Source code for piccolo_api.fastapi.endpoints

"""
Enhancing Piccolo integration with FastAPI.
"""

from __future__ import annotations

import datetime
import typing as t
from collections import defaultdict
from decimal import Decimal
from enum import Enum
from inspect import Parameter, Signature, isclass

from fastapi import APIRouter, FastAPI, Request, status
from fastapi.params import Query
from pydantic import BaseModel as PydanticBaseModel
from pydantic.main import BaseModel

from piccolo_api.crud.endpoints import PiccoloCRUD
from piccolo_api.utils.types import get_type

ANNOTATIONS: t.DefaultDict = defaultdict(dict)


class HTTPMethod(str, Enum):
    get = "GET"
    delete = "DELETE"


[docs]class FastAPIKwargs: """ Allows kwargs to be passed into ``FastAPIApp.add_api_route``. """ def __init__( self, all_routes: t.Dict[str, t.Any] = {}, get: t.Dict[str, t.Any] = {}, delete: t.Dict[str, t.Any] = {}, post: t.Dict[str, t.Any] = {}, put: t.Dict[str, t.Any] = {}, patch: t.Dict[str, t.Any] = {}, get_single: t.Dict[str, t.Any] = {}, delete_single: t.Dict[str, t.Any] = {}, ): self.all_routes = all_routes self.get = get self.delete = delete self.post = post self.put = put self.patch = patch self.get_single = get_single self.delete_single = delete_single def get_kwargs(self, endpoint_name: str) -> t.Dict[str, t.Any]: """ Merges the arguments for all routes with arguments specific to the given route. """ default = self.all_routes.copy() route_specific = getattr(self, endpoint_name, {}) default.update(**route_specific) return default
class CountModel(BaseModel): count: int page_size: int class ReferenceModel(BaseModel): tableName: str columnName: str class ReferencesModel(BaseModel): references: t.List[ReferenceModel]
[docs]class FastAPIWrapper: """ Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI. ``PiccoloCRUD`` can be used with any ASGI framework, but this way you get some of the benefits of FastAPI - namely, the OpenAPI integration. You get more control by building your own endpoints by hand, but ``FastAPIWrapper`` works great for getting endpoints up and running very quickly, and reducing boilerplate code. :param root_url: The URL to mount the endpoint at - e.g. ``'/movies/'``. :param fastapi_app: The ``FastAPI`` instance you want to attach the endpoints to. :param piccolo_crud: The ``PiccoloCRUD`` instance to wrap. ``FastAPIWrapper`` will obey the arguments passed into ``PiccoloCRUD``, for example ``ready_only`` and ``allow_bulk_delete``. :param fastapi_kwargs: Specifies the extra kwargs to pass to FastAPI's ``add_api_route``. """ def __init__( self, root_url: str, fastapi_app: t.Union[FastAPI, APIRouter], piccolo_crud: PiccoloCRUD, fastapi_kwargs: t.Optional[FastAPIKwargs] = None, ): fastapi_kwargs = fastapi_kwargs or FastAPIKwargs() self.root_url = root_url self.fastapi_app = fastapi_app self.piccolo_crud = piccolo_crud self.fastapi_kwargs = fastapi_kwargs self.ModelOut = piccolo_crud.pydantic_model_output self.ModelIn = piccolo_crud.pydantic_model self.ModelOptional = piccolo_crud.pydantic_model_optional self.ModelPlural = piccolo_crud.pydantic_model_plural() self.ModelFilters = piccolo_crud.pydantic_model_filters self.alias = f"{piccolo_crud.table._meta.tablename}__{id(self)}" global ANNOTATIONS ANNOTATIONS[self.alias]["ModelIn"] = self.ModelIn ANNOTATIONS[self.alias]["ModelOut"] = self.ModelOut ANNOTATIONS[self.alias]["ModelOptional"] = self.ModelOptional ANNOTATIONS[self.alias]["ModelPlural"] = self.ModelPlural ####################################################################### # Root - GET async def get(request: Request, **kwargs): """ Returns all rows matching the given query. """ return await piccolo_crud.root(request=request) self.modify_signature( endpoint=get, model=self.ModelFilters, http_method=HTTPMethod.get, allow_ordering=True, allow_pagination=True, ) fastapi_app.add_api_route( path=root_url, endpoint=get, methods=["GET"], response_model=self.ModelPlural, **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - IDs async def ids( request: Request, search: t.Optional[str] = None, limit: t.Optional[int] = None, ): """ Returns a mapping of row IDs to a readable representation. """ return await piccolo_crud.get_ids(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/ids/"), endpoint=ids, methods=["GET"], response_model=t.Dict[str, str], **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - New async def new(request: Request): """ Returns all of the default values for a new row, but doesn't save it. """ return await piccolo_crud.get_new(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/new/"), endpoint=new, methods=["GET"], response_model=t.Dict[str, str], **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - Count async def count(request: Request, **kwargs): """ Returns the number of rows matching the given query. """ return await piccolo_crud.get_count(request=request) self.modify_signature( endpoint=count, model=self.ModelFilters, http_method=HTTPMethod.get ) fastapi_app.add_api_route( path=self.join_urls(root_url, "/count/"), endpoint=count, methods=["GET"], response_model=CountModel, **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - Schema async def schema(request: Request): """ Returns the JSON schema for the given table. """ return await piccolo_crud.get_schema(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/schema/"), endpoint=schema, methods=["GET"], response_model=t.Dict[str, t.Any], **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - References async def references(request: Request): """ Returns a list of objects showing relationships with other tables. """ return await piccolo_crud.get_references(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/references/"), endpoint=references, methods=["GET"], response_model=ReferencesModel, **fastapi_kwargs.get_kwargs("get"), ) ####################################################################### # Root - DELETE if not piccolo_crud.read_only and piccolo_crud.allow_bulk_delete: async def delete(request: Request, **kwargs): """ Deletes all rows matching the given query. """ return await piccolo_crud.root(request=request) self.modify_signature( endpoint=delete, model=self.ModelFilters, http_method=HTTPMethod.delete, ) fastapi_app.add_api_route( path=root_url, endpoint=delete, response_model=None, status_code=status.HTTP_204_NO_CONTENT, methods=["DELETE"], **fastapi_kwargs.get_kwargs("delete"), ) ####################################################################### # Root - POST if not piccolo_crud.read_only: async def post(request: Request, model): """ Create a new row in the table. """ return await piccolo_crud.root(request=request) post.__annotations__["model"] = ( f"ANNOTATIONS['{self.alias}']['ModelIn']" ) fastapi_app.add_api_route( path=root_url, endpoint=post, response_model=self.ModelOut, status_code=status.HTTP_201_CREATED, methods=["POST"], **fastapi_kwargs.get_kwargs("post"), ) ####################################################################### # Detail - GET async def get_single(row_id: str, request: Request): """ Retrieve a single row from the table. """ return await piccolo_crud.detail(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), endpoint=get_single, response_model=self.ModelOut, methods=["GET"], **fastapi_kwargs.get_kwargs("get_single"), ) ####################################################################### # Detail - DELETE if not piccolo_crud.read_only: async def delete_single(row_id: str, request: Request): """ Delete a single row from the table. """ return await piccolo_crud.detail(request=request) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), endpoint=delete_single, response_model=None, status_code=status.HTTP_204_NO_CONTENT, methods=["DELETE"], **fastapi_kwargs.get_kwargs("delete_single"), ) ####################################################################### # Detail - PUT if not piccolo_crud.read_only: async def put(row_id: str, request: Request, model): """ Insert or update a single row. """ return await piccolo_crud.detail(request=request) put.__annotations__["model"] = ( f"ANNOTATIONS['{self.alias}']['ModelIn']" ) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), endpoint=put, response_model=None, status_code=status.HTTP_204_NO_CONTENT, methods=["PUT"], **fastapi_kwargs.get_kwargs("put"), ) ####################################################################### # Detail - PATCH if not piccolo_crud.read_only: async def patch(row_id: str, request: Request, model): """ Update a single row. """ return await piccolo_crud.detail(request=request) patch.__annotations__["model"] = ( f"ANNOTATIONS['{self.alias}']['ModelOptional']" ) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), endpoint=patch, response_model=self.ModelOut, methods=["PATCH"], **fastapi_kwargs.get_kwargs("patch"), ) @staticmethod def join_urls(url_1: str, url_2: str) -> str: """ Combine two urls, and prevent double slashes (e.g. '/foo//bar') :param url_1: e.g. '/foo' :param url_2: e.g. '/bar :returns: e.g. '/foo/bar' """ return "/".join([url_1.rstrip("/"), url_2.lstrip("/")]) @staticmethod def modify_signature( endpoint: t.Callable, model: t.Type[PydanticBaseModel], http_method: HTTPMethod, allow_pagination: bool = False, allow_ordering: bool = False, ): """ Modify the endpoint's signature, so FastAPI can correctly extract the schema from it. GET endpoints are given more filters. """ parameters = [ Parameter( name="request", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request, ), ] for field_name, _field in model.model_fields.items(): annotation = _field.annotation assert annotation is not None type_ = get_type(annotation) parameters.append( Parameter( name=field_name, kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( default=None, description=(f"Filter by the `{field_name}` column."), ), annotation=type_, ), ) if type_ in ( int, float, Decimal, datetime.date, datetime.datetime, datetime.time, datetime.timedelta, ): parameters.append( Parameter( name=f"{field_name}__operator", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( default=None, description=( f"Which operator to use for `{field_name}`. " "The options are `e` (equals - default) `lt`, " "`lte`, `gt`, `gte`, `ne`, `is_null`, and " "`not_null`." ), ), ) ) else: parameters.append( Parameter( name=f"{field_name}__operator", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( default=None, description=( f"Which operator to use for `{field_name}`. " "The options are `is_null`, and `not_null`." ), ), ) ) # We have to check if it's a subclass of `str` for Varchar, which # uses Pydantics `constr` (constrained string). if type_ is str or (isclass(type_) and issubclass(type_, str)): parameters.append( Parameter( name=f"{field_name}__match", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( default=None, description=( f"Specifies how `{field_name}` should be " "matched - `contains` (default), `exact`, " "`starts`, `ends`." ), ), ) ) if http_method == HTTPMethod.get: if allow_ordering: parameters.extend( [ Parameter( name="__order", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=str, default=Query( default=None, description=( "Specifies which field to sort the " "results by. For example `id` to sort by " "id, and `-id` for descending." ), ), ) ] ) if allow_pagination: parameters.extend( [ Parameter( name="__page_size", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=int, default=Query( default=None, description=( "The number of results to return." ), ), ), Parameter( name="__page", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=int, default=Query( default=None, description=( "Which page of results to return (default " "1)." ), ), ), Parameter( name="__visible_fields", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=str, default=Query( default=None, description=( "The fields to return. It's a comma " "separated list - for example " "'name,address'. By default all fields " "are returned." ), ), ), ] ) parameters.extend( [ Parameter( name="__range_header", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=bool, default=Query( default=False, description=( "Set to 'true' to add the " "Content-Range response header" ), ), ) ] ) parameters.extend( [ Parameter( name="__range_header_name", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=str, default=Query( default=None, description=( "Specify the object name in the Content-Range " "response header (defaults to the table name)." ), ), ) ] ) endpoint.__signature__ = Signature( # type: ignore parameters=parameters )