From 37bd6fb9611473be710dd89f1bac97d019954a9f Mon Sep 17 00:00:00 2001 From: Corvo <60719165+brothercorvo@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:32:42 -0300 Subject: [PATCH] Support configurable EmergencyManagement database --- TASK.md | 3 + .../Server/controllers_emergency.py | 22 +-- .../EmergencyManagement/Server/database.py | 126 +++++++++++++++++- .../Server/server_emergency.py | 58 +++++++- tests/test_example_emergency_management.py | 53 ++++---- 5 files changed, 211 insertions(+), 51 deletions(-) diff --git a/TASK.md b/TASK.md index 939f297..70b4eb1 100644 --- a/TASK.md +++ b/TASK.md @@ -82,3 +82,6 @@ - [x] Upgrade esbuild dependency to version 0.25.0 or later to address the development server request vulnerability. - [x] Simplify EmergencyManagement web UI tables with a drawer form triggered by a New button. (2025-09-25) +## 2025-09-26 +- [x] Derive EmergencyManagement database configuration from the server module path and expose runtime overrides. + diff --git a/examples/EmergencyManagement/Server/controllers_emergency.py b/examples/EmergencyManagement/Server/controllers_emergency.py index e37edac..b17338f 100644 --- a/examples/EmergencyManagement/Server/controllers_emergency.py +++ b/examples/EmergencyManagement/Server/controllers_emergency.py @@ -2,7 +2,7 @@ from typing import Optional from reticulum_openapi.controller import Controller from reticulum_openapi.controller import handle_exceptions -from examples.EmergencyManagement.Server.database import async_session +from examples.EmergencyManagement.Server import database from examples.EmergencyManagement.Server.models_emergency import EmergencyActionMessage from examples.EmergencyManagement.Server.models_emergency import Event @@ -11,21 +11,21 @@ class EmergencyController(Controller): @handle_exceptions async def CreateEmergencyActionMessage(self, req: EmergencyActionMessage): self.logger.info(f"CreateEAM: {req}") - async with async_session() as session: + async with database.async_session() as session: await EmergencyActionMessage.create(session, **asdict(req)) return req @handle_exceptions async def DeleteEmergencyActionMessage(self, callsign: str): self.logger.info(f"DeleteEAM callsign={callsign}") - async with async_session() as session: + async with database.async_session() as session: deleted = await EmergencyActionMessage.delete(session, callsign) return {"status": "deleted" if deleted else "not_found", "callsign": callsign} @handle_exceptions async def ListEmergencyActionMessage(self): self.logger.info("ListEAM") - async with async_session() as session: + async with database.async_session() as session: items = await EmergencyActionMessage.list(session) return items @@ -42,7 +42,7 @@ async def PutEmergencyActionMessage( Optional[EmergencyActionMessage]: Updated dataclass instance or ``None`` if not found. """ self.logger.info(f"PutEAM: {req}") - async with async_session() as session: + async with database.async_session() as session: updated = await EmergencyActionMessage.update( session, req.callsign, **asdict(req) ) @@ -51,7 +51,7 @@ async def PutEmergencyActionMessage( @handle_exceptions async def RetrieveEmergencyActionMessage(self, callsign: str): self.logger.info(f"RetrieveEAM callsign={callsign}") - async with async_session() as session: + async with database.async_session() as session: item = await EmergencyActionMessage.get(session, callsign) return item @@ -60,21 +60,21 @@ class EventController(Controller): @handle_exceptions async def CreateEvent(self, req: Event): self.logger.info(f"CreateEvent: {req}") - async with async_session() as session: + async with database.async_session() as session: await Event.create(session, **asdict(req)) return req @handle_exceptions async def DeleteEvent(self, uid: str): self.logger.info(f"DeleteEvent uid={uid}") - async with async_session() as session: + async with database.async_session() as session: deleted = await Event.delete(session, int(uid)) return {"status": "deleted" if deleted else "not_found", "uid": uid} @handle_exceptions async def ListEvent(self): self.logger.info("ListEvent") - async with async_session() as session: + async with database.async_session() as session: events = await Event.list(session) return events @@ -89,13 +89,13 @@ async def PutEvent(self, req: Event) -> Optional[Event]: Optional[Event]: Updated dataclass instance or ``None`` if not found. """ self.logger.info(f"PutEvent: {req}") - async with async_session() as session: + async with database.async_session() as session: updated = await Event.update(session, req.uid, **asdict(req)) return updated @handle_exceptions async def RetrieveEvent(self, uid: str): self.logger.info(f"RetrieveEvent uid={uid}") - async with async_session() as session: + async with database.async_session() as session: event = await Event.get(session, int(uid)) return event diff --git a/examples/EmergencyManagement/Server/database.py b/examples/EmergencyManagement/Server/database.py index 0f639e3..b95af4f 100644 --- a/examples/EmergencyManagement/Server/database.py +++ b/examples/EmergencyManagement/Server/database.py @@ -1,12 +1,128 @@ -from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession +"""Database configuration helpers for the Emergency Management example.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional, Tuple + +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + from .models_emergency import Base -DATABASE_URL = "sqlite+aiosqlite:///emergency.db" -engine = create_async_engine(DATABASE_URL, echo=False) -async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) +DATABASE_ENV_VAR = "EMERGENCY_DATABASE_URL" +_DEFAULT_DATABASE_PATH = Path(__file__).resolve().with_name("emergency.db") +_DEFAULT_DATABASE_URL = f"sqlite+aiosqlite:///{_DEFAULT_DATABASE_PATH}" + +DATABASE_URL = _DEFAULT_DATABASE_URL +engine: Optional[AsyncEngine] = None +async_session: Optional[async_sessionmaker[AsyncSession]] = None + + +def _normalise_database_url(candidate: Optional[str]) -> str: + """Convert ``candidate`` into a SQLAlchemy database URL. + + Args: + candidate (Optional[str]): Potential override provided via the + environment, CLI, or direct helper invocation. + + Returns: + str: The normalised SQLAlchemy database URL. + """ + + if not candidate: + env_value = os.getenv(DATABASE_ENV_VAR) + candidate = env_value if env_value else None + + if not candidate: + return _DEFAULT_DATABASE_URL + + if "://" not in candidate: + db_path = Path(candidate).expanduser().resolve() + return f"sqlite+aiosqlite:///{db_path}" + + return candidate + + +def _create_engine_and_session( + url: str, +) -> Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: + """Create an async engine and session factory for ``url``. + + Args: + url (str): Database URL to connect to. + Returns: + Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: Engine and + session factory pair configured for the provided URL. + """ + + engine = create_async_engine(url, echo=False) + session_factory = async_sessionmaker( + engine, + expire_on_commit=False, + class_=AsyncSession, + ) + return engine, session_factory + + +def configure_database(url: Optional[str] = None) -> str: + """Configure the database engine and session factory. + + Args: + url (Optional[str]): Optional override for the database URL. File paths + are converted into SQLite URLs automatically. When ``None``, the + helper honours :data:`DATABASE_ENV_VAR` or falls back to the default + database file next to this module. + + Returns: + str: The database URL that was applied. + """ + + global DATABASE_URL + global engine + global async_session + + resolved_url = _normalise_database_url(url) + + if ( + resolved_url == DATABASE_URL + and engine is not None + and async_session is not None + ): + return DATABASE_URL + + engine, session_factory = _create_engine_and_session(resolved_url) + DATABASE_URL = resolved_url + async_session = session_factory + return DATABASE_URL + + +async def init_db(url: Optional[str] = None) -> None: + """Initialise the database schema if it does not exist. + + Args: + url (Optional[str]): Optional override passed through to + :func:`configure_database`. + + Returns: + None: The coroutine completes once the schema has been created. + """ + + configure_database(url) + if engine is None: + raise RuntimeError("Database engine is not configured") -async def init_db(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + + +# Initialise the module-level engine and session factory using the default +# configuration or any environment override available during import. +configure_database(None) diff --git a/examples/EmergencyManagement/Server/server_emergency.py b/examples/EmergencyManagement/Server/server_emergency.py index 41daf48..17d78c9 100644 --- a/examples/EmergencyManagement/Server/server_emergency.py +++ b/examples/EmergencyManagement/Server/server_emergency.py @@ -1,11 +1,12 @@ """Run the emergency management server example.""" - +import argparse import asyncio import signal import sys from contextlib import suppress from pathlib import Path +from typing import Optional, Sequence def _ensure_standard_library_on_path() -> None: @@ -69,6 +70,7 @@ def _configure_environment() -> None: EmergencyService = object() +configure_database = None init_db = None @@ -76,19 +78,28 @@ def _ensure_dependencies_loaded() -> None: """Load modules that require adjusted import paths.""" global EmergencyService + global configure_database global init_db - if isinstance(EmergencyService, type) and init_db is not None: + if ( + isinstance(EmergencyService, type) + and init_db is not None + and callable(configure_database) + ): return _configure_environment() - from examples.EmergencyManagement.Server.database import init_db as database_init_db + from examples.EmergencyManagement.Server.database import ( + configure_database as database_configure_database, + init_db as database_init_db, + ) from examples.EmergencyManagement.Server.service_emergency import ( EmergencyService as service_emergency_service, ) init_db = database_init_db + configure_database = database_configure_database EmergencyService = service_emergency_service @@ -123,16 +134,39 @@ def _sync_handler(*_: int, **__: object) -> None: try: - from examples.EmergencyManagement.Server.database import init_db + from examples.EmergencyManagement.Server.database import ( # type: ignore + configure_database, + init_db, + ) from examples.EmergencyManagement.Server.service_emergency import ( EmergencyService, ) except Exception: # pragma: no cover - best effort for optional imports + configure_database = None init_db = None EmergencyService = None -async def main() -> None: +def _resolve_database_override(argv: Optional[Sequence[str]]) -> Optional[str]: + """Parse ``argv`` for optional database overrides.""" + + if argv is None: + return None + + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--database") + parser.add_argument("--database-path") + parser.add_argument("--database-url") + parsed, _ = parser.parse_known_args(list(argv)) + + for candidate in (parsed.database_url, parsed.database_path, parsed.database): + if candidate: + return candidate + + return None + + +async def main(argv: Optional[Sequence[str]] = None) -> None: """Run the emergency management service until interrupted. Returns: @@ -142,10 +176,20 @@ async def main() -> None: _ensure_dependencies_loaded() - if init_db is None or not isinstance(EmergencyService, type): + if ( + init_db is None + or not isinstance(EmergencyService, type) + or not callable(configure_database) + ): raise RuntimeError("Emergency service dependencies failed to load") + + if argv is None: + argv = sys.argv[1:] + _configure_environment() - await init_db() + override = _resolve_database_override(argv) + configured_url = configure_database(override) + await init_db(configured_url) async with EmergencyService() as svc: svc.announce() stop_event = asyncio.Event() diff --git a/tests/test_example_emergency_management.py b/tests/test_example_emergency_management.py index 386d2dc..7473d5c 100644 --- a/tests/test_example_emergency_management.py +++ b/tests/test_example_emergency_management.py @@ -9,13 +9,7 @@ import pytest import pytest_asyncio -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import async_sessionmaker -from sqlalchemy.ext.asyncio import create_async_engine -from examples.EmergencyManagement.Server import ( - controllers_emergency as controllers_module, -) from examples.EmergencyManagement.Server import database as database_module from examples.EmergencyManagement.Server.controllers_emergency import ( EmergencyController, @@ -32,35 +26,36 @@ @pytest_asyncio.fixture -async def emergency_db(monkeypatch, tmp_path): +async def emergency_db(tmp_path): """Provide a temporary database and session factory for the example tests.""" db_path = tmp_path / "emergency_test.db" - engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}") - session_factory = async_sessionmaker( - engine, - expire_on_commit=False, - class_=AsyncSession, - ) + original_url = database_module.DATABASE_URL + original_engine = database_module.engine - # Reason: the controllers capture async_session at import time, so patch the - # module-level references to point at the temporary session factory. - monkeypatch.setattr(database_module, "engine", engine, raising=False) - monkeypatch.setattr( - database_module, "async_session", session_factory, raising=False - ) - monkeypatch.setattr( - controllers_module, "async_session", session_factory, raising=False - ) + configured_url = database_module.configure_database(str(db_path)) + assert configured_url == database_module.DATABASE_URL - async with engine.begin() as conn: + if database_module.engine is None: + raise RuntimeError("Database engine was not initialised") + if database_module.async_session is None: + raise RuntimeError("Session factory was not initialised") + + async with database_module.engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) try: - yield session_factory + yield database_module.async_session finally: - await engine.dispose() + if database_module.engine is not None: + await database_module.engine.dispose() + database_module.configure_database(original_url) + if ( + original_engine is not None + and original_engine is not database_module.engine + ): + await original_engine.dispose() @pytest.mark.asyncio @@ -193,9 +188,7 @@ def test_decode_event_list_fallback_handles_compressed_json() -> None: Event(uid=21, type="Test", point=Point(lat=3.0, lon=4.0)), Event(uid=22, type="Exercise", point=Point(lat=5.0, lon=6.0)), ] - payload = compress_json( - dataclass_to_json_bytes([asdict(item) for item in events]) - ) + payload = compress_json(dataclass_to_json_bytes([asdict(item) for item in events])) decoded = client_module._decode_event_list(payload) @@ -234,6 +227,7 @@ def test_server_script_importable_from_directory(monkeypatch) -> None: assert "EmergencyService" in globals_ns assert "init_db" in globals_ns + assert "configure_database" in globals_ns @pytest.mark.asyncio @@ -408,6 +402,7 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): return None @@ -498,6 +493,7 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): return None @@ -596,6 +592,7 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): return None