from __future__ import annotations
import enum
import typing as t
import jwt
from piccolo.apps.user.tables import BaseUser
from starlette.exceptions import HTTPException
from starlette.types import ASGIApp
[docs]
class JWTBlacklist:
"""
Inherit from this class, and override :meth:`in_blacklist`. Used in
conjunction with :class:`JWTMiddleware`. An example is
:class:`StaticJWTBlacklist`.
"""
[docs]
async def in_blacklist(self, token: str) -> bool:
"""
Checks whether the token is in the blacklist.
"""
return False
[docs]
class StaticJWTBlacklist(JWTBlacklist):
"""
A simple implementation of :class:`JWTBlacklist <JWTBlacklist>`, which
rejects a token if it's in the given list.
"""
def __init__(self, blacklist: t.List[str]):
self.blacklist = blacklist
async def in_blacklist(self, token: str) -> bool:
return token in self.blacklist
def extend_scope(scope: t.Dict, extra: t.Dict) -> t.Dict:
"""
We copy the scope and extend it with `extra`. It's best to copy the scope
rather than manipulate it directly.
"""
new_scope = dict(scope)
new_scope.update(extra)
return new_scope
[docs]
class JWTError(str, enum.Enum):
"""
This enum contains all of the possible errors which can be returned by
:class:`JWTMiddleware`. If ``allow_unauthenticated=True`` then these
errors will be added to the ASGI scope instead under ``jwt_error``.
"""
token_not_found = "Token not found"
token_revoked = "Token revoked"
token_expired = "Token has expired"
user_not_found = "User not found"
token_invalid = "Token is invalid"
[docs]
class JWTMiddleware:
"""
Protects ASGI endpoints - only allows access if a JWT token is present in
the ``authorization`` HTTP header.
"""
def __init__(
self,
asgi: ASGIApp,
secret: str,
auth_table: t.Type[BaseUser] = BaseUser,
blacklist: JWTBlacklist = JWTBlacklist(),
allow_unauthenticated: bool = False,
) -> None:
"""
:param asgi:
The ASGI app to protect.
:param secret:
The secret used to decode the JWT token.
:param auth_table:
The Piccolo table containing users - either
:class:`BaseUser <piccolo.apps.user.tables.BaseUser>` or a
subclass.
:param blacklist:
Any tokens in this list will be rejected.
:param allow_unauthenticated:
By default the middleware rejects any requests with an invalid
token.
"""
self.asgi = asgi
self.secret = secret
self.auth_table = auth_table
self.blacklist = blacklist
self.allow_unauthenticated = allow_unauthenticated
def get_token(self, headers: dict) -> t.Optional[str]:
"""
Try and extract the JWT token from the request headers.
"""
auth_token = headers.get(b"authorization", None)
if not auth_token:
return None
auth_str = auth_token.decode()
if not auth_str.startswith("Bearer "):
return None
return auth_str.split(" ")[1]
async def get_user(
self, token_dict: t.Dict[str, t.Any]
) -> t.Optional[BaseUser]:
"""
Extract the user_id from the token, and return a matching user.
"""
user_id = token_dict.get("user_id", None)
if not user_id:
return None
return await self.auth_table.objects().get(
self.auth_table._meta.primary_key == user_id
)
async def __call__(self, scope, receive, send):
"""
Add the user_id to the scope if a JWT token is available, and the user
is recognised, otherwise raise a 403 HTTP error.
"""
allow_unauthenticated = self.allow_unauthenticated
headers = dict(scope["headers"])
token = self.get_token(headers)
if not token:
error = JWTError.token_not_found.value
if allow_unauthenticated:
await self.asgi(
extend_scope(scope, {"user_id": None, "jwt_error": error}),
receive,
send,
)
return
else:
raise HTTPException(status_code=403, detail=error)
if await self.blacklist.in_blacklist(token):
error = JWTError.token_revoked.value
if allow_unauthenticated:
await self.asgi(
extend_scope(scope, {"user_id": None, "jwt_error": error}),
receive,
send,
)
return
else:
raise HTTPException(status_code=403, detail=error)
try:
token_dict = jwt.decode(token, self.secret, algorithms=["HS256"])
except jwt.exceptions.ExpiredSignatureError:
error = JWTError.token_expired.value
if allow_unauthenticated:
await self.asgi(
extend_scope(scope, {"user_id": None, "jwt_error": error}),
receive,
send,
)
return
else:
raise HTTPException(status_code=403, detail=error)
except jwt.exceptions.InvalidSignatureError:
error = JWTError.token_invalid.value
if allow_unauthenticated:
await self.asgi(
extend_scope(scope, {"user_id": None, "jwt_error": error}),
receive,
send,
)
return
else:
raise HTTPException(status_code=403, detail=error)
user = await self.get_user(token_dict)
if user is None:
error = JWTError.user_not_found.value
if allow_unauthenticated:
await self.asgi(
extend_scope(scope, {"user_id": None, "jwt_error": error}),
receive,
send,
)
return
else:
raise HTTPException(status_code=403, detail=error)
await self.asgi(
extend_scope(scope, {"user_id": user.id}), receive, send
)