Source code for piccolo_api.session_auth.tables
from __future__ import annotations
import secrets
from datetime import datetime, timedelta
from typing import Optional, cast
from piccolo.columns import Integer, Serial, Timestamp, Varchar
from piccolo.columns.defaults.timestamp import TimestampOffset
from piccolo.table import Table
from piccolo.utils.sync import run_sync
[docs]
class SessionsBase(Table, tablename="sessions"):
"""
Use this table, or inherit from it, to create a session store.
"""
id: Serial
#: Stores the session token.
token = Varchar(length=100, null=False, secret=True)
#: Stores the user ID.
user_id = Integer(null=False)
#: Stores the expiry date for this session.
expiry_date = Timestamp(default=TimestampOffset(hours=1), null=False)
#: We set a hard limit on the expiry date - it can keep on getting extended
#: up until this value, after which it's best to invalidate it, and either
#: require login again, or just create a new session token.
max_expiry_date = Timestamp(default=TimestampOffset(days=7), null=False)
[docs]
@classmethod
async def create_session(
cls,
user_id: int,
expiry_date: Optional[datetime] = None,
max_expiry_date: Optional[datetime] = None,
) -> SessionsBase:
"""
Creates a session in the database.
"""
while True:
token = secrets.token_urlsafe(nbytes=32)
if not await cls.exists().where(cls.token == token).run():
break
session = cls(token=token, user_id=user_id)
if expiry_date:
session.expiry_date = expiry_date
if max_expiry_date:
session.max_expiry_date = max_expiry_date
await session.save().run()
return session
[docs]
@classmethod
def create_session_sync(
cls, user_id: int, expiry_date: Optional[datetime] = None
) -> SessionsBase:
"""
A sync equivalent of :meth:`create_session`.
"""
return run_sync(cls.create_session(user_id, expiry_date))
[docs]
@classmethod
async def get_user_id(
cls, token: str, increase_expiry: Optional[timedelta] = None
) -> Optional[int]:
"""
Returns the ``user_id`` if the given token is valid, otherwise
``None``.
:param increase_expiry:
If set, the ``expiry_date`` will be increased by the given amount
if it's close to expiring. If it has already expired, nothing
happens. The ``max_expiry_date`` remains the same, so there's a
hard limit on how long a session can be used for.
"""
session = await cls.objects().where(cls.token == token).first().run()
if not session:
return None
now = datetime.now()
if (session.expiry_date > now) and (session.max_expiry_date > now):
if increase_expiry and (
cast(datetime, session.expiry_date) - now < increase_expiry
):
session.expiry_date = (
cast(datetime, session.expiry_date) + increase_expiry
)
await session.save().run()
return cast(Optional[int], session.user_id)
else:
return None
[docs]
@classmethod
def get_user_id_sync(cls, token: str) -> Optional[int]:
"""
A sync wrapper around :meth:`get_user_id`.
"""
return run_sync(cls.get_user_id(token))
[docs]
@classmethod
async def remove_session(cls, token: str):
"""
Deletes a matching session from the database.
"""
await cls.delete().where(cls.token == token).run()
[docs]
@classmethod
def remove_session_sync(cls, token: str):
"""
A sync wrapper around :meth:`remove_session`.
"""
return run_sync(cls.remove_session(token))