Source code for piccolo_api.token_auth.middleware

from __future__ import annotations

import typing as t
from abc import ABCMeta, abstractmethod

from piccolo.apps.user.tables import BaseUser as BaseUserTable
from starlette.authentication import (
    AuthCredentials,
    AuthenticationBackend,
    AuthenticationError,
    BaseUser,
    SimpleUser,
)
from starlette.requests import HTTPConnection

from piccolo_api.shared.auth import User
from piccolo_api.shared.auth.excluded_paths import check_excluded_paths
from piccolo_api.token_auth.tables import TokenAuth


[docs] class TokenAuthProvider(metaclass=ABCMeta): """ Subclass to create your own token provider. """ @abstractmethod async def get_user(self, token: str) -> BaseUser: pass
[docs] class SecretTokenAuthProvider(TokenAuthProvider): """ Checks that the token belongs to a predefined list of tokens. This is useful for very simple authentication use cases - such as internal microservices, where the client is trusted. """ def __init__(self, tokens: t.Sequence[str]): self.tokens = tokens async def get_user(self, token: str) -> SimpleUser: if token in self.tokens: user = SimpleUser(username="secret_token_user") return user raise AuthenticationError("Token not recognised")
[docs] class PiccoloTokenAuthProvider(TokenAuthProvider): """ Use this when the token is stored in a Piccolo database table. """ def __init__( self, auth_table: t.Type[BaseUserTable] = BaseUserTable, token_table: t.Type[TokenAuth] = TokenAuth, ): self.auth_table = auth_table self.token_table = token_table async def get_user(self, token: str) -> User: user_id = await self.token_table.get_user_id(token) if not user_id: raise AuthenticationError() user = ( await self.auth_table.objects() .where(self.auth_table._meta.primary_key == user_id) .first() .run() ) if user is None: raise AuthenticationError() return User(user=user)
DEFAULT_PROVIDER = PiccoloTokenAuthProvider()
[docs] class TokenAuthBackend(AuthenticationBackend): def __init__( self, token_auth_provider: TokenAuthProvider = DEFAULT_PROVIDER, excluded_paths: t.Optional[t.Sequence[str]] = None, ): """ :param token_auth_provider: Used to verify that a token is correct. :param excluded_paths: These paths don't require a token - useful if you want to exclude a few URLs, such as docs. """ super().__init__() self.token_auth_provider = token_auth_provider self.excluded_paths = excluded_paths or [] def extract_token(self, header: str) -> str: try: token = header.split("Bearer ")[1] except IndexError: raise AuthenticationError("The header is in the wrong format.") return token @check_excluded_paths async def authenticate( self, conn: HTTPConnection ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: auth_header = conn.headers.get("Authorization", None) if not auth_header: raise AuthenticationError("The Authorization header is missing.") token = self.extract_token(auth_header) user = await self.token_auth_provider.get_user(token=token) return (AuthCredentials(scopes=["authenticated"]), user)