Implement task API with input validation; implement error handling; refactoring
This commit is contained in:
parent
19d264a03b
commit
50aff05614
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"dev": {
|
||||||
|
"api_host": "http://localhost:5000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
### Fetch all tasks
|
||||||
|
GET {{api_host}}/api/tasks
|
||||||
|
|
||||||
|
|
||||||
|
### Fetch one task
|
||||||
|
GET {{api_host}}/api/tasks/1
|
||||||
|
|
||||||
|
|
||||||
|
### Create new task
|
||||||
|
POST {{api_host}}/api/tasks
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"title": "Some test task"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Update task
|
||||||
|
PATCH {{api_host}}/api/tasks/1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"description": "Update!"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Delete task
|
||||||
|
DELETE {{api_host}}/api/tasks/10
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from tofu_api.common.rest import BaseBlueprint
|
from tofu_api.common.rest import BaseBlueprint
|
||||||
from .tasks import TaskApiBlueprint
|
from .tasks import TaskBlueprint
|
||||||
|
|
||||||
|
|
||||||
class TofuApiBlueprint(BaseBlueprint):
|
class TofuApiBlueprint(BaseBlueprint):
|
||||||
|
|
@ -12,4 +12,4 @@ class TofuApiBlueprint(BaseBlueprint):
|
||||||
url_prefix = '/api'
|
url_prefix = '/api'
|
||||||
|
|
||||||
def init_blueprint(self) -> None:
|
def init_blueprint(self) -> None:
|
||||||
self.register_blueprint(TaskApiBlueprint(self.app))
|
self.register_blueprint(TaskBlueprint(self.app))
|
||||||
|
|
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
from .task_api import TaskApiBlueprint
|
from .task_blueprint import TaskBlueprint
|
||||||
|
from .task_handler import TaskHandler
|
||||||
|
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
from flask import jsonify
|
|
||||||
from flask.views import MethodView
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from werkzeug.exceptions import NotFound
|
|
||||||
|
|
||||||
from tofu_api.common.rest import BaseBlueprint
|
|
||||||
from tofu_api.models import Task
|
|
||||||
|
|
||||||
|
|
||||||
class TaskApiBlueprint(BaseBlueprint):
|
|
||||||
"""
|
|
||||||
Blueprint for the tasks REST API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Blueprint settings
|
|
||||||
name = 'rest_api_tasks'
|
|
||||||
import_name = __name__
|
|
||||||
url_prefix = '/tasks'
|
|
||||||
|
|
||||||
def init_blueprint(self) -> None:
|
|
||||||
"""
|
|
||||||
Register URL rules.
|
|
||||||
"""
|
|
||||||
db_session = self.app.dependencies.get_db_session()
|
|
||||||
|
|
||||||
self.add_url_rule(
|
|
||||||
'',
|
|
||||||
view_func=self.create_view_func(TaskCollectionView, db_session=db_session),
|
|
||||||
methods=['GET', 'POST'],
|
|
||||||
)
|
|
||||||
self.add_url_rule(
|
|
||||||
'/<int:task_id>',
|
|
||||||
view_func=self.create_view_func(TaskItemView, db_session=db_session),
|
|
||||||
methods=['GET', 'PATCH', 'DELETE'],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskBaseView(MethodView):
|
|
||||||
"""
|
|
||||||
Base class for view classes for the `/tasks` endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: Use a handler class instead of accessing the database session directly
|
|
||||||
db_session: Session
|
|
||||||
|
|
||||||
def __init__(self, *, db_session: Session):
|
|
||||||
self.db_session = db_session
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCollectionView(TaskBaseView):
|
|
||||||
"""
|
|
||||||
View class for `/tasks` endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
"""
|
|
||||||
Get list of all tasks.
|
|
||||||
"""
|
|
||||||
task_list = self.db_session.query(Task).all()
|
|
||||||
return jsonify({
|
|
||||||
'count': len(task_list),
|
|
||||||
'items': [task.to_dict() for task in task_list],
|
|
||||||
}), 200
|
|
||||||
|
|
||||||
def post(self):
|
|
||||||
"""
|
|
||||||
Create a new task.
|
|
||||||
"""
|
|
||||||
# TODO: Parse request data and create real data
|
|
||||||
new_task = Task(
|
|
||||||
title='Do stuff'
|
|
||||||
)
|
|
||||||
self.db_session.add(new_task)
|
|
||||||
self.db_session.commit()
|
|
||||||
return jsonify(new_task.to_dict()), 201
|
|
||||||
|
|
||||||
|
|
||||||
class TaskItemView(TaskBaseView):
|
|
||||||
"""
|
|
||||||
View class for `/tasks/<int:task_id>` endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get(self, task_id: int):
|
|
||||||
"""
|
|
||||||
Get a single task by ID.
|
|
||||||
"""
|
|
||||||
task = self.db_session.query(Task).get(task_id)
|
|
||||||
if task is None:
|
|
||||||
raise NotFound(f'Task with ID {task_id} not found!')
|
|
||||||
return jsonify(task.to_dict()), 200
|
|
||||||
|
|
||||||
def patch(self, task_id: int):
|
|
||||||
"""
|
|
||||||
Update a single task by ID.
|
|
||||||
"""
|
|
||||||
# TODO: Implement
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def delete(self, task_id: int):
|
|
||||||
"""
|
|
||||||
Delete a single task by ID.
|
|
||||||
"""
|
|
||||||
# TODO: Implement
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
from tofu_api.common.rest import BaseBlueprint
|
||||||
|
from .task_views import TaskCollectionView, TaskItemView
|
||||||
|
|
||||||
|
|
||||||
|
class TaskBlueprint(BaseBlueprint):
|
||||||
|
"""
|
||||||
|
Blueprint for the tasks REST API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Blueprint settings
|
||||||
|
name = 'rest_api_tasks'
|
||||||
|
import_name = __name__
|
||||||
|
url_prefix = '/tasks'
|
||||||
|
|
||||||
|
def init_blueprint(self) -> None:
|
||||||
|
"""
|
||||||
|
Register URL rules.
|
||||||
|
"""
|
||||||
|
task_handler = self.app.dependencies.get_task_handler()
|
||||||
|
|
||||||
|
self.add_url_rule(
|
||||||
|
'',
|
||||||
|
view_func=TaskCollectionView.as_view(task_handler=task_handler),
|
||||||
|
methods=['GET', 'POST'],
|
||||||
|
)
|
||||||
|
self.add_url_rule(
|
||||||
|
'/<int:task_id>',
|
||||||
|
view_func=TaskItemView.as_view(task_handler=task_handler),
|
||||||
|
methods=['GET', 'PATCH', 'DELETE'],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
from tofu_api.models import Task
|
||||||
|
from tofu_api.repositories import TaskRepository
|
||||||
|
from .validators import TaskCreateData, TaskUpdateData
|
||||||
|
|
||||||
|
|
||||||
|
class TaskHandler:
|
||||||
|
"""
|
||||||
|
Handles operations on tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_repository: TaskRepository
|
||||||
|
|
||||||
|
def __init__(self, *, task_repository: TaskRepository):
|
||||||
|
self.task_repository = task_repository
|
||||||
|
|
||||||
|
def fetch_task(self, task_id: int) -> Task:
|
||||||
|
"""
|
||||||
|
Fetches a single task by its ID from the database.
|
||||||
|
Raises an ObjectNotFoundError if the task was not found.
|
||||||
|
"""
|
||||||
|
return self.task_repository.fetch_by_id(task_id)
|
||||||
|
|
||||||
|
def fetch_all_tasks(self) -> list[Task]:
|
||||||
|
"""
|
||||||
|
Fetches a list of all tasks.
|
||||||
|
"""
|
||||||
|
return self.task_repository.fetch_all()
|
||||||
|
|
||||||
|
def create_task(self, create_data: TaskCreateData) -> Task:
|
||||||
|
"""
|
||||||
|
Creates a new task, saves it to the database and returns it.
|
||||||
|
"""
|
||||||
|
task = Task()
|
||||||
|
task.update_from(create_data)
|
||||||
|
self.task_repository.save_resource(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def update_task(self, task: Task, update_data: TaskUpdateData) -> Task:
|
||||||
|
"""
|
||||||
|
Updates a Task object with new data.
|
||||||
|
"""
|
||||||
|
task.update_from(update_data)
|
||||||
|
self.task_repository.save_resource(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def delete_task(self, task: Task) -> None:
|
||||||
|
"""
|
||||||
|
Deletes a task from the database.
|
||||||
|
"""
|
||||||
|
self.task_repository.delete_resource(task)
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
from flask import jsonify
|
||||||
|
from validataclass.validators import DataclassValidator
|
||||||
|
|
||||||
|
from tofu_api.common.rest import BaseMethodView
|
||||||
|
from .task_handler import TaskHandler
|
||||||
|
from .validators import TaskCreateData, TaskUpdateData
|
||||||
|
|
||||||
|
|
||||||
|
class TaskBaseView(BaseMethodView):
|
||||||
|
"""
|
||||||
|
Base class for view classes for the `/tasks` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_handler: TaskHandler
|
||||||
|
|
||||||
|
def __init__(self, *, task_handler: TaskHandler):
|
||||||
|
self.task_handler = task_handler
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCollectionView(TaskBaseView):
|
||||||
|
"""
|
||||||
|
View class for `/api/tasks` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Validators
|
||||||
|
task_create_validator = DataclassValidator(TaskCreateData)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get list of all tasks.
|
||||||
|
"""
|
||||||
|
task_list = self.task_handler.fetch_all_tasks()
|
||||||
|
return jsonify({
|
||||||
|
'items': [task.to_dict() for task in task_list],
|
||||||
|
'total_count': len(task_list),
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
def post(self):
|
||||||
|
"""
|
||||||
|
Create a new task.
|
||||||
|
"""
|
||||||
|
# Parse request data
|
||||||
|
create_data: TaskCreateData = self.validate_request_data(self.task_create_validator)
|
||||||
|
|
||||||
|
# Create new task
|
||||||
|
new_task = self.task_handler.create_task(create_data)
|
||||||
|
|
||||||
|
# Return new task as JSON
|
||||||
|
return jsonify(new_task.to_dict()), 201
|
||||||
|
|
||||||
|
|
||||||
|
class TaskItemView(TaskBaseView):
|
||||||
|
"""
|
||||||
|
View class for `/api/tasks/<int:task_id>` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Validators
|
||||||
|
task_update_validator = DataclassValidator(TaskUpdateData)
|
||||||
|
|
||||||
|
def get(self, task_id: int):
|
||||||
|
"""
|
||||||
|
Get a single task by ID.
|
||||||
|
"""
|
||||||
|
task = self.task_handler.fetch_task(task_id)
|
||||||
|
return jsonify(task.to_dict()), 200
|
||||||
|
|
||||||
|
def patch(self, task_id: int):
|
||||||
|
"""
|
||||||
|
Update a single task by ID.
|
||||||
|
"""
|
||||||
|
# Parse request data
|
||||||
|
update_data: TaskUpdateData = self.validate_request_data(self.task_update_validator)
|
||||||
|
|
||||||
|
# Fetch task and update
|
||||||
|
task = self.task_handler.fetch_task(task_id)
|
||||||
|
task = self.task_handler.update_task(task, update_data)
|
||||||
|
|
||||||
|
# Return updated task as JSON
|
||||||
|
return jsonify(task.to_dict()), 200
|
||||||
|
|
||||||
|
def delete(self, task_id: int):
|
||||||
|
"""
|
||||||
|
Delete a single task by ID.
|
||||||
|
"""
|
||||||
|
# Fetch task and delete
|
||||||
|
task = self.task_handler.fetch_task(task_id)
|
||||||
|
self.task_handler.delete_task(task)
|
||||||
|
|
||||||
|
return self.empty_response()
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from validataclass.dataclasses import Default, DefaultUnset, ValidataclassMixin, validataclass
|
||||||
|
from validataclass.helpers import OptionalUnset
|
||||||
|
from validataclass.validators import StringValidator
|
||||||
|
|
||||||
|
|
||||||
|
@validataclass
|
||||||
|
class TaskCreateData(ValidataclassMixin):
|
||||||
|
"""
|
||||||
|
Dataclass for "create task" request data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
title: str = StringValidator(min_length=1, max_length=200)
|
||||||
|
description: str = StringValidator(max_length=2000), Default('')
|
||||||
|
|
||||||
|
|
||||||
|
@validataclass
|
||||||
|
class TaskUpdateData(TaskCreateData):
|
||||||
|
"""
|
||||||
|
Dataclass for "update task" request data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
title: OptionalUnset[str] = DefaultUnset
|
||||||
|
description: OptionalUnset[str] = DefaultUnset
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -7,6 +8,7 @@ from flask import Flask
|
||||||
from tofu_api.api import TofuApiBlueprint
|
from tofu_api.api import TofuApiBlueprint
|
||||||
from tofu_api.common.config import Config
|
from tofu_api.common.config import Config
|
||||||
from tofu_api.common.json import JSONProvider
|
from tofu_api.common.json import JSONProvider
|
||||||
|
from tofu_api.common.rest import RestApiErrorHandler
|
||||||
from tofu_api.dependencies import Dependencies
|
from tofu_api.dependencies import Dependencies
|
||||||
|
|
||||||
# Enable deprecation warnings in dev environment
|
# Enable deprecation warnings in dev environment
|
||||||
|
|
@ -42,6 +44,10 @@ class App(Flask):
|
||||||
# Load app configuration from YAML file
|
# Load app configuration from YAML file
|
||||||
self.config.from_yaml(os.getenv('FLASK_CONFIG_FILE', default='config.yml'))
|
self.config.from_yaml(os.getenv('FLASK_CONFIG_FILE', default='config.yml'))
|
||||||
|
|
||||||
|
# Configure logging and error handling
|
||||||
|
self.configure_logging()
|
||||||
|
self.configure_error_handling()
|
||||||
|
|
||||||
# Initialize DI container
|
# Initialize DI container
|
||||||
self.dependencies = Dependencies()
|
self.dependencies = Dependencies()
|
||||||
|
|
||||||
|
|
@ -51,6 +57,24 @@ class App(Flask):
|
||||||
# Register blueprints
|
# Register blueprints
|
||||||
self.register_blueprint(TofuApiBlueprint(self))
|
self.register_blueprint(TofuApiBlueprint(self))
|
||||||
|
|
||||||
|
def configure_logging(self) -> None:
|
||||||
|
"""
|
||||||
|
Configures the logging system.
|
||||||
|
"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if self.debug else logging.INFO,
|
||||||
|
format='{asctime}.{msecs:03.0f} {levelname:>8} [{name}] {message}',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
style='{',
|
||||||
|
)
|
||||||
|
|
||||||
|
def configure_error_handling(self) -> None:
|
||||||
|
"""
|
||||||
|
Registers error handlers to the app.
|
||||||
|
"""
|
||||||
|
error_handler = RestApiErrorHandler(debug_mode=self.debug)
|
||||||
|
error_handler.register_error_handlers(self)
|
||||||
|
|
||||||
def init_database(self) -> None:
|
def init_database(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize database connection and models.
|
Initialize database connection and models.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .base import AppException
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AppException(Exception):
|
||||||
|
"""
|
||||||
|
Base class for application specific exceptions that can also be used as API error responses.
|
||||||
|
"""
|
||||||
|
code: str = 'unspecified_error'
|
||||||
|
status_code: int = 400
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, code: Optional[str] = None, status_code: Optional[int] = None):
|
||||||
|
if code is not None:
|
||||||
|
self.code = code
|
||||||
|
if status_code is not None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
'code': self.code,
|
||||||
|
'message': self.message,
|
||||||
|
}
|
||||||
|
|
@ -1 +1,3 @@
|
||||||
from .base_blueprint import BaseBlueprint
|
from .base_blueprint import BaseBlueprint
|
||||||
|
from .base_method_view import BaseMethodView
|
||||||
|
from .error_handler import RestApiErrorHandler
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,11 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Type, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask.views import View
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tofu_api.app import App
|
from tofu_api.app import App
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseBlueprint',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBlueprint(Blueprint, ABC):
|
class BaseBlueprint(Blueprint, ABC):
|
||||||
"""
|
"""
|
||||||
|
|
@ -62,11 +57,3 @@ class BaseBlueprint(Blueprint, ABC):
|
||||||
Register child blueprints and URL rules.
|
Register child blueprints and URL rules.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_view_func(view_cls: Type[View], *args, **kwargs) -> Callable:
|
|
||||||
"""
|
|
||||||
Helper function to create a view function from a `View` class using `view_cls.as_view()`.
|
|
||||||
All arguments are passed to the constructor of the view class.
|
|
||||||
"""
|
|
||||||
return view_cls.as_view(view_cls.__name__, *args, **kwargs)
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import flask
|
||||||
|
from flask.typing import RouteCallable
|
||||||
|
from flask.views import MethodView
|
||||||
|
from validataclass.validators import Validator
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMethodView(MethodView):
|
||||||
|
"""
|
||||||
|
Base class for REST API views.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def request(self) -> flask.Request:
|
||||||
|
return flask.request
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_view(cls, *args, **kwargs) -> RouteCallable:
|
||||||
|
return super().as_view(cls.__name__, *args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty_response(code: int = 204) -> tuple[str, int]:
|
||||||
|
return '', code
|
||||||
|
|
||||||
|
def validate_request_data(self, validator: Validator) -> Any:
|
||||||
|
"""
|
||||||
|
Parses request data as JSON and validates it using a validataclass validator.
|
||||||
|
"""
|
||||||
|
# TODO error handling: wrong content type; empty body; invalid json; validation errors
|
||||||
|
parsed_json = self.request.json
|
||||||
|
return validator.validate(parsed_json)
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from flask import Flask, Response, jsonify
|
||||||
|
from werkzeug.exceptions import HTTPException
|
||||||
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
|
from tofu_api.common import string_utils
|
||||||
|
from tofu_api.common.exceptions import AppException
|
||||||
|
from .exceptions import InternalServerError
|
||||||
|
|
||||||
|
T_Response = Union[Response, tuple[Response, int]]
|
||||||
|
|
||||||
|
|
||||||
|
class RestApiErrorHandler:
|
||||||
|
"""
|
||||||
|
Error handler class for REST API errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
logger: logging.Logger
|
||||||
|
|
||||||
|
# Options
|
||||||
|
debug_mode: bool = False
|
||||||
|
|
||||||
|
# Lookup table for HTTP status codes to API error codes
|
||||||
|
_http_to_api_error_codes: dict[int, str]
|
||||||
|
|
||||||
|
def __init__(self, *, debug_mode: bool = False):
|
||||||
|
self.logger = logging.getLogger(type(self).__name__)
|
||||||
|
self.debug_mode = debug_mode
|
||||||
|
|
||||||
|
# Generate lookup table for HTTP status codes
|
||||||
|
self._http_to_api_error_codes = {
|
||||||
|
http_status: string_utils.str_to_snake_case(name)
|
||||||
|
for http_status, name in HTTP_STATUS_CODES.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_error_handlers(self, app: Flask) -> None:
|
||||||
|
"""
|
||||||
|
Registers error handlers for different types of exceptions to the app.
|
||||||
|
"""
|
||||||
|
app.register_error_handler(AppException, self.handle_app_exception)
|
||||||
|
app.register_error_handler(HTTPException, self.handle_http_exception)
|
||||||
|
app.register_error_handler(Exception, self.handle_generic_exception)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_app_exception(exception: AppException) -> T_Response:
|
||||||
|
"""
|
||||||
|
Handles exceptions of type `AppException` that were not handled by any more specific handler.
|
||||||
|
"""
|
||||||
|
return jsonify(exception.to_dict()), exception.status_code
|
||||||
|
|
||||||
|
def handle_http_exception(self, exception: HTTPException) -> T_Response:
|
||||||
|
"""
|
||||||
|
Handles exceptions of type `HTTPException`, i.e. any werkzeug HTTP exceptions.
|
||||||
|
"""
|
||||||
|
if exception.code >= 500:
|
||||||
|
self.logger.exception('HTTP exception with status code %s: %s', exception.code, type(exception).__name__)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'code': self._http_to_api_error_codes.get(exception.code, 'unknown_http_error'),
|
||||||
|
'message': exception.description,
|
||||||
|
}), exception.code
|
||||||
|
|
||||||
|
def handle_generic_exception(self, exception: Exception) -> T_Response:
|
||||||
|
"""
|
||||||
|
Fallback handler for any exceptions not handled by any other handler.
|
||||||
|
"""
|
||||||
|
self.logger.exception('Uncaught exception: %s', type(exception).__name__)
|
||||||
|
|
||||||
|
wrapped_exception = InternalServerError('There was an uncaught error on the server.', inner_exception=exception)
|
||||||
|
return jsonify(wrapped_exception.to_dict(debug=self.debug_mode)), wrapped_exception.status_code
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
__all__ = [
|
||||||
|
'InternalServerError'
|
||||||
|
]
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tofu_api.common.exceptions import AppException
|
||||||
|
|
||||||
|
|
||||||
|
class InternalServerError(AppException):
|
||||||
|
"""
|
||||||
|
Wrapper exception for any uncaught exception.
|
||||||
|
"""
|
||||||
|
status_code = 500
|
||||||
|
code = 'internal_server_error'
|
||||||
|
inner_exception: Optional[Exception] = None
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, inner_exception: Optional[Exception] = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.inner_exception = inner_exception
|
||||||
|
|
||||||
|
def to_dict(self, *, debug: bool = False) -> dict:
|
||||||
|
data = super().to_dict()
|
||||||
|
if debug:
|
||||||
|
data['_debug'] = {
|
||||||
|
'exception': str(self.inner_exception),
|
||||||
|
'traceback': traceback.format_exception(self.inner_exception),
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
__all__ = [
|
||||||
|
'SNAKE_CASE_CHARACTERS',
|
||||||
|
'is_snake_case',
|
||||||
|
'str_to_snake_case',
|
||||||
|
]
|
||||||
|
|
||||||
|
import string
|
||||||
|
|
||||||
|
SNAKE_CASE_CHARACTERS = string.ascii_lowercase + string.digits + '_'
|
||||||
|
|
||||||
|
|
||||||
|
def is_snake_case(input_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the input string only consists of snake case characters (lowercase letters, digits, underscore).
|
||||||
|
"""
|
||||||
|
return all(c in SNAKE_CASE_CHARACTERS for c in input_str)
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_snake_case(input_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Converts any string to a snake case string: Whitespaces are replaced with an underscore, uppercase letters are
|
||||||
|
converted to lowercase, and any non-alphanumeric character is removed.
|
||||||
|
"""
|
||||||
|
# First, lowercase string and replace any consecutive whitespaces with a single underscore
|
||||||
|
almost_snake_case = '_'.join(input_str.lower().split())
|
||||||
|
|
||||||
|
# Now, remove all characters that are neither letters, digits, nor underscores
|
||||||
|
return ''.join(filter(lambda c: c in SNAKE_CASE_CHARACTERS, almost_snake_case))
|
||||||
|
|
@ -1,6 +1,26 @@
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tofu_api.api.tasks import TaskHandler
|
||||||
from tofu_api.common.database import SQLAlchemy
|
from tofu_api.common.database import SQLAlchemy
|
||||||
|
from tofu_api.repositories import TaskRepository
|
||||||
|
|
||||||
|
T_Dep_Callable = TypeVar('T_Dep_Callable')
|
||||||
|
|
||||||
|
|
||||||
|
def cache_dependency(func: T_Dep_Callable) -> T_Dep_Callable:
|
||||||
|
"""
|
||||||
|
Decorator to be used in `Dependencies` to cache dependencies inside the Dependencies instance.
|
||||||
|
"""
|
||||||
|
dep_name = func.__name__
|
||||||
|
|
||||||
|
def wrapped_func(self: 'Dependencies'):
|
||||||
|
if dep_name not in self._dependency_cache:
|
||||||
|
self._dependency_cache[dep_name] = func(self)
|
||||||
|
return self._dependency_cache.get(dep_name)
|
||||||
|
|
||||||
|
return wrapped_func
|
||||||
|
|
||||||
|
|
||||||
class Dependencies:
|
class Dependencies:
|
||||||
|
|
@ -15,10 +35,22 @@ class Dependencies:
|
||||||
|
|
||||||
# Database dependencies
|
# Database dependencies
|
||||||
|
|
||||||
|
@cache_dependency
|
||||||
def get_sqlalchemy(self) -> SQLAlchemy:
|
def get_sqlalchemy(self) -> SQLAlchemy:
|
||||||
if SQLAlchemy not in self._dependency_cache:
|
return SQLAlchemy()
|
||||||
self._dependency_cache[SQLAlchemy] = SQLAlchemy()
|
|
||||||
return self._dependency_cache[SQLAlchemy]
|
|
||||||
|
|
||||||
|
# No caching necessary here
|
||||||
def get_db_session(self) -> Session:
|
def get_db_session(self) -> Session:
|
||||||
return self.get_sqlalchemy().session
|
return self.get_sqlalchemy().session
|
||||||
|
|
||||||
|
# Repository classes
|
||||||
|
|
||||||
|
@cache_dependency
|
||||||
|
def get_task_repository(self) -> TaskRepository:
|
||||||
|
return TaskRepository(session=self.get_db_session())
|
||||||
|
|
||||||
|
# API Handler classes
|
||||||
|
|
||||||
|
@cache_dependency
|
||||||
|
def get_task_handler(self) -> TaskHandler:
|
||||||
|
return TaskHandler(task_repository=self.get_task_repository())
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
from typing import Any, Iterable, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, inspect
|
|
||||||
from sqlalchemy.orm import InstanceState, as_declarative
|
|
||||||
|
|
||||||
from tofu_api.common.database import Col, MetaData
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseModel',
|
'BaseModel',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from typing import Any, Iterable, Optional, Union
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, inspect
|
||||||
|
from sqlalchemy.orm import InstanceState, as_declarative
|
||||||
|
from validataclass.dataclasses import ValidataclassMixin
|
||||||
|
|
||||||
|
from tofu_api.common.database import Col, MetaData
|
||||||
|
|
||||||
|
|
||||||
@as_declarative(metadata=MetaData())
|
@as_declarative(metadata=MetaData())
|
||||||
class BaseModel:
|
class BaseModel:
|
||||||
|
|
@ -60,3 +61,15 @@ class BaseModel:
|
||||||
return {
|
return {
|
||||||
field: getattr(self, field) for field in included_fields
|
field: getattr(self, field) for field in included_fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def update_from(self, data: Union[dict, ValidataclassMixin]) -> None:
|
||||||
|
"""
|
||||||
|
Updates the object with data from either a dictionary or a validataclass object (requires the ValidataclassMixin).
|
||||||
|
"""
|
||||||
|
if isinstance(data, ValidataclassMixin):
|
||||||
|
data = data.to_dict()
|
||||||
|
|
||||||
|
# TODO: Is it a good idea to just iterate over data and setattr? Or should we check __table__.columns?
|
||||||
|
for key, value in data.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .base_repository import BaseRepository
|
||||||
|
from .task_repository import TaskRepository
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
__all__ = [
|
||||||
|
'BaseRepository',
|
||||||
|
'T_Model',
|
||||||
|
]
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, Optional, Type, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tofu_api.models import BaseModel
|
||||||
|
from .exceptions import ObjectNotFoundException
|
||||||
|
|
||||||
|
T_Model = TypeVar('T_Model', bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepository(Generic[T_Model], ABC):
|
||||||
|
"""
|
||||||
|
Base class for repositories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Database session
|
||||||
|
session: Session
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model_cls(self) -> Type[T_Model]:
|
||||||
|
"""
|
||||||
|
Set this to the model class.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, *, session: Session):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _or_raise(resource: Optional[T_Model], exception_msg: Optional[str] = None) -> T_Model:
|
||||||
|
if resource is None:
|
||||||
|
raise ObjectNotFoundException(exception_msg)
|
||||||
|
return resource
|
||||||
|
|
||||||
|
def fetch_by_id(self, resource_id: int) -> T_Model:
|
||||||
|
"""
|
||||||
|
Fetches a resource by ID.
|
||||||
|
|
||||||
|
Raises an ObjectNotFoundException if no resource with the ID was found.
|
||||||
|
"""
|
||||||
|
resource = self.session.get(self.model_cls, resource_id)
|
||||||
|
return self._or_raise(resource, f'Resource with ID {resource_id} was not found.')
|
||||||
|
|
||||||
|
def fetch_all(self) -> list[T_Model]:
|
||||||
|
"""
|
||||||
|
Fetches all resources of the repository type.
|
||||||
|
"""
|
||||||
|
return self.session.scalars(
|
||||||
|
select(self.model_cls)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
def commit_session(self) -> None:
|
||||||
|
"""
|
||||||
|
Commits the current database session.
|
||||||
|
"""
|
||||||
|
self.commit_session()
|
||||||
|
|
||||||
|
def rollback_session(self) -> None:
|
||||||
|
"""
|
||||||
|
Rolls back the current database session.
|
||||||
|
"""
|
||||||
|
self.rollback_session()
|
||||||
|
|
||||||
|
def save_resource(self, *resources: T_Model, commit: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Saves one or multiple resources to the database by adding them to the session and committing the session.
|
||||||
|
Set `commit` to False to skip committing (the session will still be flushed, though).
|
||||||
|
"""
|
||||||
|
for resource in resources:
|
||||||
|
self.session.add(resource)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def delete_resource(self, *resources: T_Model, commit: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Deletes one or multiple resources from the database.
|
||||||
|
Set `commit` to False to skip committing (the session will still be flushed, though).
|
||||||
|
"""
|
||||||
|
for resource in resources:
|
||||||
|
self.session.delete(resource)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
self.session.commit()
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
from tofu_api.common.exceptions import AppException
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectNotFoundException(AppException):
|
||||||
|
"""
|
||||||
|
Exception raised when a database object was not found, i.e. does not exist or is inaccessible for the user.
|
||||||
|
"""
|
||||||
|
status_code = 404
|
||||||
|
code = 'not_found'
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
from tofu_api.models import Task
|
||||||
|
from .base_repository import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRepository(BaseRepository[Task]):
|
||||||
|
"""
|
||||||
|
Repository for tasks.
|
||||||
|
"""
|
||||||
|
model_cls = Task
|
||||||
Loading…
Reference in New Issue