from __future__ import annotations
import os
import re
from abc import ABCMeta, abstractmethod
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Optional, Union
from jinja2 import Environment, FileSystemLoader
from piccolo.apps.user.tables import BaseUser
from starlette.datastructures import URL
from starlette.endpoints import HTTPEndpoint, Request
from starlette.exceptions import HTTPException
from starlette.responses import (
HTMLResponse,
PlainTextResponse,
RedirectResponse,
)
from starlette.status import HTTP_303_SEE_OTHER
from piccolo_api.shared.auth.styles import Styles
if TYPE_CHECKING: # pragma: no cover
from jinja2 import Template
from starlette.responses import Response
from piccolo_api.shared.auth.captcha import Captcha
SIGNUP_TEMPLATE_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "templates", "register.html"
)
EMAIL_REGEX = re.compile(r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)")
class RegisterEndpoint(HTTPEndpoint, metaclass=ABCMeta):
@property
@abstractmethod
def _auth_table(self) -> type[BaseUser]:
raise NotImplementedError
@property
@abstractmethod
def _redirect_to(self) -> Union[str, URL]:
"""
Where to redirect to after login is successful.
"""
raise NotImplementedError
@property
@abstractmethod
def _register_template(self) -> Template:
raise NotImplementedError
@property
@abstractmethod
def _user_defaults(self) -> Optional[dict[str, Any]]:
raise NotImplementedError
@property
@abstractmethod
def _captcha(self) -> Optional[Captcha]:
raise NotImplementedError
@property
@abstractmethod
def _styles(self) -> Styles:
raise NotImplementedError
@property
@abstractmethod
def _read_only(self) -> bool:
raise NotImplementedError
def render_template(
self, request: Request, template_context: dict[str, Any] = {}
) -> HTMLResponse:
# If CSRF middleware is present, we have to include a form field with
# the CSRF token. It only works if CSRFMiddleware has
# allow_form_param=True, otherwise it only looks for the token in the
# header.
csrftoken = request.scope.get("csrftoken")
csrf_cookie_name = request.scope.get("csrf_cookie_name")
return HTMLResponse(
self._register_template.render(
csrftoken=csrftoken,
csrf_cookie_name=csrf_cookie_name,
request=request,
captcha=self._captcha,
styles=self._styles,
**template_context,
)
)
async def get(self, request: Request) -> HTMLResponse:
return self.render_template(request)
async def post(self, request: Request) -> Response:
if self._read_only:
return PlainTextResponse(
content="Running in read only mode.", status_code=405
)
# Some middleware (for example CSRF) has already awaited the request
# body, and adds it to the request.
body: Any = request.scope.get("form")
if not body:
try:
body = await request.json()
except JSONDecodeError:
body = await request.form()
username = body.get("username", None)
email = body.get("email", None)
password = body.get("password", None)
confirm_password = body.get("confirm_password", None)
if self._captcha:
token = body.get(self._captcha.token_field, None)
response = await self._captcha.validate(token=token)
if isinstance(response, str):
return self.render_template(
request,
template_context={"error": response},
)
if (
(not username)
or (not email)
or (not password)
or (not confirm_password)
):
if body.get("format") == "html":
return self.render_template(
request,
template_context={
"error": "Form is invalid. Missing one or more fields."
},
)
raise HTTPException(
status_code=422,
detail="Form is invalid. Missing one or more fields.",
)
if not EMAIL_REGEX.fullmatch(email):
if body.get("format") == "html":
return self.render_template(
request,
template_context={"error": "Invalid email address."},
)
else:
raise HTTPException(
status_code=422, detail="Invalid email address."
)
if len(password) < 6:
if body.get("format") == "html":
return self.render_template(
request,
template_context={
"error": "Password must be at least 6 characters long."
},
)
else:
raise HTTPException(
status_code=422,
detail="Password must be at least 6 characters long.",
)
if confirm_password != password:
if body.get("format") == "html":
return self.render_template(
request,
template_context={"error": "Passwords do not match."},
)
else:
raise HTTPException(
status_code=422, detail="Passwords do not match."
)
if await self._auth_table.count().where(
self._auth_table.email == email,
self._auth_table.username == username,
):
if body.get("format") == "html":
return self.render_template(
request,
template_context={
"error": "User with email or username already exists."
},
)
else:
raise HTTPException(
status_code=422,
detail="User with email or username already exists.",
)
extra_params = self._user_defaults or {}
await self._auth_table.create_user(
username=username, password=password, email=email, **extra_params
)
return RedirectResponse(
url=self._redirect_to, status_code=HTTP_303_SEE_OTHER
)
[docs]
def register(
auth_table: type[BaseUser] = BaseUser,
redirect_to: Union[str, URL] = "/login/",
template_path: Optional[str] = None,
user_defaults: Optional[dict[str, Any]] = None,
captcha: Optional[Captcha] = None,
styles: Optional[Styles] = None,
read_only: bool = False,
) -> type[RegisterEndpoint]:
"""
An endpoint for register user.
:param auth_table:
Which ``Table`` to create the user in. It defaults to
:class:`BaseUser <piccolo.apps.user.tables.BaseUser>`.
:param redirect_to:
Where to redirect to after successful registration.
:param template_path:
If you want to override the default register HTML template, you can do
so by specifying the absolute path to a custom template. For example
``'/some_directory/register.html'``. Refer to the default template at
``piccolo_api/templates/register.html`` as a basis for
your custom template.
:param user_defaults:
These values are assigned to the new user. An example use case is
setting ``active = True`` on each new user, so they can immediately
login (not recommended for production, as it's better to verify their
email address first, but OK for a prototype app)::
register(user_defaults={'active': True})
:param captcha:
Integrate a CAPTCHA service, to provide protection against bots.
See :class:`Captcha <piccolo_api.shared.auth.captcha.Captcha>`.
:param styles:
Modify the appearance of the HTML template using CSS.
:read_only:
If ``True``, the endpoint only responds to GET requests. It's not
commonly needed, except when running demos.
"""
template_path = (
SIGNUP_TEMPLATE_PATH if template_path is None else template_path
)
directory, filename = os.path.split(template_path)
environment = Environment(
loader=FileSystemLoader(directory), autoescape=True
)
register_template = environment.get_template(filename)
class _RegisterEndpoint(RegisterEndpoint):
_auth_table = auth_table or BaseUser
_redirect_to = redirect_to
_register_template = register_template
_user_defaults = user_defaults
_captcha = captcha
_styles = styles or Styles()
_read_only = read_only
return _RegisterEndpoint