Source code for piccolo_api.mfa.endpoints
import os
from abc import ABCMeta, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from jinja2 import Environment, FileSystemLoader
from piccolo.apps.user.tables import BaseUser
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED
from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.shared.auth.styles import Styles
TEMPLATE_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"templates",
)
environment = Environment(
loader=FileSystemLoader(TEMPLATE_PATH), autoescape=True
)
class MFASetupEndpoint(HTTPEndpoint, metaclass=ABCMeta):
@property
@abstractmethod
def _provider(self) -> MFAProvider:
raise NotImplementedError
@property
@abstractmethod
def _auth_table(self) -> type[BaseUser]:
raise NotImplementedError
@property
@abstractmethod
def _styles(self) -> Styles:
raise NotImplementedError
def _render_register_template(
self,
request: Request,
extra_context: Optional[dict] = None,
status_code: int = 200,
):
template = environment.get_template("mfa_setup.html")
return HTMLResponse(
status_code=status_code,
content=template.render(
styles=self._styles,
csrftoken=request.scope.get("csrftoken"),
**(extra_context or {}),
),
)
def _render_cancel_template(
self,
request: Request,
):
template = environment.get_template("mfa_cancel.html")
return HTMLResponse(
status_code=HTTP_400_BAD_REQUEST,
content=template.render(
styles=self._styles,
csrftoken=request.scope.get("csrftoken"),
),
)
async def get(self, request: Request):
piccolo_user: BaseUser = request.user.user
if await self._provider.is_user_enrolled(user=piccolo_user):
return self._render_cancel_template(request=request)
else:
return self._render_register_template(request=request)
async def post(self, request: Request):
piccolo_user: BaseUser = request.user.user
# Some middleware (for example CSRF) has already awaited the request
# body, and adds it to the request.
body: Any = request.scope.get("form")
if not body:
try:
body = await request.json()
except JSONDecodeError:
body = await request.form()
if action := body.get("action"):
if action == "register":
###############################################################
# If the user is already enrolled, don't proceed.
if await self._provider.is_user_enrolled(user=piccolo_user):
return self._render_cancel_template(request=request)
###############################################################
# Make sure the password is correct.
password = body.get("password")
if not password or not await self._auth_table.login(
username=piccolo_user.username, password=password
):
return self._render_register_template(
request=request,
status_code=HTTP_401_UNAUTHORIZED,
extra_context={"error": "Incorrect password"},
)
###############################################################
# Return the content
if body.get("format") == "json":
json_content = await self._provider.get_registration_json(
user=piccolo_user
)
return JSONResponse(content=json_content)
else:
html_content = await self._provider.get_registration_html(
user=piccolo_user
)
return HTMLResponse(content=html_content)
elif action == "revoke":
if password := body.get("password"):
if await self._auth_table.login(
username=piccolo_user.username, password=password
):
await self._provider.delete_registration(
user=piccolo_user
)
template = environment.get_template(
"mfa_disabled.html",
)
return HTMLResponse(
content=template.render(
styles=self._styles,
)
)
return HTMLResponse(content="<p>Error</p>")
[docs]
def mfa_setup(
provider: MFAProvider,
auth_table: type[BaseUser] = BaseUser,
styles: Optional[Styles] = None,
) -> type[HTTPEndpoint]:
"""
This endpoint needs to be protected ``SessionAuthMiddleware``, ensuring
that only logged in users can access it.
We also recommend protecting it with ``RateLimitingMiddleware``, because:
* Some of the forms accept a password, and we want to protect against brute
forcing.
* Generating secrets and refresh tokens is somewhat expensive, so we want
to protect against abuse.
Users can setup and manage their MFA setup using this endpoint.
"""
class _MFARegisterEndpoint(MFASetupEndpoint):
_auth_table = auth_table
_provider = provider
_styles = styles or Styles()
return _MFARegisterEndpoint