from __future__ import annotations
import os
import typing as t
import warnings
from abc import ABCMeta, abstractproperty
from datetime import datetime, timedelta
from json import JSONDecodeError
from jinja2 import Environment, FileSystemLoader
from piccolo.apps.user.tables import BaseUser
from starlette.endpoints import HTTPEndpoint, Request
from starlette.exceptions import HTTPException
from starlette.responses import (
HTMLResponse,
JSONResponse,
PlainTextResponse,
RedirectResponse,
)
from starlette.status import HTTP_303_SEE_OTHER
from piccolo_api.session_auth.tables import SessionsBase
from piccolo_api.shared.auth.hooks import LoginHooks
from piccolo_api.shared.auth.styles import Styles
if t.TYPE_CHECKING: # pragma: no cover
from jinja2 import Template
from starlette.responses import Response
from piccolo_api.shared.auth.captcha import Captcha
TEMPLATE_DIR = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "templates"
)
LOGIN_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "session_login.html")
LOGOUT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "session_logout.html")
class SessionLogoutEndpoint(HTTPEndpoint, metaclass=ABCMeta):
@abstractproperty
def _session_table(self) -> t.Type[SessionsBase]:
raise NotImplementedError
@abstractproperty
def _cookie_name(self) -> str:
raise NotImplementedError
@abstractproperty
def _redirect_to(self) -> t.Optional[str]:
raise NotImplementedError
@abstractproperty
def _logout_template(self) -> Template:
raise NotImplementedError
@abstractproperty
def _styles(self) -> t.Optional[Styles]:
raise NotImplementedError
def _render_template(
self, request: Request, template_context: t.Dict[str, t.Any] = {}
) -> HTMLResponse:
# If CSRF middleware is present, we have to include a form field with
# the CSRF token. It only works if CSRFMiddleware has
# allow_form_param=True, otherwise it only looks for the token in the
# header.
csrftoken = request.scope.get("csrftoken")
csrf_cookie_name = request.scope.get("csrf_cookie_name")
return HTMLResponse(
self._logout_template.render(
csrftoken=csrftoken,
csrf_cookie_name=csrf_cookie_name,
request=request,
styles=self._styles,
**template_context,
)
)
async def get(self, request: Request) -> HTMLResponse:
return self._render_template(request)
async def post(self, request: Request) -> Response:
cookie = request.cookies.get(self._cookie_name, None)
if not cookie:
raise HTTPException(
status_code=401, detail="The session cookie wasn't found."
)
await self._session_table.remove_session(token=cookie)
if self._redirect_to is not None:
response: Response = RedirectResponse(
url=self._redirect_to, status_code=HTTP_303_SEE_OTHER
)
else:
response = PlainTextResponse("Successfully logged out")
response.set_cookie(self._cookie_name, "", max_age=0)
return response
class SessionLoginEndpoint(HTTPEndpoint, metaclass=ABCMeta):
@abstractproperty
def _auth_table(self) -> t.Type[BaseUser]:
raise NotImplementedError
@abstractproperty
def _session_table(self) -> t.Type[SessionsBase]:
raise NotImplementedError
@abstractproperty
def _session_expiry(self) -> timedelta:
raise NotImplementedError
@abstractproperty
def _max_session_expiry(self) -> timedelta:
raise NotImplementedError
@abstractproperty
def _cookie_name(self) -> str:
raise NotImplementedError
@abstractproperty
def _redirect_to(self) -> t.Optional[str]:
"""
Where to redirect to after login is successful.
"""
raise NotImplementedError
@abstractproperty
def _production(self) -> bool:
"""
If True, apply more stringent security.
"""
raise NotImplementedError
@abstractproperty
def _login_template(self) -> Template:
raise NotImplementedError
@abstractproperty
def _hooks(self) -> t.Optional[LoginHooks]:
raise NotImplementedError
@abstractproperty
def _captcha(self) -> t.Optional[Captcha]:
raise NotImplementedError
@abstractproperty
def _styles(self) -> t.Optional[Styles]:
raise NotImplementedError
def _render_template(
self,
request: Request,
template_context: t.Dict[str, t.Any] = {},
status_code=200,
) -> HTMLResponse:
# If CSRF middleware is present, we have to include a form field with
# the CSRF token. It only works if CSRFMiddleware has
# allow_form_param=True, otherwise it only looks for the token in the
# header.
csrftoken = request.scope.get("csrftoken")
csrf_cookie_name = request.scope.get("csrf_cookie_name")
return HTMLResponse(
self._login_template.render(
csrftoken=csrftoken,
csrf_cookie_name=csrf_cookie_name,
request=request,
captcha=self._captcha,
styles=self._styles,
**template_context,
),
status_code=status_code,
)
def _get_error_response(
self, request, error: str, response_format: t.Literal["html", "plain"]
) -> Response:
if response_format == "html":
return self._render_template(
request, template_context={"error": error}, status_code=401
)
else:
return PlainTextResponse(
status_code=401, content=f"Login failed: {error}"
)
async def get(self, request: Request) -> HTMLResponse:
return self._render_template(request)
async def post(self, request: Request) -> Response:
# Some middleware (for example CSRF) has already awaited the request
# body, and adds it to the request.
body: t.Any = request.scope.get("form")
if not body:
try:
body = await request.json()
except JSONDecodeError:
body = await request.form()
username = body.get("username", None)
password = body.get("password", None)
return_html = body.get("format") == "html"
if (not username) or (not password):
error_message = "Missing username or password"
if return_html:
return self._render_template(
request,
template_context={"error": error_message},
)
else:
raise HTTPException(status_code=422, detail=error_message)
# Run pre_login hooks
if self._hooks and self._hooks.pre_login:
hooks_response = await self._hooks.run_pre_login(username=username)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)
# Check CAPTCHA
if self._captcha:
token = body.get(self._captcha.token_field, None)
validate_response = await self._captcha.validate(token=token)
if isinstance(validate_response, str):
if return_html:
return self._render_template(
request,
template_context={"error": validate_response},
)
else:
raise HTTPException(
status_code=401, detail=validate_response
)
# Attempt login
user_id = await self._auth_table.login(
username=username, password=password
)
if user_id:
# Run login_success hooks
if self._hooks and self._hooks.login_success:
hooks_response = await self._hooks.run_login_success(
username=username, user_id=user_id
)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)
else:
# Run login_failure hooks
if self._hooks and self._hooks.login_failure:
hooks_response = await self._hooks.run_login_failure(
username=username
)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)
if return_html:
return self._render_template(
request,
template_context={
"error": "The username or password is incorrect."
},
)
else:
raise HTTPException(status_code=401, detail="Login failed")
now = datetime.now()
expiry_date = now + self._session_expiry
max_expiry_date = now + self._max_session_expiry
session: SessionsBase = await self._session_table.create_session(
user_id=user_id,
expiry_date=expiry_date,
max_expiry_date=max_expiry_date,
)
if self._redirect_to is not None:
response: Response = RedirectResponse(
url=self._redirect_to, status_code=HTTP_303_SEE_OTHER
)
else:
response = JSONResponse(
content={"message": "logged in"}, status_code=200
)
if not self._production:
message = (
"If running sessions in production, make sure 'production' "
"is set to True, and serve under HTTPS."
)
warnings.warn(message)
cookie_value = t.cast(str, session.token)
response.set_cookie(
key=self._cookie_name,
value=cookie_value,
httponly=True,
secure=self._production,
max_age=int(self._max_session_expiry.total_seconds()),
samesite="lax",
)
return response
[docs]
def session_login(
auth_table: t.Type[BaseUser] = BaseUser,
session_table: t.Type[SessionsBase] = SessionsBase,
session_expiry: timedelta = timedelta(hours=1),
max_session_expiry: timedelta = timedelta(days=7),
redirect_to: t.Optional[str] = "/",
production: bool = False,
cookie_name: str = "id",
template_path: t.Optional[str] = None,
hooks: t.Optional[LoginHooks] = None,
captcha: t.Optional[Captcha] = None,
styles: t.Optional[Styles] = None,
) -> t.Type[SessionLoginEndpoint]:
"""
An endpoint for creating a user session.
:param auth_table:
Which table to authenticate the username and password with. It
defaults to :class:`BaseUser <piccolo.apps.user.tables.BaseUser>`.
:param session_table:
Which table to store the session in. If defaults to
:class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`.
:param session_expiry:
How long the session will last.
:param max_session_expiry:
If the session is refreshed (see the ``increase_expiry`` parameter for
:class:`SessionsAuthBackend <piccolo_api.session_auth.middleware.SessionsAuthBackend>`),
it can only be refreshed up to a certain limit, after which the session
is void.
:param redirect_to:
Where to redirect to after successful login.
:param production:
Adds additional security measures. Use this in production, when serving
your app over HTTPS.
:param cookie_name:
The name of the cookie used to store the session token. Only override
this if the name of the cookie clashes with other cookies.
:param template_path:
If you want to override the default login HTML template, you can do
so by specifying the absolute path to a custom template. For example
``'/some_directory/login.html'``. Refer to the default template at
``piccolo_api/templates/session_login.html`` as a basis for your
custom template.
:param hooks:
Allows you to run custom logic at various points in the login process.
See :class:`LoginHooks <piccolo_api.shared.auth.hooks.LoginHooks>`.
:param captcha:
Integrate a CAPTCHA service, to provide protection against bots.
See :class:`Captcha <piccolo_api.shared.auth.captcha.Captcha>`.
:param styles:
Modify the appearance of the HTML template using CSS.
""" # noqa: E501
template_path = (
LOGIN_TEMPLATE_PATH if template_path is None else template_path
)
directory, filename = os.path.split(template_path)
environment = Environment(
loader=FileSystemLoader(directory), autoescape=True
)
login_template = environment.get_template(filename)
class _SessionLoginEndpoint(SessionLoginEndpoint):
_auth_table = auth_table
_session_table = session_table
_session_expiry = session_expiry
_max_session_expiry = max_session_expiry
_redirect_to = redirect_to
_production = production
_cookie_name = cookie_name
_login_template = login_template
_hooks = hooks
_captcha = captcha
_styles = styles or Styles()
return _SessionLoginEndpoint
[docs]
def session_logout(
session_table: t.Type[SessionsBase] = SessionsBase,
cookie_name: str = "id",
redirect_to: t.Optional[str] = None,
template_path: t.Optional[str] = None,
styles: t.Optional[Styles] = None,
) -> t.Type[SessionLogoutEndpoint]:
"""
An endpoint for clearing a user session.
:param session_table:
Which table to store the session in. It defaults to
:class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`.
:param cookie_name:
The name of the cookie used to store the session token. Only override
this if the name of the cookie clashes with other cookies.
:param redirect_to:
Where to redirect to after logging out.
:param template_path:
If you want to override the default logout HTML template, you can do
so by specifying the absolute path to a custom template. For example
``'/some_directory/logout.html'``. Refer to the default template at
``piccolo_api/templates/logout.html`` as a basis for your
custom template.
:param styles:
Modify the appearance of the HTML template using CSS.
""" # noqa: E501
template_path = (
LOGOUT_TEMPLATE_PATH if template_path is None else template_path
)
directory, filename = os.path.split(template_path)
environment = Environment(
loader=FileSystemLoader(directory), autoescape=True
)
logout_template = environment.get_template(filename)
class _SessionLogoutEndpoint(SessionLogoutEndpoint):
_session_table = session_table
_cookie_name = cookie_name
_redirect_to = redirect_to
_logout_template = logout_template
_styles = styles or Styles()
return _SessionLogoutEndpoint