Source code for piccolo_api.rate_limiting.middleware

from __future__ import annotations

import typing as t
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from time import time

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from starlette.types import ASGIApp

if t.TYPE_CHECKING:  # pragma: no cover
    from starlette.middleware.base import Request, RequestResponseEndpoint


class RateLimitError(Exception):
    """
    Raised when a client exceeds the request limit. Should be handled
    internally without bleeding out to the rest of the application.
    """

    pass


[docs]class RateLimitProvider(metaclass=ABCMeta): """ An abstract base class which all rate limit providers should inherit from. """
[docs] @abstractmethod def increment(self, identifier: str): """ :param identifier: A unique identifier for the client (for example, IP address). :raises RateLimitError: If too many requests are received from a client with this identifier. """ pass
[docs]class InMemoryLimitProvider(RateLimitProvider): """ A very simple rate limiting provider - works fine when running a single application instance. Time values are given in seconds, rather than a timedelta, for improved performance. """ def __init__( self, timespan: int, limit: int = 1000, block_duration: t.Optional[int] = None, ): """ :param timespan: The time in seconds between resetting the number of requests. Beware setting it too high, because memory usage will increase. :param limit: The number of requests in the timespan, before getting blocked. :param block_duration: If set, the number of seconds before a client is no longer blocked. Otherwise, they're only removed when the app is restarted. """ # Maps a client identifier to the number of requests they have made. self.request_dict: defaultdict = defaultdict(int) self.timespan = timespan self.last_reset = time() self.limit = limit self.blocked: t.Dict[str, float] = {} self.block_duration = block_duration def _handle_blocked(self): raise RateLimitError() def is_already_blocked(self, identifier: str) -> bool: """ Check whether the identifier is already blocked from previous requests. Remove the identifier if the block has expired. """ blocked_at: t.Optional[float] = self.blocked.get(identifier, None) if blocked_at: duration = self.block_duration if (time() - blocked_at < duration) if duration else True: return True else: del self.blocked[identifier] return False else: return False def add_to_blocked(self, identifier: str): self.blocked[identifier] = time() def increment(self, identifier: str): """ Increment the number of requests with this identifier. If too many requests are received during the interval then record them as blocked, and reject the request. :param identifier: An identifier for the client making the request, for example the IP address. """ if self.is_already_blocked(identifier): self._handle_blocked() # Reset the request count if needed. now = time() if now - self.last_reset > self.timespan: self.last_reset = now self.request_dict = defaultdict(int) self.request_dict[identifier] += 1 if self.request_dict[identifier] > self.limit: self.add_to_blocked(identifier) self._handle_blocked() def clear_blocked(self): """ Resets the block list. """ self.blocked = {}
[docs]class RateLimitingMiddleware(BaseHTTPMiddleware): """ Blocks clients who exceed a given number of requests in a given time period. """ def __init__( self, app: ASGIApp, provider: t.Optional[RateLimitProvider] = None, ): """ :param app: The ASGI app to wrap. :param provider: Provides the logic around rate limiting. If not specified, it will default to a :class:`InMemoryLimitProvider`. """ super().__init__(app) if provider is None: provider = InMemoryLimitProvider(limit=1000, timespan=300) self.rate_limit = provider async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: if not request.client: # If we can't get the client, we have to reject the request. return Response( content="Client host can't be found.", status_code=400 ) identifier = request.client.host try: self.rate_limit.increment(identifier) except RateLimitError: return Response(content="Too many requests", status_code=429) return await call_next(request)