diff --git a/TASK.md b/TASK.md index 6708bcf..0c4e7da 100644 --- a/TASK.md +++ b/TASK.md @@ -103,6 +103,7 @@ - [x] Pin Reticulum (RNS) to 1.0.2 and LXMF to 0.9.2 in project dependencies. - [x] Centralise payload conversion utilities and refactor EmergencyManagement client and gateway to use them. - [x] Extract SQLAlchemy controller mixin for async CRUD helpers and refactor emergency controllers/tests. +- [x] Share async database helpers across services and migrate EmergencyManagement to use them. ## 2025-11-12 - [x] Introduce FastAPI integration helpers for LXMF configuration, dependencies, and command routing. diff --git a/examples/EmergencyManagement/Server/database.py b/examples/EmergencyManagement/Server/database.py index fd3c405..fdc4366 100644 --- a/examples/EmergencyManagement/Server/database.py +++ b/examples/EmergencyManagement/Server/database.py @@ -2,20 +2,20 @@ from __future__ import annotations -import os import json from pathlib import Path -from typing import Any, Optional, Tuple +from typing import Any, Optional from sqlalchemy import inspect from sqlalchemy import literal from sqlalchemy import select -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - AsyncSession, - async_sessionmaker, - create_async_engine, -) +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_sessionmaker + +from reticulum_openapi.database import create_async_engine_and_session +from reticulum_openapi.database import initialise_database +from reticulum_openapi.database import normalise_database_url from .models_emergency import Base from .models_emergency import EventDetailORM, EventORM, EventPointORM @@ -119,59 +119,6 @@ def _backfill_event_components(connection) -> None: existing_point.add(uid) -def _apply_schema_upgrades(connection) -> None: - """Run schema upgrade routines for legacy deployments.""" - - _backfill_event_components(connection) - - -def _normalise_database_url(candidate: Optional[str]) -> str: - """Convert ``candidate`` into a SQLAlchemy database URL. - - Args: - candidate (Optional[str]): Potential override provided via the - environment, CLI, or direct helper invocation. - - Returns: - str: The normalised SQLAlchemy database URL. - """ - - if not candidate: - env_value = os.getenv(DATABASE_ENV_VAR) - candidate = env_value if env_value else None - - if not candidate: - return _DEFAULT_DATABASE_URL - - if "://" not in candidate: - db_path = Path(candidate).expanduser().resolve() - return f"sqlite+aiosqlite:///{db_path}" - - return candidate - - -def _create_engine_and_session( - url: str, -) -> Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: - """Create an async engine and session factory for ``url``. - - Args: - url (str): Database URL to connect to. - - Returns: - Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: Engine and - session factory pair configured for the provided URL. - """ - - engine = create_async_engine(url, echo=False) - session_factory = async_sessionmaker( - engine, - expire_on_commit=False, - class_=AsyncSession, - ) - return engine, session_factory - - def configure_database(url: Optional[str] = None) -> str: """Configure the database engine and session factory. @@ -189,7 +136,11 @@ def configure_database(url: Optional[str] = None) -> str: global engine global async_session - resolved_url = _normalise_database_url(url) + resolved_url = normalise_database_url( + url, + default_url=_DEFAULT_DATABASE_URL, + env_var=DATABASE_ENV_VAR, + ) if ( resolved_url == DATABASE_URL @@ -198,8 +149,9 @@ def configure_database(url: Optional[str] = None) -> str: ): return DATABASE_URL - engine, session_factory = _create_engine_and_session(resolved_url) + created_engine, session_factory = create_async_engine_and_session(resolved_url) DATABASE_URL = resolved_url + engine = created_engine async_session = session_factory return DATABASE_URL @@ -219,9 +171,11 @@ async def init_db(url: Optional[str] = None) -> None: if engine is None: raise RuntimeError("Database engine is not configured") - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - await conn.run_sync(_apply_schema_upgrades) + await initialise_database( + engine, + metadata=Base.metadata, + upgrade_hooks=(_backfill_event_components,), + ) # Initialise the module-level engine and session factory using the default diff --git a/reticulum_openapi/database.py b/reticulum_openapi/database.py new file mode 100644 index 0000000..c4eb22f --- /dev/null +++ b/reticulum_openapi/database.py @@ -0,0 +1,123 @@ +"""Async database configuration helpers for Reticulum OpenAPI projects.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +from dotenv import load_dotenv +from sqlalchemy import MetaData +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +load_dotenv() + + +def normalise_database_url( + candidate: Optional[str], + *, + default_url: str, + env_var: Optional[str] = None, +) -> str: + """Convert ``candidate`` into an async SQLAlchemy database URL. + + Args: + candidate (Optional[str]): Potential override provided via configuration + files, CLI arguments, or direct helper invocation. + default_url (str): URL to return when no overrides are provided. + env_var (Optional[str]): Environment variable used as a secondary + override before falling back to :data:`default_url`. + + Returns: + str: The normalised SQLAlchemy database URL suitable for async engines. + """ + + if not candidate and env_var: + env_value = os.getenv(env_var) + candidate = env_value if env_value else None + + if not candidate: + candidate = default_url + + if "://" not in candidate: + db_path = Path(candidate).expanduser().resolve() + return f"sqlite+aiosqlite:///{db_path}" + + return candidate + + +def create_async_engine_and_session( + url: str, + *, + echo: bool = False, + engine_kwargs: Optional[Dict[str, Any]] = None, + session_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: + """Create an async engine and session factory for ``url``. + + Args: + url (str): Database URL to connect to. + echo (bool): When ``True`` SQLAlchemy will log SQL statements. Defaults + to ``False``. + engine_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments + forwarded to :func:`sqlalchemy.ext.asyncio.create_async_engine`. + session_kwargs (Optional[Dict[str, Any]]): Keyword arguments applied to + :func:`sqlalchemy.ext.asyncio.async_sessionmaker`. + + Returns: + Tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: Configured engine + and session factory pair. + """ + + engine_options = engine_kwargs.copy() if engine_kwargs else {} + engine = create_async_engine(url, echo=echo, **engine_options) + + session_options = {"expire_on_commit": False, "class_": AsyncSession} + if session_kwargs: + session_options.update(session_kwargs) + + session_factory = async_sessionmaker(engine, **session_options) + return engine, session_factory + + +def _run_upgrade_hooks( + connection: Connection, + hooks: Sequence[Callable[[Connection], None]], +) -> None: + """Execute upgrade hooks against a synchronous SQLAlchemy connection.""" + + for hook in hooks: + hook(connection) + + +async def initialise_database( + engine: AsyncEngine, + *, + metadata: MetaData, + upgrade_hooks: Optional[Sequence[Callable[[Connection], None]]] = None, +) -> None: + """Initialise database schema and run upgrade hooks. + + Args: + engine (AsyncEngine): Engine used to issue schema commands. + metadata (MetaData): Metadata describing the schema to create. + upgrade_hooks (Optional[Sequence[Callable[[Connection], None]]]): + Iterable of callables executed after :meth:`MetaData.create_all`. + Each hook receives a synchronous :class:`sqlalchemy.engine.Connection`. + + Returns: + None: Completes once the schema exists and hooks have been executed. + """ + + hooks: Sequence[Callable[[Connection], None]] = upgrade_hooks if upgrade_hooks else () + + async with engine.begin() as connection: + await connection.run_sync(metadata.create_all) + if hooks: + await connection.run_sync(_run_upgrade_hooks, hooks) diff --git a/tests/test_database_module.py b/tests/test_database_module.py new file mode 100644 index 0000000..c4952d4 --- /dev/null +++ b/tests/test_database_module.py @@ -0,0 +1,77 @@ +"""Integration tests for the shared database helpers.""" + +from __future__ import annotations + +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from reticulum_openapi.database import create_async_engine_and_session +from reticulum_openapi.database import initialise_database +from reticulum_openapi.database import normalise_database_url + + +def test_normalise_database_url_prefers_candidate_path(tmp_path, monkeypatch) -> None: + """Explicit candidate paths should override environment defaults.""" + + monkeypatch.setenv("RETICULUM_TEST_DB", "sqlite+aiosqlite:///env.db") + db_path = tmp_path / "custom.sqlite" + + result = normalise_database_url( + str(db_path), + default_url="sqlite+aiosqlite:///default.db", + env_var="RETICULUM_TEST_DB", + ) + + assert result.endswith("custom.sqlite") + assert result.startswith("sqlite+aiosqlite:///") + + +def test_normalise_database_url_uses_environment(monkeypatch, tmp_path) -> None: + """Environment variables should override configured defaults when present.""" + + env_path = tmp_path / "env.sqlite" + monkeypatch.setenv("RETICULUM_TEST_DB", str(env_path)) + + result = normalise_database_url( + None, + default_url="sqlite+aiosqlite:///default.db", + env_var="RETICULUM_TEST_DB", + ) + + assert result.endswith("env.sqlite") + assert result.startswith("sqlite+aiosqlite:///") + + +@pytest.mark.asyncio +async def test_initialise_database_runs_upgrade_hook(tmp_path) -> None: + """Upgrade hooks should run after schema creation for new databases.""" + + db_path = tmp_path / "integration.sqlite" + url = f"sqlite+aiosqlite:///{db_path}" + + metadata = MetaData() + Table("items", metadata, Column("id", Integer, primary_key=True)) + + hook_invocations = [] + + def upgrade(connection) -> None: + hook_invocations.append( + connection.execute(text("SELECT COUNT(*) FROM items")).scalar_one() + ) + + engine, session_factory = create_async_engine_and_session(url) + + try: + await initialise_database( + engine, + metadata=metadata, + upgrade_hooks=(upgrade,), + ) + + async with session_factory() as session: + result = await session.execute(text("SELECT COUNT(*) FROM items")) + assert result.scalar_one() == 0 + finally: + await engine.dispose() + + assert hook_invocations == [0]