76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
__all__ = [
|
|
'BaseModel',
|
|
]
|
|
|
|
from typing import Any, Iterable, Optional, Union
|
|
|
|
from sqlalchemy import Column, Integer, inspect
|
|
from sqlalchemy.orm import InstanceState, as_declarative
|
|
from validataclass.dataclasses import ValidataclassMixin
|
|
|
|
from tofu_api.common.database import Col, MetaData
|
|
|
|
|
|
@as_declarative(metadata=MetaData())
|
|
class BaseModel:
|
|
"""
|
|
Declarative base class for database models.
|
|
"""
|
|
|
|
# Default primary key
|
|
id: Col[int] = Column(Integer, nullable=False, primary_key=True)
|
|
|
|
def __repr__(self) -> str:
|
|
"""
|
|
Return a string representation of this object.
|
|
"""
|
|
return self._repr(id=self.id) if hasattr(self, 'id') else self._repr()
|
|
|
|
def _repr(self, **fields) -> str:
|
|
"""
|
|
Helper method for implementing __repr__.
|
|
"""
|
|
state: InstanceState = inspect(self)
|
|
state_str = f' [transient {id(self)}]' if state.transient \
|
|
else f' [pending {id(self)}]' if state.pending \
|
|
else ' [deleted]' if state.deleted \
|
|
else ' [detached]' if state.detached else ''
|
|
param_str = ', '.join([f'{key}={value!r}' for key, value in fields.items()] if fields else state.identity or [])
|
|
return f'<{type(self).__name__}({param_str}){state_str}>'
|
|
|
|
def to_dict(
|
|
self,
|
|
*,
|
|
fields: Optional[Iterable[str]] = None,
|
|
exclude: Optional[Iterable[str]] = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Return the object's data as a dictionary.
|
|
|
|
By default, the dictionary will contain all table columns (with their column name as key) defined in the model.
|
|
This can be overridden by setting the `fields` and/or `exclude` parameters, in which case only fields that are
|
|
listed in `fields` will be included in the dictionary, except for fields listed in `exclude`.
|
|
"""
|
|
# Determine fields to include in dictionary (starting will all table columns)
|
|
included_fields = set(column.name for column in self.__table__.columns)
|
|
if fields is not None:
|
|
included_fields.intersection_update(fields)
|
|
if exclude is not None:
|
|
included_fields.difference_update(exclude)
|
|
|
|
return {
|
|
field: getattr(self, field) for field in included_fields
|
|
}
|
|
|
|
def update_from(self, data: Union[dict, ValidataclassMixin]) -> None:
|
|
"""
|
|
Updates the object with data from either a dictionary or a validataclass object (requires the ValidataclassMixin).
|
|
"""
|
|
if isinstance(data, ValidataclassMixin):
|
|
data = data.to_dict()
|
|
|
|
# TODO: Is it a good idea to just iterate over data and setattr? Or should we check __table__.columns?
|
|
for key, value in data.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|