From 50aff05614495c270cfcdda17e6e84e5b5dc664d Mon Sep 17 00:00:00 2001 From: binaryDiv Date: Fri, 23 Sep 2022 20:38:34 +0200 Subject: [PATCH] Implement task API with input validation; implement error handling; refactoring --- api_tests/http-client.env.json | 5 ++ api_tests/tasks.http | 28 ++++++ tofu_api/api/rest_api.py | 4 +- tofu_api/api/tasks/__init__.py | 3 +- tofu_api/api/tasks/task_api.py | 104 ----------------------- tofu_api/api/tasks/task_blueprint.py | 30 +++++++ tofu_api/api/tasks/task_handler.py | 50 +++++++++++ tofu_api/api/tasks/task_views.py | 89 +++++++++++++++++++ tofu_api/api/tasks/validators.py | 23 +++++ tofu_api/app.py | 24 ++++++ tofu_api/common/exceptions/__init__.py | 1 + tofu_api/common/exceptions/base.py | 23 +++++ tofu_api/common/rest/__init__.py | 2 + tofu_api/common/rest/base_blueprint.py | 15 +--- tofu_api/common/rest/base_method_view.py | 32 +++++++ tofu_api/common/rest/error_handler.py | 73 ++++++++++++++++ tofu_api/common/rest/exceptions.py | 30 +++++++ tofu_api/common/string_utils.py | 28 ++++++ tofu_api/dependencies.py | 38 ++++++++- tofu_api/models/base.py | 27 ++++-- tofu_api/repositories/__init__.py | 2 + tofu_api/repositories/base_repository.py | 94 ++++++++++++++++++++ tofu_api/repositories/exceptions.py | 9 ++ tofu_api/repositories/task_repository.py | 9 ++ 24 files changed, 612 insertions(+), 131 deletions(-) create mode 100644 api_tests/http-client.env.json create mode 100644 api_tests/tasks.http delete mode 100644 tofu_api/api/tasks/task_api.py create mode 100644 tofu_api/api/tasks/task_blueprint.py create mode 100644 tofu_api/api/tasks/task_handler.py create mode 100644 tofu_api/api/tasks/task_views.py create mode 100644 tofu_api/api/tasks/validators.py create mode 100644 tofu_api/common/exceptions/__init__.py create mode 100644 tofu_api/common/exceptions/base.py create mode 100644 tofu_api/common/rest/base_method_view.py create mode 100644 tofu_api/common/rest/error_handler.py create mode 100644 tofu_api/common/rest/exceptions.py create mode 100644 tofu_api/common/string_utils.py create mode 100644 tofu_api/repositories/__init__.py create mode 100644 tofu_api/repositories/base_repository.py create mode 100644 tofu_api/repositories/exceptions.py create mode 100644 tofu_api/repositories/task_repository.py diff --git a/api_tests/http-client.env.json b/api_tests/http-client.env.json new file mode 100644 index 0000000..6e4bf5e --- /dev/null +++ b/api_tests/http-client.env.json @@ -0,0 +1,5 @@ +{ + "dev": { + "api_host": "http://localhost:5000" + } +} diff --git a/api_tests/tasks.http b/api_tests/tasks.http new file mode 100644 index 0000000..7e6699e --- /dev/null +++ b/api_tests/tasks.http @@ -0,0 +1,28 @@ +### Fetch all tasks +GET {{api_host}}/api/tasks + + +### Fetch one task +GET {{api_host}}/api/tasks/1 + + +### Create new task +POST {{api_host}}/api/tasks +Content-Type: application/json + +{ + "title": "Some test task" +} + + +### Update task +PATCH {{api_host}}/api/tasks/1 +Content-Type: application/json + +{ + "description": "Update!" +} + + +### Delete task +DELETE {{api_host}}/api/tasks/10 diff --git a/tofu_api/api/rest_api.py b/tofu_api/api/rest_api.py index fea8814..f708aae 100644 --- a/tofu_api/api/rest_api.py +++ b/tofu_api/api/rest_api.py @@ -1,5 +1,5 @@ from tofu_api.common.rest import BaseBlueprint -from .tasks import TaskApiBlueprint +from .tasks import TaskBlueprint class TofuApiBlueprint(BaseBlueprint): @@ -12,4 +12,4 @@ class TofuApiBlueprint(BaseBlueprint): url_prefix = '/api' def init_blueprint(self) -> None: - self.register_blueprint(TaskApiBlueprint(self.app)) + self.register_blueprint(TaskBlueprint(self.app)) diff --git a/tofu_api/api/tasks/__init__.py b/tofu_api/api/tasks/__init__.py index 021af83..ce3b464 100644 --- a/tofu_api/api/tasks/__init__.py +++ b/tofu_api/api/tasks/__init__.py @@ -1 +1,2 @@ -from .task_api import TaskApiBlueprint +from .task_blueprint import TaskBlueprint +from .task_handler import TaskHandler diff --git a/tofu_api/api/tasks/task_api.py b/tofu_api/api/tasks/task_api.py deleted file mode 100644 index a4fff4a..0000000 --- a/tofu_api/api/tasks/task_api.py +++ /dev/null @@ -1,104 +0,0 @@ -from flask import jsonify -from flask.views import MethodView -from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound - -from tofu_api.common.rest import BaseBlueprint -from tofu_api.models import Task - - -class TaskApiBlueprint(BaseBlueprint): - """ - Blueprint for the tasks REST API. - """ - - # Blueprint settings - name = 'rest_api_tasks' - import_name = __name__ - url_prefix = '/tasks' - - def init_blueprint(self) -> None: - """ - Register URL rules. - """ - db_session = self.app.dependencies.get_db_session() - - self.add_url_rule( - '', - view_func=self.create_view_func(TaskCollectionView, db_session=db_session), - methods=['GET', 'POST'], - ) - self.add_url_rule( - '/', - view_func=self.create_view_func(TaskItemView, db_session=db_session), - methods=['GET', 'PATCH', 'DELETE'], - ) - - -class TaskBaseView(MethodView): - """ - Base class for view classes for the `/tasks` endpoint. - """ - - # TODO: Use a handler class instead of accessing the database session directly - db_session: Session - - def __init__(self, *, db_session: Session): - self.db_session = db_session - - -class TaskCollectionView(TaskBaseView): - """ - View class for `/tasks` endpoint. - """ - - def get(self): - """ - Get list of all tasks. - """ - task_list = self.db_session.query(Task).all() - return jsonify({ - 'count': len(task_list), - 'items': [task.to_dict() for task in task_list], - }), 200 - - def post(self): - """ - Create a new task. - """ - # TODO: Parse request data and create real data - new_task = Task( - title='Do stuff' - ) - self.db_session.add(new_task) - self.db_session.commit() - return jsonify(new_task.to_dict()), 201 - - -class TaskItemView(TaskBaseView): - """ - View class for `/tasks/` endpoint. - """ - - def get(self, task_id: int): - """ - Get a single task by ID. - """ - task = self.db_session.query(Task).get(task_id) - if task is None: - raise NotFound(f'Task with ID {task_id} not found!') - return jsonify(task.to_dict()), 200 - - def patch(self, task_id: int): - """ - Update a single task by ID. - """ - # TODO: Implement - raise NotImplementedError - - def delete(self, task_id: int): - """ - Delete a single task by ID. - """ - # TODO: Implement - raise NotImplementedError diff --git a/tofu_api/api/tasks/task_blueprint.py b/tofu_api/api/tasks/task_blueprint.py new file mode 100644 index 0000000..d73ee15 --- /dev/null +++ b/tofu_api/api/tasks/task_blueprint.py @@ -0,0 +1,30 @@ +from tofu_api.common.rest import BaseBlueprint +from .task_views import TaskCollectionView, TaskItemView + + +class TaskBlueprint(BaseBlueprint): + """ + Blueprint for the tasks REST API. + """ + + # Blueprint settings + name = 'rest_api_tasks' + import_name = __name__ + url_prefix = '/tasks' + + def init_blueprint(self) -> None: + """ + Register URL rules. + """ + task_handler = self.app.dependencies.get_task_handler() + + self.add_url_rule( + '', + view_func=TaskCollectionView.as_view(task_handler=task_handler), + methods=['GET', 'POST'], + ) + self.add_url_rule( + '/', + view_func=TaskItemView.as_view(task_handler=task_handler), + methods=['GET', 'PATCH', 'DELETE'], + ) diff --git a/tofu_api/api/tasks/task_handler.py b/tofu_api/api/tasks/task_handler.py new file mode 100644 index 0000000..7cb5fed --- /dev/null +++ b/tofu_api/api/tasks/task_handler.py @@ -0,0 +1,50 @@ +from tofu_api.models import Task +from tofu_api.repositories import TaskRepository +from .validators import TaskCreateData, TaskUpdateData + + +class TaskHandler: + """ + Handles operations on tasks. + """ + + task_repository: TaskRepository + + def __init__(self, *, task_repository: TaskRepository): + self.task_repository = task_repository + + def fetch_task(self, task_id: int) -> Task: + """ + Fetches a single task by its ID from the database. + Raises an ObjectNotFoundError if the task was not found. + """ + return self.task_repository.fetch_by_id(task_id) + + def fetch_all_tasks(self) -> list[Task]: + """ + Fetches a list of all tasks. + """ + return self.task_repository.fetch_all() + + def create_task(self, create_data: TaskCreateData) -> Task: + """ + Creates a new task, saves it to the database and returns it. + """ + task = Task() + task.update_from(create_data) + self.task_repository.save_resource(task) + return task + + def update_task(self, task: Task, update_data: TaskUpdateData) -> Task: + """ + Updates a Task object with new data. + """ + task.update_from(update_data) + self.task_repository.save_resource(task) + return task + + def delete_task(self, task: Task) -> None: + """ + Deletes a task from the database. + """ + self.task_repository.delete_resource(task) diff --git a/tofu_api/api/tasks/task_views.py b/tofu_api/api/tasks/task_views.py new file mode 100644 index 0000000..3f9a2c3 --- /dev/null +++ b/tofu_api/api/tasks/task_views.py @@ -0,0 +1,89 @@ +from flask import jsonify +from validataclass.validators import DataclassValidator + +from tofu_api.common.rest import BaseMethodView +from .task_handler import TaskHandler +from .validators import TaskCreateData, TaskUpdateData + + +class TaskBaseView(BaseMethodView): + """ + Base class for view classes for the `/tasks` endpoint. + """ + + task_handler: TaskHandler + + def __init__(self, *, task_handler: TaskHandler): + self.task_handler = task_handler + + +class TaskCollectionView(TaskBaseView): + """ + View class for `/api/tasks` endpoint. + """ + + # Validators + task_create_validator = DataclassValidator(TaskCreateData) + + def get(self): + """ + Get list of all tasks. + """ + task_list = self.task_handler.fetch_all_tasks() + return jsonify({ + 'items': [task.to_dict() for task in task_list], + 'total_count': len(task_list), + }), 200 + + def post(self): + """ + Create a new task. + """ + # Parse request data + create_data: TaskCreateData = self.validate_request_data(self.task_create_validator) + + # Create new task + new_task = self.task_handler.create_task(create_data) + + # Return new task as JSON + return jsonify(new_task.to_dict()), 201 + + +class TaskItemView(TaskBaseView): + """ + View class for `/api/tasks/` endpoint. + """ + + # Validators + task_update_validator = DataclassValidator(TaskUpdateData) + + def get(self, task_id: int): + """ + Get a single task by ID. + """ + task = self.task_handler.fetch_task(task_id) + return jsonify(task.to_dict()), 200 + + def patch(self, task_id: int): + """ + Update a single task by ID. + """ + # Parse request data + update_data: TaskUpdateData = self.validate_request_data(self.task_update_validator) + + # Fetch task and update + task = self.task_handler.fetch_task(task_id) + task = self.task_handler.update_task(task, update_data) + + # Return updated task as JSON + return jsonify(task.to_dict()), 200 + + def delete(self, task_id: int): + """ + Delete a single task by ID. + """ + # Fetch task and delete + task = self.task_handler.fetch_task(task_id) + self.task_handler.delete_task(task) + + return self.empty_response() diff --git a/tofu_api/api/tasks/validators.py b/tofu_api/api/tasks/validators.py new file mode 100644 index 0000000..ca8f3c4 --- /dev/null +++ b/tofu_api/api/tasks/validators.py @@ -0,0 +1,23 @@ +from validataclass.dataclasses import Default, DefaultUnset, ValidataclassMixin, validataclass +from validataclass.helpers import OptionalUnset +from validataclass.validators import StringValidator + + +@validataclass +class TaskCreateData(ValidataclassMixin): + """ + Dataclass for "create task" request data. + """ + + title: str = StringValidator(min_length=1, max_length=200) + description: str = StringValidator(max_length=2000), Default('') + + +@validataclass +class TaskUpdateData(TaskCreateData): + """ + Dataclass for "update task" request data. + """ + + title: OptionalUnset[str] = DefaultUnset + description: OptionalUnset[str] = DefaultUnset diff --git a/tofu_api/app.py b/tofu_api/app.py index 232ee31..d8d8606 100644 --- a/tofu_api/app.py +++ b/tofu_api/app.py @@ -1,3 +1,4 @@ +import logging import os import sys import warnings @@ -7,6 +8,7 @@ from flask import Flask from tofu_api.api import TofuApiBlueprint from tofu_api.common.config import Config from tofu_api.common.json import JSONProvider +from tofu_api.common.rest import RestApiErrorHandler from tofu_api.dependencies import Dependencies # Enable deprecation warnings in dev environment @@ -42,6 +44,10 @@ class App(Flask): # Load app configuration from YAML file self.config.from_yaml(os.getenv('FLASK_CONFIG_FILE', default='config.yml')) + # Configure logging and error handling + self.configure_logging() + self.configure_error_handling() + # Initialize DI container self.dependencies = Dependencies() @@ -51,6 +57,24 @@ class App(Flask): # Register blueprints self.register_blueprint(TofuApiBlueprint(self)) + def configure_logging(self) -> None: + """ + Configures the logging system. + """ + logging.basicConfig( + level=logging.DEBUG if self.debug else logging.INFO, + format='{asctime}.{msecs:03.0f} {levelname:>8} [{name}] {message}', + datefmt='%Y-%m-%d %H:%M:%S', + style='{', + ) + + def configure_error_handling(self) -> None: + """ + Registers error handlers to the app. + """ + error_handler = RestApiErrorHandler(debug_mode=self.debug) + error_handler.register_error_handlers(self) + def init_database(self) -> None: """ Initialize database connection and models. diff --git a/tofu_api/common/exceptions/__init__.py b/tofu_api/common/exceptions/__init__.py new file mode 100644 index 0000000..96105e0 --- /dev/null +++ b/tofu_api/common/exceptions/__init__.py @@ -0,0 +1 @@ +from .base import AppException diff --git a/tofu_api/common/exceptions/base.py b/tofu_api/common/exceptions/base.py new file mode 100644 index 0000000..ee77f61 --- /dev/null +++ b/tofu_api/common/exceptions/base.py @@ -0,0 +1,23 @@ +from typing import Optional + + +class AppException(Exception): + """ + Base class for application specific exceptions that can also be used as API error responses. + """ + code: str = 'unspecified_error' + status_code: int = 400 + message: str + + def __init__(self, message: str, *, code: Optional[str] = None, status_code: Optional[int] = None): + if code is not None: + self.code = code + if status_code is not None: + self.status_code = status_code + self.message = message + + def to_dict(self) -> dict: + return { + 'code': self.code, + 'message': self.message, + } diff --git a/tofu_api/common/rest/__init__.py b/tofu_api/common/rest/__init__.py index 91d7bc8..e3cfe4c 100644 --- a/tofu_api/common/rest/__init__.py +++ b/tofu_api/common/rest/__init__.py @@ -1 +1,3 @@ from .base_blueprint import BaseBlueprint +from .base_method_view import BaseMethodView +from .error_handler import RestApiErrorHandler diff --git a/tofu_api/common/rest/base_blueprint.py b/tofu_api/common/rest/base_blueprint.py index 9d46569..21753e6 100644 --- a/tofu_api/common/rest/base_blueprint.py +++ b/tofu_api/common/rest/base_blueprint.py @@ -1,16 +1,11 @@ from abc import ABC, abstractmethod -from typing import Callable, Type, TYPE_CHECKING +from typing import TYPE_CHECKING from flask import Blueprint -from flask.views import View if TYPE_CHECKING: from tofu_api.app import App -__all__ = [ - 'BaseBlueprint', -] - class BaseBlueprint(Blueprint, ABC): """ @@ -62,11 +57,3 @@ class BaseBlueprint(Blueprint, ABC): Register child blueprints and URL rules. """ raise NotImplementedError - - @staticmethod - def create_view_func(view_cls: Type[View], *args, **kwargs) -> Callable: - """ - Helper function to create a view function from a `View` class using `view_cls.as_view()`. - All arguments are passed to the constructor of the view class. - """ - return view_cls.as_view(view_cls.__name__, *args, **kwargs) diff --git a/tofu_api/common/rest/base_method_view.py b/tofu_api/common/rest/base_method_view.py new file mode 100644 index 0000000..ba9c0e7 --- /dev/null +++ b/tofu_api/common/rest/base_method_view.py @@ -0,0 +1,32 @@ +from typing import Any + +import flask +from flask.typing import RouteCallable +from flask.views import MethodView +from validataclass.validators import Validator + + +class BaseMethodView(MethodView): + """ + Base class for REST API views. + """ + + @property + def request(self) -> flask.Request: + return flask.request + + @classmethod + def as_view(cls, *args, **kwargs) -> RouteCallable: + return super().as_view(cls.__name__, *args, **kwargs) + + @staticmethod + def empty_response(code: int = 204) -> tuple[str, int]: + return '', code + + def validate_request_data(self, validator: Validator) -> Any: + """ + Parses request data as JSON and validates it using a validataclass validator. + """ + # TODO error handling: wrong content type; empty body; invalid json; validation errors + parsed_json = self.request.json + return validator.validate(parsed_json) diff --git a/tofu_api/common/rest/error_handler.py b/tofu_api/common/rest/error_handler.py new file mode 100644 index 0000000..2b2350b --- /dev/null +++ b/tofu_api/common/rest/error_handler.py @@ -0,0 +1,73 @@ +import logging +from typing import Union + +from flask import Flask, Response, jsonify +from werkzeug.exceptions import HTTPException +from werkzeug.http import HTTP_STATUS_CODES + +from tofu_api.common import string_utils +from tofu_api.common.exceptions import AppException +from .exceptions import InternalServerError + +T_Response = Union[Response, tuple[Response, int]] + + +class RestApiErrorHandler: + """ + Error handler class for REST API errors. + """ + + # Dependencies + logger: logging.Logger + + # Options + debug_mode: bool = False + + # Lookup table for HTTP status codes to API error codes + _http_to_api_error_codes: dict[int, str] + + def __init__(self, *, debug_mode: bool = False): + self.logger = logging.getLogger(type(self).__name__) + self.debug_mode = debug_mode + + # Generate lookup table for HTTP status codes + self._http_to_api_error_codes = { + http_status: string_utils.str_to_snake_case(name) + for http_status, name in HTTP_STATUS_CODES.items() + } + + def register_error_handlers(self, app: Flask) -> None: + """ + Registers error handlers for different types of exceptions to the app. + """ + app.register_error_handler(AppException, self.handle_app_exception) + app.register_error_handler(HTTPException, self.handle_http_exception) + app.register_error_handler(Exception, self.handle_generic_exception) + + @staticmethod + def handle_app_exception(exception: AppException) -> T_Response: + """ + Handles exceptions of type `AppException` that were not handled by any more specific handler. + """ + return jsonify(exception.to_dict()), exception.status_code + + def handle_http_exception(self, exception: HTTPException) -> T_Response: + """ + Handles exceptions of type `HTTPException`, i.e. any werkzeug HTTP exceptions. + """ + if exception.code >= 500: + self.logger.exception('HTTP exception with status code %s: %s', exception.code, type(exception).__name__) + + return jsonify({ + 'code': self._http_to_api_error_codes.get(exception.code, 'unknown_http_error'), + 'message': exception.description, + }), exception.code + + def handle_generic_exception(self, exception: Exception) -> T_Response: + """ + Fallback handler for any exceptions not handled by any other handler. + """ + self.logger.exception('Uncaught exception: %s', type(exception).__name__) + + wrapped_exception = InternalServerError('There was an uncaught error on the server.', inner_exception=exception) + return jsonify(wrapped_exception.to_dict(debug=self.debug_mode)), wrapped_exception.status_code diff --git a/tofu_api/common/rest/exceptions.py b/tofu_api/common/rest/exceptions.py new file mode 100644 index 0000000..ae54727 --- /dev/null +++ b/tofu_api/common/rest/exceptions.py @@ -0,0 +1,30 @@ +__all__ = [ + 'InternalServerError' +] + +import traceback +from typing import Optional + +from tofu_api.common.exceptions import AppException + + +class InternalServerError(AppException): + """ + Wrapper exception for any uncaught exception. + """ + status_code = 500 + code = 'internal_server_error' + inner_exception: Optional[Exception] = None + + def __init__(self, message: str, *, inner_exception: Optional[Exception] = None): + super().__init__(message) + self.inner_exception = inner_exception + + def to_dict(self, *, debug: bool = False) -> dict: + data = super().to_dict() + if debug: + data['_debug'] = { + 'exception': str(self.inner_exception), + 'traceback': traceback.format_exception(self.inner_exception), + } + return data diff --git a/tofu_api/common/string_utils.py b/tofu_api/common/string_utils.py new file mode 100644 index 0000000..c15e95c --- /dev/null +++ b/tofu_api/common/string_utils.py @@ -0,0 +1,28 @@ +__all__ = [ + 'SNAKE_CASE_CHARACTERS', + 'is_snake_case', + 'str_to_snake_case', +] + +import string + +SNAKE_CASE_CHARACTERS = string.ascii_lowercase + string.digits + '_' + + +def is_snake_case(input_str: str) -> bool: + """ + Returns True if the input string only consists of snake case characters (lowercase letters, digits, underscore). + """ + return all(c in SNAKE_CASE_CHARACTERS for c in input_str) + + +def str_to_snake_case(input_str: str) -> str: + """ + Converts any string to a snake case string: Whitespaces are replaced with an underscore, uppercase letters are + converted to lowercase, and any non-alphanumeric character is removed. + """ + # First, lowercase string and replace any consecutive whitespaces with a single underscore + almost_snake_case = '_'.join(input_str.lower().split()) + + # Now, remove all characters that are neither letters, digits, nor underscores + return ''.join(filter(lambda c: c in SNAKE_CASE_CHARACTERS, almost_snake_case)) diff --git a/tofu_api/dependencies.py b/tofu_api/dependencies.py index b2698c6..59c11b5 100644 --- a/tofu_api/dependencies.py +++ b/tofu_api/dependencies.py @@ -1,6 +1,26 @@ +from typing import TypeVar + from sqlalchemy.orm import Session +from tofu_api.api.tasks import TaskHandler from tofu_api.common.database import SQLAlchemy +from tofu_api.repositories import TaskRepository + +T_Dep_Callable = TypeVar('T_Dep_Callable') + + +def cache_dependency(func: T_Dep_Callable) -> T_Dep_Callable: + """ + Decorator to be used in `Dependencies` to cache dependencies inside the Dependencies instance. + """ + dep_name = func.__name__ + + def wrapped_func(self: 'Dependencies'): + if dep_name not in self._dependency_cache: + self._dependency_cache[dep_name] = func(self) + return self._dependency_cache.get(dep_name) + + return wrapped_func class Dependencies: @@ -15,10 +35,22 @@ class Dependencies: # Database dependencies + @cache_dependency def get_sqlalchemy(self) -> SQLAlchemy: - if SQLAlchemy not in self._dependency_cache: - self._dependency_cache[SQLAlchemy] = SQLAlchemy() - return self._dependency_cache[SQLAlchemy] + return SQLAlchemy() + # No caching necessary here def get_db_session(self) -> Session: return self.get_sqlalchemy().session + + # Repository classes + + @cache_dependency + def get_task_repository(self) -> TaskRepository: + return TaskRepository(session=self.get_db_session()) + + # API Handler classes + + @cache_dependency + def get_task_handler(self) -> TaskHandler: + return TaskHandler(task_repository=self.get_task_repository()) diff --git a/tofu_api/models/base.py b/tofu_api/models/base.py index 2a22c89..b063539 100644 --- a/tofu_api/models/base.py +++ b/tofu_api/models/base.py @@ -1,14 +1,15 @@ -from typing import Any, Iterable, Optional - -from sqlalchemy import Column, Integer, inspect -from sqlalchemy.orm import InstanceState, as_declarative - -from tofu_api.common.database import Col, MetaData - __all__ = [ 'BaseModel', ] +from typing import Any, Iterable, Optional, Union + +from sqlalchemy import Column, Integer, inspect +from sqlalchemy.orm import InstanceState, as_declarative +from validataclass.dataclasses import ValidataclassMixin + +from tofu_api.common.database import Col, MetaData + @as_declarative(metadata=MetaData()) class BaseModel: @@ -60,3 +61,15 @@ class BaseModel: return { field: getattr(self, field) for field in included_fields } + + def update_from(self, data: Union[dict, ValidataclassMixin]) -> None: + """ + Updates the object with data from either a dictionary or a validataclass object (requires the ValidataclassMixin). + """ + if isinstance(data, ValidataclassMixin): + data = data.to_dict() + + # TODO: Is it a good idea to just iterate over data and setattr? Or should we check __table__.columns? + for key, value in data.items(): + if hasattr(self, key): + setattr(self, key, value) diff --git a/tofu_api/repositories/__init__.py b/tofu_api/repositories/__init__.py new file mode 100644 index 0000000..2ee789a --- /dev/null +++ b/tofu_api/repositories/__init__.py @@ -0,0 +1,2 @@ +from .base_repository import BaseRepository +from .task_repository import TaskRepository diff --git a/tofu_api/repositories/base_repository.py b/tofu_api/repositories/base_repository.py new file mode 100644 index 0000000..d8dae41 --- /dev/null +++ b/tofu_api/repositories/base_repository.py @@ -0,0 +1,94 @@ +__all__ = [ + 'BaseRepository', + 'T_Model', +] + +from abc import ABC, abstractmethod +from typing import Generic, Optional, Type, TypeVar + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from tofu_api.models import BaseModel +from .exceptions import ObjectNotFoundException + +T_Model = TypeVar('T_Model', bound=BaseModel) + + +class BaseRepository(Generic[T_Model], ABC): + """ + Base class for repositories. + """ + + # Database session + session: Session + + @property + @abstractmethod + def model_cls(self) -> Type[T_Model]: + """ + Set this to the model class. + """ + raise NotImplementedError + + def __init__(self, *, session: Session): + self.session = session + + @staticmethod + def _or_raise(resource: Optional[T_Model], exception_msg: Optional[str] = None) -> T_Model: + if resource is None: + raise ObjectNotFoundException(exception_msg) + return resource + + def fetch_by_id(self, resource_id: int) -> T_Model: + """ + Fetches a resource by ID. + + Raises an ObjectNotFoundException if no resource with the ID was found. + """ + resource = self.session.get(self.model_cls, resource_id) + return self._or_raise(resource, f'Resource with ID {resource_id} was not found.') + + def fetch_all(self) -> list[T_Model]: + """ + Fetches all resources of the repository type. + """ + return self.session.scalars( + select(self.model_cls) + ).all() + + def commit_session(self) -> None: + """ + Commits the current database session. + """ + self.commit_session() + + def rollback_session(self) -> None: + """ + Rolls back the current database session. + """ + self.rollback_session() + + def save_resource(self, *resources: T_Model, commit: bool = True) -> None: + """ + Saves one or multiple resources to the database by adding them to the session and committing the session. + Set `commit` to False to skip committing (the session will still be flushed, though). + """ + for resource in resources: + self.session.add(resource) + self.session.flush() + + if commit: + self.session.commit() + + def delete_resource(self, *resources: T_Model, commit: bool = True) -> None: + """ + Deletes one or multiple resources from the database. + Set `commit` to False to skip committing (the session will still be flushed, though). + """ + for resource in resources: + self.session.delete(resource) + self.session.flush() + + if commit: + self.session.commit() diff --git a/tofu_api/repositories/exceptions.py b/tofu_api/repositories/exceptions.py new file mode 100644 index 0000000..b3e5919 --- /dev/null +++ b/tofu_api/repositories/exceptions.py @@ -0,0 +1,9 @@ +from tofu_api.common.exceptions import AppException + + +class ObjectNotFoundException(AppException): + """ + Exception raised when a database object was not found, i.e. does not exist or is inaccessible for the user. + """ + status_code = 404 + code = 'not_found' diff --git a/tofu_api/repositories/task_repository.py b/tofu_api/repositories/task_repository.py new file mode 100644 index 0000000..0be0741 --- /dev/null +++ b/tofu_api/repositories/task_repository.py @@ -0,0 +1,9 @@ +from tofu_api.models import Task +from .base_repository import BaseRepository + + +class TaskRepository(BaseRepository[Task]): + """ + Repository for tasks. + """ + model_cls = Task