diff --git a/TASK.md b/TASK.md index 501e133..6708bcf 100644 --- a/TASK.md +++ b/TASK.md @@ -102,6 +102,7 @@ ## 2025-11-11 - [x] Pin Reticulum (RNS) to 1.0.2 and LXMF to 0.9.2 in project dependencies. - [x] Centralise payload conversion utilities and refactor EmergencyManagement client and gateway to use them. +- [x] Extract SQLAlchemy controller mixin for async CRUD helpers and refactor emergency controllers/tests. ## 2025-11-12 - [x] Introduce FastAPI integration helpers for LXMF configuration, dependencies, and command routing. diff --git a/examples/EmergencyManagement/Server/controllers_emergency.py b/examples/EmergencyManagement/Server/controllers_emergency.py index 0773f8f..b16c09a 100644 --- a/examples/EmergencyManagement/Server/controllers_emergency.py +++ b/examples/EmergencyManagement/Server/controllers_emergency.py @@ -1,121 +1,38 @@ -from dataclasses import asdict -from typing import Any from typing import Dict from typing import List from typing import Optional -from typing import Type -from typing import TypeVar from reticulum_openapi.controller import Controller from reticulum_openapi.controller import handle_exceptions -from reticulum_openapi.model import BaseModel +from reticulum_openapi.sqlalchemy_controller import SQLAlchemyControllerMixin +from reticulum_openapi.sqlalchemy_controller import SessionFactory from examples.EmergencyManagement.Server import database from examples.EmergencyManagement.Server.models_emergency import EmergencyActionMessage from examples.EmergencyManagement.Server.models_emergency import Event -ModelT = TypeVar("ModelT", bound=BaseModel) +class _BaseDatabaseController(SQLAlchemyControllerMixin, Controller): + """Shared database integration helpers for emergency controllers.""" -# Backwards compatibility shim allowing tests to override the session factory. -async_session = None + def get_default_session_factory(self) -> Optional[SessionFactory]: + """Return the configured async session factory from the database module.""" + return database.async_session -def _require_session_factory(): - """Return the configured async session factory or raise an error.""" - if async_session is not None: - return async_session - if database.async_session is None: - raise RuntimeError("Database session factory is not configured") - return database.async_session - - -def _get_primary_key_column(model: Type[ModelT]): - """Return the SQLAlchemy column representing the model primary key.""" - - orm_model = getattr(model, "__orm_model__", None) - if orm_model is None: - raise RuntimeError(f"{model.__name__} does not define an ORM mapping") - primary_key_columns = list(orm_model.__table__.primary_key.columns) - if len(primary_key_columns) != 1: - raise RuntimeError( - f"{model.__name__} must define exactly one primary key column" - ) - return primary_key_columns[0] - - -def _coerce_identifier(model: Type[ModelT], identifier: Any) -> Any: - """Convert an identifier into the Python type expected by the ORM column.""" - - column = _get_primary_key_column(model) - python_type = getattr(column.type, "python_type", None) - if python_type is None or isinstance(identifier, python_type): - return identifier - try: - return python_type(identifier) - except (TypeError, ValueError) as exc: - raise ValueError( - f"Invalid identifier for {model.__name__}: {identifier!r}" - ) from exc - - -async def _create_instance(model: Type[ModelT], payload: ModelT) -> ModelT: - """Persist ``payload`` using the model helper and return the stored instance.""" - - session_factory = _require_session_factory() - async with session_factory() as session: - return await model.create(session, **asdict(payload)) - - -async def _update_instance(model: Type[ModelT], payload: ModelT) -> Optional[ModelT]: - """Update ``payload`` using the model helper and return the refreshed instance.""" - - identifier_name = _get_primary_key_column(model).name - identifier = getattr(payload, identifier_name) - session_factory = _require_session_factory() - async with session_factory() as session: - return await model.update(session, identifier, **asdict(payload)) - - -async def _retrieve_instance(model: Type[ModelT], identifier: Any) -> Optional[ModelT]: - """Return a stored instance or ``None`` when the identifier is unknown.""" - - resolved_identifier = _coerce_identifier(model, identifier) - session_factory = _require_session_factory() - async with session_factory() as session: - return await model.get(session, resolved_identifier) - - -async def _delete_instance(model: Type[ModelT], identifier: Any) -> bool: - """Delete the record referenced by ``identifier``.""" - - resolved_identifier = _coerce_identifier(model, identifier) - session_factory = _require_session_factory() - async with session_factory() as session: - return await model.delete(session, resolved_identifier) - - -async def _list_instances(model: Type[ModelT]) -> List[ModelT]: - """Return all stored instances for ``model``.""" - - session_factory = _require_session_factory() - async with session_factory() as session: - return await model.list(session) - - -class EmergencyController(Controller): +class EmergencyController(_BaseDatabaseController): @handle_exceptions async def CreateEmergencyActionMessage( self, req: EmergencyActionMessage ) -> EmergencyActionMessage: self.logger.info(f"CreateEAM: {req}") - return await _create_instance(EmergencyActionMessage, req) + return await self._create_instance(EmergencyActionMessage, req) @handle_exceptions async def DeleteEmergencyActionMessage(self, callsign: str) -> Dict[str, str]: self.logger.info(f"DeleteEAM callsign={callsign}") - deleted = await _delete_instance(EmergencyActionMessage, callsign) + deleted = await self._delete_instance(EmergencyActionMessage, callsign) return {"status": "deleted" if deleted else "not_found", "callsign": callsign} @handle_exceptions @@ -123,7 +40,7 @@ async def ListEmergencyActionMessage( self, ) -> List[EmergencyActionMessage]: self.logger.info("ListEAM") - return await _list_instances(EmergencyActionMessage) + return await self._list_instances(EmergencyActionMessage) @handle_exceptions async def PutEmergencyActionMessage( @@ -138,32 +55,32 @@ async def PutEmergencyActionMessage( Optional[EmergencyActionMessage]: Updated dataclass instance or ``None`` if not found. """ self.logger.info(f"PutEAM: {req}") - return await _update_instance(EmergencyActionMessage, req) + return await self._update_instance(EmergencyActionMessage, req) @handle_exceptions async def RetrieveEmergencyActionMessage( self, callsign: str ) -> Optional[EmergencyActionMessage]: self.logger.info(f"RetrieveEAM callsign={callsign}") - return await _retrieve_instance(EmergencyActionMessage, callsign) + return await self._retrieve_instance(EmergencyActionMessage, callsign) -class EventController(Controller): +class EventController(_BaseDatabaseController): @handle_exceptions async def CreateEvent(self, req: Event) -> Event: self.logger.info(f"CreateEvent: {req}") - return await _create_instance(Event, req) + return await self._create_instance(Event, req) @handle_exceptions async def DeleteEvent(self, uid: str) -> Dict[str, str]: self.logger.info(f"DeleteEvent uid={uid}") - deleted = await _delete_instance(Event, uid) + deleted = await self._delete_instance(Event, uid) return {"status": "deleted" if deleted else "not_found", "uid": uid} @handle_exceptions async def ListEvent(self) -> List[Event]: self.logger.info("ListEvent") - return await _list_instances(Event) + return await self._list_instances(Event) @handle_exceptions async def PutEvent(self, req: Event) -> Optional[Event]: @@ -176,9 +93,9 @@ async def PutEvent(self, req: Event) -> Optional[Event]: Optional[Event]: Updated dataclass instance or ``None`` if not found. """ self.logger.info(f"PutEvent: {req}") - return await _update_instance(Event, req) + return await self._update_instance(Event, req) @handle_exceptions async def RetrieveEvent(self, uid: str) -> Optional[Event]: self.logger.info(f"RetrieveEvent uid={uid}") - return await _retrieve_instance(Event, uid) + return await self._retrieve_instance(Event, uid) diff --git a/reticulum_openapi/sqlalchemy_controller.py b/reticulum_openapi/sqlalchemy_controller.py new file mode 100644 index 0000000..59fe4cc --- /dev/null +++ b/reticulum_openapi/sqlalchemy_controller.py @@ -0,0 +1,131 @@ +"""Shared SQLAlchemy controller helpers for async CRUD operations.""" + +from __future__ import annotations + +from contextlib import AbstractAsyncContextManager +from dataclasses import asdict +from typing import Any, Callable, Optional, Type, TypeVar + +from .model import BaseModel + + +ModelT = TypeVar("ModelT", bound=BaseModel) +SessionFactory = Callable[[], AbstractAsyncContextManager[Any]] + + +class SQLAlchemyControllerMixin: + """Provide reusable async CRUD helpers for SQLAlchemy-backed controllers.""" + + session_factory: Optional[SessionFactory] = None + + def __init__( + self, + session_factory: Optional[SessionFactory] = None, + **kwargs: Any, + ) -> None: + """Initialise the mixin with an optional session factory override.""" + + super().__init__(**kwargs) + self._session_factory_override = session_factory + + @classmethod + def configure_session_factory( + cls, + session_factory: Optional[SessionFactory], + ) -> None: + """Set a class-level session factory used by all controller instances.""" + + cls.session_factory = session_factory + + def get_default_session_factory(self) -> Optional[SessionFactory]: + """Return the default session factory for this controller instance.""" + + return None + + def _require_session_factory(self) -> SessionFactory: + """Return the configured session factory or raise an error.""" + + for candidate in ( + getattr(self, "_session_factory_override", None), + getattr(type(self), "session_factory", None), + self.get_default_session_factory(), + ): + if candidate is not None: + return candidate + raise RuntimeError("Database session factory is not configured") + + @staticmethod + def _get_primary_key_column(model: Type[ModelT]): + """Return the SQLAlchemy column representing the model primary key.""" + + orm_model = getattr(model, "__orm_model__", None) + if orm_model is None: + raise RuntimeError(f"{model.__name__} does not define an ORM mapping") + primary_key_columns = list(orm_model.__table__.primary_key.columns) # type: ignore[attr-defined] + if len(primary_key_columns) != 1: + raise RuntimeError( + f"{model.__name__} must define exactly one primary key column" + ) + return primary_key_columns[0] + + @classmethod + def _coerce_identifier(cls, model: Type[ModelT], identifier: Any) -> Any: + """Convert an identifier into the Python type expected by the ORM column.""" + + column = cls._get_primary_key_column(model) + python_type = getattr(column.type, "python_type", None) + if python_type is None or isinstance(identifier, python_type): + return identifier + try: + return python_type(identifier) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Invalid identifier for {model.__name__}: {identifier!r}" + ) from exc + + async def _create_instance(self, model: Type[ModelT], payload: ModelT) -> ModelT: + """Persist ``payload`` using the model helper and return the stored instance.""" + + session_factory = self._require_session_factory() + async with session_factory() as session: + return await model.create(session, **asdict(payload)) + + async def _update_instance( + self, + model: Type[ModelT], + payload: ModelT, + ) -> Optional[ModelT]: + """Update ``payload`` using the model helper and return the refreshed instance.""" + + identifier_name = self._get_primary_key_column(model).name + identifier = getattr(payload, identifier_name) + session_factory = self._require_session_factory() + async with session_factory() as session: + return await model.update(session, identifier, **asdict(payload)) + + async def _retrieve_instance( + self, + model: Type[ModelT], + identifier: Any, + ) -> Optional[ModelT]: + """Return a stored instance or ``None`` when the identifier is unknown.""" + + resolved_identifier = self._coerce_identifier(model, identifier) + session_factory = self._require_session_factory() + async with session_factory() as session: + return await model.get(session, resolved_identifier) + + async def _delete_instance(self, model: Type[ModelT], identifier: Any) -> bool: + """Delete the record referenced by ``identifier``.""" + + resolved_identifier = self._coerce_identifier(model, identifier) + session_factory = self._require_session_factory() + async with session_factory() as session: + return await model.delete(session, resolved_identifier) + + async def _list_instances(self, model: Type[ModelT]) -> list[ModelT]: + """Return all stored instances for ``model``.""" + + session_factory = self._require_session_factory() + async with session_factory() as session: + return await model.list(session) diff --git a/tests/test_example_emergency_management.py b/tests/test_example_emergency_management.py index 5971c88..5125211 100644 --- a/tests/test_example_emergency_management.py +++ b/tests/test_example_emergency_management.py @@ -168,8 +168,9 @@ async def test_event_controller_list_without_session_factory(monkeypatch) -> Non """Missing session factories should be reported via controller error payloads.""" monkeypatch.setattr( - "examples.EmergencyManagement.Server.controllers_emergency.async_session", + "examples.EmergencyManagement.Server.controllers_emergency.EventController.session_factory", None, + raising=False, ) monkeypatch.setattr( database_module, diff --git a/tests/test_sqlalchemy_controller.py b/tests/test_sqlalchemy_controller.py new file mode 100644 index 0000000..493b5d3 --- /dev/null +++ b/tests/test_sqlalchemy_controller.py @@ -0,0 +1,127 @@ +"""Unit tests for :mod:`reticulum_openapi.sqlalchemy_controller`.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +import pytest_asyncio +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import declarative_base + +from reticulum_openapi.controller import Controller +from reticulum_openapi.model import BaseModel +from reticulum_openapi.sqlalchemy_controller import SQLAlchemyControllerMixin + + +Base = declarative_base() + + +class DummyORM(Base): + """SQLAlchemy ORM model used to exercise the controller mixin.""" + + __tablename__ = "dummy_records" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False) + + +@dataclass +class DummyModel(BaseModel): + """Dataclass backed by :class:`DummyORM` for CRUD tests.""" + + id: int + name: str + + +DummyModel.__orm_model__ = DummyORM + + +class DummyController(SQLAlchemyControllerMixin, Controller): + """Minimal controller implementation exposing the mixin helpers.""" + + def __init__(self, session_factory=None) -> None: + super().__init__(session_factory=session_factory) + + +@pytest_asyncio.fixture +async def dummy_session_factory(): + """Create an in-memory SQLite session factory for tests.""" + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + session_factory = async_sessionmaker( + engine, + expire_on_commit=False, + class_=AsyncSession, + ) + yield session_factory + await engine.dispose() + + +@pytest.mark.asyncio +async def test_mixin_requires_session_factory() -> None: + """The mixin raises when no session factory is configured.""" + + controller = DummyController() + + with pytest.raises(RuntimeError): + await controller._list_instances(DummyModel) + + +@pytest.mark.asyncio +async def test_mixin_crud_flow(dummy_session_factory) -> None: + """CRUD helpers persist, retrieve, update, list, and delete records.""" + + controller = DummyController(session_factory=dummy_session_factory) + + created = await controller._create_instance( + DummyModel, DummyModel(id=1, name="Alpha") + ) + assert created.name == "Alpha" + + retrieved = await controller._retrieve_instance(DummyModel, "1") + assert retrieved is not None + assert retrieved.id == 1 + assert retrieved.name == "Alpha" + + with pytest.raises(ValueError): + await controller._retrieve_instance(DummyModel, "not-a-number") + + updated = await controller._update_instance( + DummyModel, + DummyModel(id=1, name="Beta"), + ) + assert updated is not None + assert updated.name == "Beta" + + listing = await controller._list_instances(DummyModel) + assert [item.name for item in listing] == ["Beta"] + + deleted = await controller._delete_instance(DummyModel, 1) + assert deleted is True + + missing = await controller._delete_instance(DummyModel, 1) + assert missing is False + + +@pytest.mark.asyncio +async def test_mixin_class_level_session_factory(dummy_session_factory) -> None: + """Controllers may rely on the class-level session factory configuration.""" + + DummyController.configure_session_factory(dummy_session_factory) + try: + controller = DummyController() + result = await controller._create_instance( + DummyModel, + DummyModel(id=5, name="Gamma"), + ) + assert result.id == 5 + finally: + DummyController.configure_session_factory(None)