Source code for piccolo_api.session_auth.endpoints

from __future__ import annotations

import os
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from datetime import datetime, timedelta
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Literal, Optional, cast

from jinja2 import Environment, FileSystemLoader
from piccolo.apps.user.tables import BaseUser
from starlette.endpoints import HTTPEndpoint, Request
from starlette.exceptions import HTTPException
from starlette.responses import (
    HTMLResponse,
    JSONResponse,
    PlainTextResponse,
    RedirectResponse,
)
from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED

from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.session_auth.tables import SessionsBase
from piccolo_api.shared.auth.hooks import LoginHooks
from piccolo_api.shared.auth.styles import Styles

if TYPE_CHECKING:  # pragma: no cover
    from jinja2 import Template
    from starlette.responses import Response

    from piccolo_api.shared.auth.captcha import Captcha

TEMPLATE_DIR = os.path.join(
    os.path.dirname(os.path.dirname(__file__)), "templates"
)

LOGIN_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "session_login.html")

LOGOUT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "session_logout.html")


class SessionLogoutEndpoint(HTTPEndpoint, metaclass=ABCMeta):
    @property
    @abstractmethod
    def _session_table(self) -> type[SessionsBase]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _cookie_name(self) -> str:
        raise NotImplementedError

    @property
    @abstractmethod
    def _redirect_to(self) -> Optional[str]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _logout_template(self) -> Template:
        raise NotImplementedError

    @property
    @abstractmethod
    def _styles(self) -> Optional[Styles]:
        raise NotImplementedError

    def _render_template(
        self, request: Request, template_context: dict[str, Any] = {}
    ) -> HTMLResponse:
        # If CSRF middleware is present, we have to include a form field with
        # the CSRF token. It only works if CSRFMiddleware has
        # allow_form_param=True, otherwise it only looks for the token in the
        # header.
        csrftoken = request.scope.get("csrftoken")
        csrf_cookie_name = request.scope.get("csrf_cookie_name")

        return HTMLResponse(
            self._logout_template.render(
                csrftoken=csrftoken,
                csrf_cookie_name=csrf_cookie_name,
                request=request,
                styles=self._styles,
                **template_context,
            )
        )

    async def get(self, request: Request) -> HTMLResponse:
        return self._render_template(request)

    async def post(self, request: Request) -> Response:
        cookie = request.cookies.get(self._cookie_name, None)
        if not cookie:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED,
                detail="The session cookie wasn't found.",
            )
        await self._session_table.remove_session(token=cookie)

        if self._redirect_to is not None:
            response: Response = RedirectResponse(
                url=self._redirect_to, status_code=HTTP_303_SEE_OTHER
            )
        else:
            response = PlainTextResponse("Successfully logged out")

        response.set_cookie(self._cookie_name, "", max_age=0)
        return response


class SessionLoginEndpoint(HTTPEndpoint, metaclass=ABCMeta):
    @property
    @abstractmethod
    def _auth_table(self) -> type[BaseUser]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _session_table(self) -> type[SessionsBase]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _session_expiry(self) -> timedelta:
        raise NotImplementedError

    @property
    @abstractmethod
    def _max_session_expiry(self) -> timedelta:
        raise NotImplementedError

    @property
    @abstractmethod
    def _cookie_name(self) -> str:
        raise NotImplementedError

    @property
    @abstractmethod
    def _redirect_to(self) -> Optional[str]:
        """
        Where to redirect to after login is successful.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def _production(self) -> bool:
        """
        If True, apply more stringent security.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def _login_template(self) -> Template:
        raise NotImplementedError

    @property
    @abstractmethod
    def _hooks(self) -> Optional[LoginHooks]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _captcha(self) -> Optional[Captcha]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _styles(self) -> Optional[Styles]:
        raise NotImplementedError

    @property
    @abstractmethod
    def _mfa_providers(self) -> Optional[Sequence[MFAProvider]]:
        raise NotImplementedError

    def _render_template(
        self,
        request: Request,
        template_context: dict[str, Any] = {},
        status_code=200,
    ) -> HTMLResponse:
        # If CSRF middleware is present, we have to include a form field with
        # the CSRF token. It only works if CSRFMiddleware has
        # allow_form_param=True, otherwise it only looks for the token in the
        # header.
        csrftoken = request.scope.get("csrftoken")
        csrf_cookie_name = request.scope.get("csrf_cookie_name")

        return HTMLResponse(
            self._login_template.render(
                csrftoken=csrftoken,
                csrf_cookie_name=csrf_cookie_name,
                request=request,
                captcha=self._captcha,
                styles=self._styles,
                **template_context,
            ),
            status_code=status_code,
        )

    def _get_error_response(
        self, request, error: str, response_format: Literal["html", "plain"]
    ) -> Response:
        if response_format == "html":
            return self._render_template(
                request,
                template_context={"error": error},
                status_code=HTTP_401_UNAUTHORIZED,
            )
        else:
            return PlainTextResponse(
                status_code=HTTP_401_UNAUTHORIZED,
                content=f"Login failed: {error}",
            )

    async def get(self, request: Request) -> HTMLResponse:
        return self._render_template(request)

    async def post(self, request: Request) -> Response:
        # Some middleware (for example CSRF) has already awaited the request
        # body, and adds it to the request.
        body: Any = request.scope.get("form")

        if not body:
            try:
                body = await request.json()
            except JSONDecodeError:
                body = await request.form()

        username = body.get("username")
        password = body.get("password")
        return_html = body.get("format") == "html"

        if (not username) or (not password):
            error_message = "Missing username or password"
            if return_html:
                return self._render_template(
                    request,
                    template_context={"error": error_message},
                )
            else:
                raise HTTPException(status_code=422, detail=error_message)

        # Run pre_login hooks
        if self._hooks and self._hooks.pre_login:
            hooks_response = await self._hooks.run_pre_login(username=username)
            if isinstance(hooks_response, str):
                return self._get_error_response(
                    request=request,
                    error=hooks_response,
                    response_format="html" if return_html else "plain",
                )

        # Check CAPTCHA
        if self._captcha:
            token = body.get(self._captcha.token_field, None)
            validate_response = await self._captcha.validate(token=token)
            if isinstance(validate_response, str):
                if return_html:
                    return self._render_template(
                        request,
                        template_context={"error": validate_response},
                    )
                else:
                    raise HTTPException(
                        status_code=HTTP_401_UNAUTHORIZED,
                        detail=validate_response,
                    )

        # Attempt login
        user_id = await self._auth_table.login(
            username=username, password=password
        )

        if user_id:
            # Apply MFA
            if mfa_providers := self._mfa_providers:
                user = (
                    await self._auth_table.objects()
                    .where(self._auth_table.id == user_id)
                    .first()
                )

                assert user is not None

                if enrolled_mfa_providers := [
                    mfa_provider
                    for mfa_provider in mfa_providers
                    if await mfa_provider.is_user_enrolled(user=user)
                ]:
                    mfa_code = body.get("mfa_code")

                    if mfa_code is None:
                        has_sent_code: list[bool] = []
                        for mfa_provider in enrolled_mfa_providers:
                            # Send the code (only used with things like email
                            # and SMS MFA).
                            has_sent_code.append(
                                await mfa_provider.send_code(user=user)
                            )

                        message = "MFA code required"
                        if any(has_sent_code):
                            message += " (we sent you a code)"

                        if return_html:
                            return self._render_template(
                                request,
                                template_context={
                                    "error": message,
                                    "show_mfa_input": True,
                                    "mfa_provider_names": [
                                        mfa_provider.name
                                        for mfa_provider in enrolled_mfa_providers  # noqa: E501
                                    ],
                                },
                            )
                        else:
                            raise HTTPException(
                                status_code=HTTP_401_UNAUTHORIZED,
                                detail=message,
                            )

                    # Work out which MFA provider to use:
                    if len(enrolled_mfa_providers) == 1:
                        active_mfa_provider = enrolled_mfa_providers[0]
                    else:
                        mfa_provider_name = body.get("mfa_provider_name")

                        if mfa_provider_name is None:
                            raise HTTPException(
                                status_code=HTTP_401_UNAUTHORIZED,
                                detail="MFA provider must be specified",
                            )

                        filtered_mfa_providers = [
                            i
                            for i in enrolled_mfa_providers
                            if i.name == mfa_provider_name
                        ]

                        if len(filtered_mfa_providers) == 0:
                            raise HTTPException(
                                status_code=HTTP_401_UNAUTHORIZED,
                                detail="MFA provider not recognised.",
                            )

                        if len(filtered_mfa_providers) > 1:
                            raise HTTPException(
                                status_code=HTTP_401_UNAUTHORIZED,
                                detail=(
                                    "Multiple matching MFA providers found."
                                ),
                            )

                        active_mfa_provider = filtered_mfa_providers[0]

                    if not await active_mfa_provider.authenticate_user(
                        user=user, code=mfa_code
                    ):
                        if return_html:
                            return self._render_template(
                                request,
                                template_context={
                                    "error": "MFA failed",
                                    "show_mfa_input": True,
                                    "mfa_provider_names": {
                                        mfa_provider.name
                                        for mfa_provider in enrolled_mfa_providers  # noqa: E501
                                    },
                                },
                            )
                        else:
                            raise HTTPException(
                                status_code=HTTP_401_UNAUTHORIZED,
                                detail="MFA failed",
                            )

            # Run login_success hooks
            if self._hooks and self._hooks.login_success:
                hooks_response = await self._hooks.run_login_success(
                    username=username, user_id=user_id
                )
                if isinstance(hooks_response, str):
                    return self._get_error_response(
                        request=request,
                        error=hooks_response,
                        response_format="html" if return_html else "plain",
                    )
        else:
            # Run login_failure hooks
            if self._hooks and self._hooks.login_failure:
                hooks_response = await self._hooks.run_login_failure(
                    username=username
                )
                if isinstance(hooks_response, str):
                    return self._get_error_response(
                        request=request,
                        error=hooks_response,
                        response_format="html" if return_html else "plain",
                    )

            if return_html:
                return self._render_template(
                    request,
                    template_context={
                        "error": "The username or password is incorrect."
                    },
                )
            else:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Login failed"
                )

        now = datetime.now()
        expiry_date = now + self._session_expiry
        max_expiry_date = now + self._max_session_expiry

        session: SessionsBase = await self._session_table.create_session(
            user_id=user_id,
            expiry_date=expiry_date,
            max_expiry_date=max_expiry_date,
        )

        if self._redirect_to is not None:
            response: Response = RedirectResponse(
                url=self._redirect_to, status_code=HTTP_303_SEE_OTHER
            )
        else:
            response = JSONResponse(
                content={"message": "logged in"}, status_code=200
            )

        if not self._production:
            message = (
                "If running sessions in production, make sure 'production' "
                "is set to True, and serve under HTTPS."
            )
            warnings.warn(message)

        cookie_value = cast(str, session.token)

        response.set_cookie(
            key=self._cookie_name,
            value=cookie_value,
            httponly=True,
            secure=self._production,
            max_age=int(self._max_session_expiry.total_seconds()),
            samesite="lax",
        )
        return response


[docs] def session_login( auth_table: type[BaseUser] = BaseUser, session_table: type[SessionsBase] = SessionsBase, session_expiry: timedelta = timedelta(hours=1), max_session_expiry: timedelta = timedelta(days=7), redirect_to: Optional[str] = "/", production: bool = False, cookie_name: str = "id", template_path: Optional[str] = None, hooks: Optional[LoginHooks] = None, captcha: Optional[Captcha] = None, styles: Optional[Styles] = None, mfa_providers: Optional[Sequence[MFAProvider]] = None, ) -> type[SessionLoginEndpoint]: """ An endpoint for creating a user session. :param auth_table: Which table to authenticate the username and password with. It defaults to :class:`BaseUser <piccolo.apps.user.tables.BaseUser>`. :param session_table: Which table to store the session in. If defaults to :class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`. :param session_expiry: How long the session will last. :param max_session_expiry: If the session is refreshed (see the ``increase_expiry`` parameter for :class:`SessionsAuthBackend <piccolo_api.session_auth.middleware.SessionsAuthBackend>`), it can only be refreshed up to a certain limit, after which the session is void. :param redirect_to: Where to redirect to after successful login. :param production: Adds additional security measures. Use this in production, when serving your app over HTTPS. :param cookie_name: The name of the cookie used to store the session token. Only override this if the name of the cookie clashes with other cookies. :param template_path: If you want to override the default login HTML template, you can do so by specifying the absolute path to a custom template. For example ``'/some_directory/login.html'``. Refer to the default template at ``piccolo_api/templates/session_login.html`` as a basis for your custom template. :param hooks: Allows you to run custom logic at various points in the login process. See :class:`LoginHooks <piccolo_api.shared.auth.hooks.LoginHooks>`. :param captcha: Integrate a CAPTCHA service, to provide protection against bots. See :class:`Captcha <piccolo_api.shared.auth.captcha.Captcha>`. :param styles: Modify the appearance of the HTML template using CSS. :param mfa_providers: Add additional security to the login process using Multi-Factor Authentication. """ # noqa: E501 template_path = ( LOGIN_TEMPLATE_PATH if template_path is None else template_path ) directory, filename = os.path.split(template_path) environment = Environment( loader=FileSystemLoader(directory), autoescape=True ) login_template = environment.get_template(filename) class _SessionLoginEndpoint(SessionLoginEndpoint): _auth_table = auth_table _session_table = session_table _session_expiry = session_expiry _max_session_expiry = max_session_expiry _redirect_to = redirect_to _production = production _cookie_name = cookie_name _login_template = login_template _hooks = hooks _captcha = captcha _styles = styles or Styles() _mfa_providers = mfa_providers return _SessionLoginEndpoint
[docs] def session_logout( session_table: type[SessionsBase] = SessionsBase, cookie_name: str = "id", redirect_to: Optional[str] = None, template_path: Optional[str] = None, styles: Optional[Styles] = None, ) -> type[SessionLogoutEndpoint]: """ An endpoint for clearing a user session. :param session_table: Which table to store the session in. It defaults to :class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`. :param cookie_name: The name of the cookie used to store the session token. Only override this if the name of the cookie clashes with other cookies. :param redirect_to: Where to redirect to after logging out. :param template_path: If you want to override the default logout HTML template, you can do so by specifying the absolute path to a custom template. For example ``'/some_directory/logout.html'``. Refer to the default template at ``piccolo_api/templates/logout.html`` as a basis for your custom template. :param styles: Modify the appearance of the HTML template using CSS. """ # noqa: E501 template_path = ( LOGOUT_TEMPLATE_PATH if template_path is None else template_path ) directory, filename = os.path.split(template_path) environment = Environment( loader=FileSystemLoader(directory), autoescape=True ) logout_template = environment.get_template(filename) class _SessionLogoutEndpoint(SessionLogoutEndpoint): _session_table = session_table _cookie_name = cookie_name _redirect_to = redirect_to _logout_template = logout_template _styles = styles or Styles() return _SessionLogoutEndpoint