Source code for sheraf.attributes.files

import os

import sheraf
from sheraf.attributes import Attribute
from sheraf.models.base import BaseModel

from ..types import Set

FILES_ROOT_DIR = "files/"


def set_files_root_dir(path):
    global FILES_ROOT_DIR
    FILES_ROOT_DIR = path  # pragma: no cover


class FileObject:
    def __init__(self, stream=None, extension=""):
        self.model = None
        self.attribute_name = None
        self.extension = extension
        self._content = stream

    def __getitem__(self, name):
        return getattr(self, name)

    def __setitem__(self, name, value):
        setattr(self, name, value)

    def __eq__(self, other):
        return (
            isinstance(other, FileObject)
            and self.stream == other.stream
            and self.extension == other.extension
        )

    def __repr__(self):
        return "<FileObject model='{}' attribute_name='{}'>".format(
            self.model.__class__.__name__, self.attribute_name
        )

    def exists(self):
        return os.path.exists(self.absolute_path())

    @property
    def stream(self):
        if self._content is None:
            path = self.absolute_path()
            if path not in FilesGarbageCollector.instance():
                with open(self.absolute_path(), "rb") as f:
                    self._content = f.read()

        return self._content

    @stream.setter
    def stream(self, stream):
        self._content = stream

    @classmethod
    def read(cls, model, attribute_name):
        persisted_file_path = model.mapping.get(attribute_name)
        if not persisted_file_path:
            return None

        extension = os.path.splitext(persisted_file_path)[1].lstrip(".")
        file_instance = cls(extension=extension)
        file_instance.associate(model, attribute_name)
        return file_instance

    def relative_path(self):
        if not self.model:
            return None

        if self.attribute_name in self.model.mapping:
            return self.model.mapping[self.attribute_name]

        file_name = f"{self.model.identifier}.{self.extension}"
        path = os.path.join(self.directory(), file_name)

        return os.path.relpath(path, FILES_ROOT_DIR)

    def absolute_path(self):
        return os.path.join(FILES_ROOT_DIR, self.relative_path())

    def keys(self):
        return ["stream", "extension"]

    def associate(self, model, attribute_name):
        self.model = model
        self.attribute_name = attribute_name

    def directory(self):
        return os.path.join(FILES_ROOT_DIR, self.model.table, self.attribute_name)

    def write(self):
        if self.attribute_name in self.model.mapping:
            self.delete()

        path = self.absolute_path()
        directory = os.path.dirname(path)

        try:
            os.makedirs(directory)
        except OSError:
            pass

        with open(path, "wb") as f:
            f.write(self.stream)
            try:
                FilesGarbageCollector.instance().remove(self.relative_path())
            except KeyError:
                pass

        self.model.mapping[self.attribute_name] = self.relative_path()

    def delete(self):
        if not self.model or not self.attribute_name in self.model.mapping:
            return

        FilesGarbageCollector.instance().add(self.model.mapping[self.attribute_name])
        del self.model.mapping[self.attribute_name]


[docs]class FileAttribute(Attribute): """ This attribute stores a file on disk. """ def __init__(self, file_object_class=FileObject, **kwargs): self.FileObjectClass = file_object_class kwargs.setdefault("read_memoization", True) super().__init__(**kwargs) def read(self, parent): # TODO: Implement deserialize instead of read return self.FileObjectClass.read(parent, self.key(parent)) def write(self, parent, value): # TODO: Implement serialize instead of write if not value: return value if isinstance(value, dict): value = self.FileObjectClass(**value) elif isinstance(value.model, BaseModel) and not BaseModel.__eq__( value.model, parent ): value = self.FileObjectClass(stream=value.stream, extension=value.extension) value.associate(parent, self.key(parent)) return value def save(self, parent): _file = getattr(parent, self.key(parent)) if _file: _file.write() else: _file = self.FileObjectClass.read(parent, self.key(parent)) if _file: _file.delete() def delete(self, parent): _file = getattr(parent, self.key(parent)) if not _file: _file = self.FileObjectClass() _file.associate(parent, self.key(parent)) _file.delete()
class FilesGarbageCollector: DEFAULT_TABLE = "__sheraf_collected_files_to_remove" _instance = None default_mapping = Set @classmethod def instance(cls): if not cls._instance: cls._instance = cls() return cls._instance def __init__(self, table=DEFAULT_TABLE): self.table = table def _collected_paths(self): root = sheraf.Database.current_connection().root() if self.table not in root: root[self.table] = self.default_mapping() return root[self.table] @staticmethod def _relpath(path): if path.startswith(FILES_ROOT_DIR): return os.path.relpath(path, FILES_ROOT_DIR) return path def add(self, path): self._collected_paths().add(self._relpath(path)) def remove(self, path): self._collected_paths().remove(self._relpath(path)) def __contains__(self, path): return self._relpath(path) in self._collected_paths() def __bool__(self): return bool(self._collected_paths()) def clear(self): for path in self._collected_paths(): try: os.remove(os.path.join(FILES_ROOT_DIR, path)) except OSError: pass self._collected_paths().clear()