Source code for piccolo_api.session_auth.tables

from __future__ import annotations

import secrets
from datetime import datetime, timedelta
from typing import Optional, cast

from piccolo.columns import Integer, Serial, Timestamp, Varchar
from piccolo.columns.defaults.timestamp import TimestampOffset
from piccolo.table import Table
from piccolo.utils.sync import run_sync


[docs] class SessionsBase(Table, tablename="sessions"): """ Use this table, or inherit from it, to create a session store. """ id: Serial #: Stores the session token. token = Varchar(length=100, null=False, secret=True) #: Stores the user ID. user_id = Integer(null=False) #: Stores the expiry date for this session. expiry_date = Timestamp(default=TimestampOffset(hours=1), null=False) #: We set a hard limit on the expiry date - it can keep on getting extended #: up until this value, after which it's best to invalidate it, and either #: require login again, or just create a new session token. max_expiry_date = Timestamp(default=TimestampOffset(days=7), null=False)
[docs] @classmethod async def create_session( cls, user_id: int, expiry_date: Optional[datetime] = None, max_expiry_date: Optional[datetime] = None, ) -> SessionsBase: """ Creates a session in the database. """ while True: token = secrets.token_urlsafe(nbytes=32) if not await cls.exists().where(cls.token == token).run(): break session = cls(token=token, user_id=user_id) if expiry_date: session.expiry_date = expiry_date if max_expiry_date: session.max_expiry_date = max_expiry_date await session.save().run() return session
[docs] @classmethod def create_session_sync( cls, user_id: int, expiry_date: Optional[datetime] = None ) -> SessionsBase: """ A sync equivalent of :meth:`create_session`. """ return run_sync(cls.create_session(user_id, expiry_date))
[docs] @classmethod async def get_user_id( cls, token: str, increase_expiry: Optional[timedelta] = None ) -> Optional[int]: """ Returns the ``user_id`` if the given token is valid, otherwise ``None``. :param increase_expiry: If set, the ``expiry_date`` will be increased by the given amount if it's close to expiring. If it has already expired, nothing happens. The ``max_expiry_date`` remains the same, so there's a hard limit on how long a session can be used for. """ session = await cls.objects().where(cls.token == token).first().run() if not session: return None now = datetime.now() if (session.expiry_date > now) and (session.max_expiry_date > now): if increase_expiry and ( cast(datetime, session.expiry_date) - now < increase_expiry ): session.expiry_date = ( cast(datetime, session.expiry_date) + increase_expiry ) await session.save().run() return cast(Optional[int], session.user_id) else: return None
[docs] @classmethod def get_user_id_sync(cls, token: str) -> Optional[int]: """ A sync wrapper around :meth:`get_user_id`. """ return run_sync(cls.get_user_id(token))
[docs] @classmethod async def remove_session(cls, token: str): """ Deletes a matching session from the database. """ await cls.delete().where(cls.token == token).run()
[docs] @classmethod def remove_session_sync(cls, token: str): """ A sync wrapper around :meth:`remove_session`. """ return run_sync(cls.remove_session(token))