Source code for piccolo_api.media.local
from __future__ import annotations
import asyncio
import functools
import logging
import os
import pathlib
import shutil
import typing as t
from concurrent.futures import ThreadPoolExecutor
from piccolo.apps.user.tables import BaseUser
from piccolo.columns.column_types import Array, Text, Varchar
from piccolo.utils.sync import run_sync
from .base import ALLOWED_CHARACTERS, ALLOWED_EXTENSIONS, MediaStorage
if t.TYPE_CHECKING: # pragma: no cover
from concurrent.futures._base import Executor
logger = logging.getLogger(__name__)
[docs]
class LocalMediaStorage(MediaStorage):
def __init__(
self,
column: t.Union[Text, Varchar, Array],
media_path: str,
executor: t.Optional[Executor] = None,
allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS,
allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS,
file_permissions: t.Optional[int] = 0o600,
):
"""
Stores media files on a local path. This is good for simple
applications, where you're happy with the media files being stored
on a single server.
:param column:
The Piccolo ``Column`` which the storage is for.
:param media_path:
This is the local folder where the media files will be stored. It
should be an absolute path. For example,
``'/srv/piccolo-media/poster/'``.
:param executor:
An executor, which file save operations are run in, to avoid
blocking the event loop. If not specified, we use a sensibly
configured :class:`ThreadPoolExecutor <concurrent.futures.ThreadPoolExecutor>`.
:param allowed_extensions:
Which file extensions are allowed. If ``None``, then all extensions
are allowed (not recommended unless the users are trusted).
:param allowed_characters:
Which characters are allowed in the file name. By default, it's
very strict. If set to ``None`` then all characters are allowed.
:param file_permissions:
If set to a value other than ``None``, then all uploaded files are
given these file permissions.
""" # noqa: E501
self.media_path = media_path
self.executor = executor or ThreadPoolExecutor(max_workers=10)
self.file_permissions = file_permissions
if not os.path.exists(media_path):
os.mkdir(self.media_path)
super().__init__(
column=column,
allowed_extensions=allowed_extensions,
allowed_characters=allowed_characters,
)
async def store_file(
self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None
) -> str:
# If the file_name includes the entire path (e.g. /foo/bar.jpg) - we
# just want bar.jpg.
file_name = pathlib.Path(file_name).name
file_key = self.generate_file_key(file_name=file_name, user=user)
loop = asyncio.get_running_loop()
file_permissions = self.file_permissions
def save():
path = os.path.join(self.media_path, file_key)
if os.path.exists(path):
logger.error(
"A file name clash has occurred - the chances are very "
"low. Could be malicious, or a serious bug."
)
raise IOError("Unable to save the file")
with open(path, "wb") as new_file:
shutil.copyfileobj(file, new_file)
if file_permissions is not None:
os.chmod(path, file_permissions)
await loop.run_in_executor(self.executor, save)
return file_key
def store_file_sync(
self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None
) -> str:
"""
A sync wrapper around :meth:`store_file`.
"""
return run_sync(
self.store_file(file_name=file_name, file=file, user=user)
)
async def generate_file_url(
self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None
) -> str:
"""
This retrieves an absolute URL for the file.
"""
return "/".join((root_url.rstrip("/"), file_key))
def generate_file_url_sync(
self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None
) -> str:
"""
A sync wrapper around :meth:`generate_file_url`.
"""
return run_sync(
self.generate_file_url(
file_key=file_key, root_url=root_url, user=user
)
)
###########################################################################
async def get_file(self, file_key: str) -> t.Optional[t.IO]:
"""
Returns the file object matching the ``file_key``.
"""
loop = asyncio.get_running_loop()
func = functools.partial(self.get_file_sync, file_key=file_key)
return await loop.run_in_executor(self.executor, func)
def get_file_sync(self, file_key: str) -> t.Optional[t.IO]:
"""
A sync wrapper around :meth:`get_file`.
"""
path = os.path.join(self.media_path, file_key)
return open(path, "rb")
async def delete_file(self, file_key: str):
"""
Deletes the file object matching the ``file_key``.
"""
loop = asyncio.get_running_loop()
func = functools.partial(self.delete_file_sync, file_key=file_key)
return await loop.run_in_executor(self.executor, func)
def delete_file_sync(self, file_key: str):
"""
A sync wrapper around :meth:`delete_file`.
"""
path = os.path.join(self.media_path, file_key)
os.unlink(path)
async def bulk_delete_files(self, file_keys: t.List[str]):
media_path = self.media_path
for file_key in file_keys:
os.unlink(os.path.join(media_path, file_key))
async def get_file_keys(self) -> t.List[str]:
"""
Returns the file key for each file we have stored.
"""
file_keys = []
for _, _, filenames in os.walk(self.media_path):
file_keys.extend(filenames)
break
return file_keys
def __hash__(self):
return hash(("local", self.media_path))
def __eq__(self, value):
if not isinstance(value, LocalMediaStorage):
return False
return value.__hash__() == self.__hash__()