diff --git a/pyproject.toml b/pyproject.toml index 5b9e7c7d7..96887add0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,9 +100,10 @@ docs = [ "pdoc3>=0.10.0", ] pg = [ - # PostgreSQL-backed PgReplayStore (and future PgIdempotencyBackend). - # psycopg3 gives both sync + async client interfaces so the same dep - # serves the sync replay store today and an async one later. + # PostgreSQL-backed adcp.signing.PgReplayStore, + # adcp.decisioning.PostgresTaskRegistry (durable HITL task state), and + # future PgIdempotencyBackend. psycopg3 ships both sync + async pool + # interfaces so the single dep serves all three use cases. "psycopg[binary]>=3.1.0", "psycopg-pool>=3.2.0", ] @@ -121,6 +122,7 @@ adcp = [ "py.typed", "ADCP_VERSION", "signing/pg/*.sql", + "decisioning/pg/*.sql", # AdCP JSON schemas, mirrored from ``schemas/cache/`` by # ``scripts/bundle_schemas.py`` so the wheel ships them for # ``adcp.validation.schema_loader``. diff --git a/src/adcp/decisioning/__init__.py b/src/adcp/decisioning/__init__.py index 1c1b688cb..20643ae04 100644 --- a/src/adcp/decisioning/__init__.py +++ b/src/adcp/decisioning/__init__.py @@ -129,6 +129,32 @@ def create_media_buy( WorkflowHandoff, ) +# Conditional import: PostgresTaskRegistry needs the [pg] extra. Always expose +# the name — when psycopg isn't installed we fall through to a stub class whose +# constructor raises ImportError with the install hint. Matches the pattern +# used by adcp.signing for PgReplayStore. +try: + from adcp.decisioning.pg import PostgresTaskRegistry # noqa: F401 +except ImportError: # pragma: no cover — exercised by the [pg] extra tests + from typing import ClassVar as _ClassVar + + class PostgresTaskRegistry: # type: ignore[no-redef] + """Stub raised when ``adcp[pg]`` isn't installed. + + Attempting to instantiate raises :class:`ImportError` with the + install-hint text from :mod:`adcp.decisioning.pg.task_registry`. + """ + + is_durable: _ClassVar[bool] = True + + def __init__(self, *args: object, **kwargs: object) -> None: + raise ImportError( + "PostgresTaskRegistry requires psycopg3 and psycopg-pool. " + "Install the 'pg' extra: `pip install 'adcp[pg]'` " + "(Poetry: `poetry add 'adcp[pg]'`)." + ) + + __all__ = [ "Account", "AccountStore", @@ -161,6 +187,7 @@ def create_media_buy( "InMemoryTaskRegistry", "MaybeAsync", "OAuthCredential", + "PostgresTaskRegistry", "Proposal", "PropertyList", "PropertyListReference", diff --git a/src/adcp/decisioning/pg/__init__.py b/src/adcp/decisioning/pg/__init__.py index ce43e7b4f..758a94539 100644 --- a/src/adcp/decisioning/pg/__init__.py +++ b/src/adcp/decisioning/pg/__init__.py @@ -11,10 +11,17 @@ request to gate dispatch on the seller's commercial relationship with the buyer agent (allowlist + onboarding state + billing capabilities). +* :class:`PostgresTaskRegistry` — durable + :class:`~adcp.decisioning.TaskRegistry` for HITL task state. Survives + process restarts and is safe for multi-worker deployments sharing a + single Postgres database. Drop-in replacement for + :class:`~adcp.decisioning.InMemoryTaskRegistry` that satisfies the + production-mode durability gate. The schema DDL ships alongside the Python code (e.g. -``adcp/decisioning/pg/buyer_agent_registry.sql``) so adopters can run -it through whatever migration tool they use (Alembic, Flyway, psql). +``adcp/decisioning/pg/buyer_agent_registry.sql``, +``adcp/decisioning/pg/decisioning_tasks.sql``) so adopters can run it +through whatever migration tool they use (Alembic, Flyway, psql). """ from __future__ import annotations @@ -24,9 +31,11 @@ PG_AVAILABLE, PgBuyerAgentRegistry, ) +from adcp.decisioning.pg.task_registry import PostgresTaskRegistry __all__ = [ "DEFAULT_TABLE_NAME", "PG_AVAILABLE", "PgBuyerAgentRegistry", + "PostgresTaskRegistry", ] diff --git a/src/adcp/decisioning/pg/decisioning_tasks.sql b/src/adcp/decisioning/pg/decisioning_tasks.sql new file mode 100644 index 000000000..deb2b9868 --- /dev/null +++ b/src/adcp/decisioning/pg/decisioning_tasks.sql @@ -0,0 +1,32 @@ +-- AdCP decisioning task registry — durable HITL task state. +-- +-- Run this once per deployment. Tracked by PostgresTaskRegistry; +-- see src/adcp/decisioning/pg/task_registry.py for the query shapes +-- the Python code executes. +-- +-- COLLATE "C" on identifier columns avoids locale-dependent case +-- folding — on some locales "Task-A" and "task-a" compare equal, +-- which could collapse distinct task_ids or account_ids. "C" is the +-- byte-for-byte comparison we actually want. +-- +-- Alternatively, call PostgresTaskRegistry.create_schema() from +-- application code — it runs the equivalent DDL idempotently on boot. + +CREATE TABLE IF NOT EXISTS decisioning_tasks ( + task_id TEXT COLLATE "C" NOT NULL PRIMARY KEY, + account_id TEXT COLLATE "C" NOT NULL, + state TEXT NOT NULL DEFAULT 'submitted', + task_type TEXT NOT NULL, + progress JSONB, + result JSONB, + error JSONB, + -- Unix epoch seconds (float), matches TaskRecord.created_at/updated_at + -- so Python round-trips the value without lossy TIMESTAMPTZ conversion. + created_at DOUBLE PRECISION NOT NULL, + updated_at DOUBLE PRECISION NOT NULL +); + +-- Supports the cross-tenant get() query: WHERE task_id = $1 AND account_id = $2. +-- Without this index, every tasks/get is a full-table scan on account_id. +CREATE INDEX IF NOT EXISTS decisioning_tasks_account_idx + ON decisioning_tasks (account_id); diff --git a/src/adcp/decisioning/pg/task_registry.py b/src/adcp/decisioning/pg/task_registry.py new file mode 100644 index 000000000..101826c05 --- /dev/null +++ b/src/adcp/decisioning/pg/task_registry.py @@ -0,0 +1,357 @@ +"""PostgreSQL-backed :class:`~adcp.decisioning.TaskRegistry` implementation. + +Durable counterpart to :class:`~adcp.decisioning.InMemoryTaskRegistry`: +task state survives process restarts and is safe for multi-worker deployments +sharing a single Postgres database. + +The caller supplies an :class:`psycopg_pool.AsyncConnectionPool`. We don't +open, own, or close the pool — adopters typically share an existing pool with +their main application database. + +Quickstart +---------- + +:: + + import asyncio + from psycopg_pool import AsyncConnectionPool + from adcp.decisioning import PostgresTaskRegistry, serve + from myapp import MyPlatform + + async def main(): + async with AsyncConnectionPool( + "postgresql://user:pass@localhost/mydb", + min_size=2, + max_size=10, + ) as pool: + registry = PostgresTaskRegistry(pool=pool) + await registry.create_schema() # idempotent; safe on every boot + serve(MyPlatform(), registry=registry) + + asyncio.run(main()) + +Schema bootstrap +---------------- + +Call :meth:`create_schema` once per deployment (or every boot — it is +idempotent via ``CREATE TABLE IF NOT EXISTS``). The equivalent raw DDL +ships at :file:`src/adcp/decisioning/pg/decisioning_tasks.sql` for adopters +using a migration tool (Alembic, Flyway, psql). + +Cross-tenant safety +------------------- + +:meth:`get` enforces account isolation at the SQL level — +``WHERE account_id = %s`` is part of the query predicate, not a Python-level +filter. A mis-matched ``expected_account_id`` returns ``None`` without +materializing the row. + +Multi-worker concurrency +------------------------ + +Terminal-state transitions (:meth:`complete`, :meth:`fail`) use an atomic +``UPDATE ... WHERE state NOT IN ('completed', 'failed') RETURNING task_id`` +pattern. If the UPDATE lands zero rows, a follow-up SELECT determines whether +the task is unknown or already terminal, enabling correct idempotency +behavior across workers without optimistic-lock retries. + +:meth:`update_progress` similarly uses a conditional UPDATE that silently +no-ops on terminal rows, so a straggler progress write can never resurrect a +completed task. +""" + +from __future__ import annotations + +import json +import re +import time +import uuid +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from psycopg_pool import AsyncConnectionPool + +try: + from psycopg_pool import AsyncConnectionPool as _AsyncConnectionPool # noqa: F401 + + PG_AVAILABLE = True +except ImportError: + PG_AVAILABLE = False + +_INSTALL_HINT = ( + "PostgresTaskRegistry requires psycopg3 and psycopg-pool. " + "Install the 'pg' extra: `pip install 'adcp[pg]'` " + "(Poetry: `poetry add 'adcp[pg]'`)." +) + +_DEFAULT_TABLE = "decisioning_tasks" + +# ASCII-only identifier guard — same pattern as PgReplayStore._is_safe_identifier. +# The table name is static-formatted into SQL at construction so this guard is +# the only protection against SQL injection or Unicode homoglyph substitution. +_SAFE_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]{0,62}$") + + +class PostgresTaskRegistry: + """PostgreSQL-backed :class:`~adcp.decisioning.TaskRegistry` — v6.1. + + Durable counterpart to :class:`~adcp.decisioning.InMemoryTaskRegistry`. + Set ``is_durable = True`` so the production-mode gate in + :func:`adcp.decisioning.serve.create_adcp_server_from_platform` accepts it + without requiring ``ADCP_DECISIONING_ALLOW_INMEMORY_TASKS=1``. + + Parameters + ---------- + pool: + An :class:`psycopg_pool.AsyncConnectionPool` owned by the caller. + Each registry operation acquires a short-lived connection from the + pool and returns it immediately after the query. No long-lived + transactions, no cross-operation state. + + Notes + ----- + Unlike :class:`~adcp.signing.PgReplayStore`, this class uses a fixed + ``decisioning_tasks`` table name. Multi-tenant table-name isolation is not + supported in this release — callers requiring strict schema separation + should use separate databases or schemas. + """ + + is_durable: ClassVar[bool] = True + + def __init__(self, *, pool: AsyncConnectionPool, _table: str = _DEFAULT_TABLE) -> None: + if not PG_AVAILABLE: + raise ImportError(_INSTALL_HINT) + if not _SAFE_IDENTIFIER_RE.fullmatch(_table): + raise ValueError( + f"_table must match [a-z_][a-z0-9_]* (ASCII only), got {_table!r}" + ) + self._pool = pool + self._table = _table + + # Pre-format queries at construction so the hot path avoids f-strings per call. + # _table is whitelisted by _SAFE_IDENTIFIER_RE above. + self._sql_insert = ( # noqa: S608 — table name is whitelisted + f"INSERT INTO {self._table}" + f" (task_id, account_id, state, task_type, created_at, updated_at)" + f" VALUES (%s, %s, 'submitted', %s, %s, %s)" + ) + self._sql_update_progress = ( # noqa: S608 + f"UPDATE {self._table}" + f" SET state = CASE state WHEN 'submitted' THEN 'working' ELSE state END," + f" progress = %s::jsonb, updated_at = %s" + f" WHERE task_id = %s AND state NOT IN ('completed', 'failed')" + ) + self._sql_complete = ( # noqa: S608 + f"UPDATE {self._table}" + f" SET state = 'completed', result = %s::jsonb, updated_at = %s" + f" WHERE task_id = %s AND state NOT IN ('completed', 'failed')" + f" RETURNING task_id" + ) + self._sql_fail = ( # noqa: S608 + f"UPDATE {self._table}" + f" SET state = 'failed', error = %s::jsonb, updated_at = %s" + f" WHERE task_id = %s AND state NOT IN ('completed', 'failed')" + f" RETURNING task_id" + ) + self._sql_get = ( # noqa: S608 + f"SELECT task_id, account_id, state, task_type," + f" progress, result, error, created_at, updated_at" + f" FROM {self._table}" + f" WHERE task_id = %s AND (%s IS NULL OR account_id = %s)" + ) + self._sql_get_state_result = ( # noqa: S608 + f"SELECT state, result FROM {self._table} WHERE task_id = %s" + ) + self._sql_get_state_error = ( # noqa: S608 + f"SELECT state, error FROM {self._table} WHERE task_id = %s" + ) + self._sql_discard = f"DELETE FROM {self._table} WHERE task_id = %s" # noqa: S608 + self._sql_ddl = ( # noqa: S608 + f'CREATE TABLE IF NOT EXISTS {self._table} (' + f' task_id TEXT COLLATE "C" NOT NULL PRIMARY KEY,' + f' account_id TEXT COLLATE "C" NOT NULL,' + f" state TEXT NOT NULL DEFAULT 'submitted'," + f" task_type TEXT NOT NULL," + f" progress JSONB," + f" result JSONB," + f" error JSONB," + f" created_at DOUBLE PRECISION NOT NULL," + f" updated_at DOUBLE PRECISION NOT NULL" + f");" + f"CREATE INDEX IF NOT EXISTS {self._table}_account_idx" # noqa: S608 + f" ON {self._table} (account_id);" + ) + + # -- schema bootstrap ----------------------------------------------- + + async def create_schema(self) -> None: + """Create the task registry table and supporting index. + + Honors the ``_table`` kwarg the store was constructed with. + Idempotent via ``CREATE TABLE IF NOT EXISTS`` — safe to call on every + application boot. The equivalent raw DDL ships at + ``adcp/decisioning/pg/decisioning_tasks.sql`` in the installed package + for adopters using a migration tool (Alembic, Flyway, psql). + """ + async with self._pool.connection() as conn: + await conn.execute(self._sql_ddl) + + # -- TaskRegistry Protocol ------------------------------------------ + + async def issue( + self, + *, + account_id: str, + task_type: str, + ) -> str: + """Allocate a task_id, persist a ``submitted`` row, return the id. + + Mirrors :meth:`~adcp.decisioning.InMemoryTaskRegistry.issue` including + the account_id validation guard — empty or sentinel account_ids would + allow cross-tenant task-id probing via the ``WHERE account_id = %s`` + predicate collapsing multiple tenants into one slot. + """ + if not account_id or not account_id.strip() or account_id == "": + raise ValueError( + f"account_id must be a non-empty, non-default string; " + f"got {account_id!r}. AccountStore.resolve must always " + "return Account(id=) so cross-tenant cache " + "scoping works correctly." + ) + task_id = f"task_{uuid.uuid4().hex[:16]}" + now = time.time() + async with self._pool.connection() as conn: + await conn.execute(self._sql_insert, (task_id, account_id, task_type, now, now)) + return task_id + + async def update_progress( + self, + task_id: str, + progress: dict[str, Any], + ) -> None: + """Write a progress payload; transition ``submitted`` → ``working``. + + Silently no-ops when the task is already in a terminal state or + unknown — the dispatch wrapper expects this method never to raise on + transient conditions (see :class:`~adcp.decisioning.TaskRegistry` + docstring). + + The ``state NOT IN ('completed', 'failed')`` predicate is evaluated + server-side so a concurrent terminal write cannot be overwritten by a + straggler progress event. + """ + async with self._pool.connection() as conn: + await conn.execute( + self._sql_update_progress, + (json.dumps(progress), time.time(), task_id), + ) + # Zero rows updated means unknown task_id or terminal state — silent + # no-op per Protocol contract. The InMemoryTaskRegistry logs a + # WARNING on terminal-state drops; we omit the extra SELECT needed + # to distinguish the two cases since the dispatch wrapper swallows + # the result either way. + + async def complete( + self, + task_id: str, + result: dict[str, Any], + ) -> None: + """Mark the task ``completed`` with ``result`` as the terminal artifact. + + Idempotent on repeated calls with an equal ``result``; raises + :class:`ValueError` on conflicting re-completion. + + Uses an atomic ``UPDATE ... RETURNING`` so concurrent workers cannot + race each other into double-completion without detection. + """ + async with self._pool.connection() as conn: + cur = await conn.execute( + self._sql_complete, (json.dumps(result), time.time(), task_id) + ) + if await cur.fetchone() is not None: + return # updated successfully + + # Zero rows in RETURNING — task is unknown or already terminal. + cur2 = await conn.execute(self._sql_get_state_result, (task_id,)) + row = await cur2.fetchone() + if row is None: + raise ValueError(f"Task {task_id!r} not found") + state, existing_result = row + if state == "completed": + if existing_result == result: + return # idempotent + raise ValueError(f"Task {task_id!r} already completed with a different result") + raise ValueError(f"Task {task_id!r} already in terminal state {state!r}") + + async def fail( + self, + task_id: str, + error: dict[str, Any], + ) -> None: + """Mark the task ``failed`` with ``error`` as the terminal payload. + + Idempotent on repeated calls with an equal ``error``; raises + :class:`ValueError` on conflicting re-failure. + """ + async with self._pool.connection() as conn: + cur = await conn.execute( + self._sql_fail, (json.dumps(error), time.time(), task_id) + ) + if await cur.fetchone() is not None: + return # updated successfully + + # Zero rows in RETURNING — task is unknown or already terminal. + cur2 = await conn.execute(self._sql_get_state_error, (task_id,)) + row = await cur2.fetchone() + if row is None: + raise ValueError(f"Task {task_id!r} not found") + state, existing_error = row + if state == "failed": + if existing_error == error: + return # idempotent + raise ValueError(f"Task {task_id!r} already failed with a different error") + raise ValueError(f"Task {task_id!r} already in terminal state {state!r}") + + async def get( + self, + task_id: str, + *, + expected_account_id: str | None = None, + ) -> dict[str, Any] | None: + """Look up a task record; cross-tenant probes return ``None``. + + The ``expected_account_id`` predicate is enforced at the SQL level + (``WHERE account_id = %s``), not as a Python-level filter after fetch. + This guarantees the row is never materialized for a mismatched probe, + eliminating the fetch-then-filter anti-pattern. + """ + async with self._pool.connection() as conn: + cur = await conn.execute( + self._sql_get, (task_id, expected_account_id, expected_account_id) + ) + row = await cur.fetchone() + if row is None: + return None + return { + "task_id": row[0], + "account_id": row[1], + "state": row[2], + "task_type": row[3], + "progress": row[4], + "result": row[5], + "error": row[6], + "created_at": row[7], + "updated_at": row[8], + } + + async def discard(self, task_id: str) -> None: + """Remove a task_id from the registry — rollback path. + + Idempotent: discarding an unknown task_id is a no-op (no raise), + matching the :class:`~adcp.decisioning.InMemoryTaskRegistry` contract. + """ + async with self._pool.connection() as conn: + await conn.execute(self._sql_discard, (task_id,)) + + +__all__ = ["PG_AVAILABLE", "PostgresTaskRegistry"] diff --git a/tests/conformance/decisioning/test_pg_task_registry.py b/tests/conformance/decisioning/test_pg_task_registry.py new file mode 100644 index 000000000..1e61c78c3 --- /dev/null +++ b/tests/conformance/decisioning/test_pg_task_registry.py @@ -0,0 +1,261 @@ +"""Conformance tests for :class:`adcp.decisioning.pg.PostgresTaskRegistry`. + +Requires a real PostgreSQL instance. To run locally:: + + docker run --rm -d -p 5432:5432 -e POSTGRES_PASSWORD=pg postgres:16 + export ADCP_PG_TEST_URL=postgresql://postgres:pg@localhost:5432/postgres + pytest tests/conformance/decisioning/test_pg_task_registry.py -v + +The entire module skips when ``ADCP_PG_TEST_URL`` is unset, so the +default test matrix stays green without a database dependency. + +Each test runs against a freshly-created ``decisioning_tasks_`` +table so parallel runs and crash-then-retry scenarios don't collide. + +These tests mirror the behavioral guarantees of +``tests/test_decisioning_task_registry.py`` (InMemoryTaskRegistry) and +``tests/test_decisioning_task_registry_cross_tenant.py`` (security) against +a real Postgres engine to catch SQL-level divergence. +""" + +from __future__ import annotations + +import asyncio +import os +import secrets +from collections.abc import AsyncIterator + +import pytest + +psycopg = pytest.importorskip("psycopg") +psycopg_pool = pytest.importorskip("psycopg_pool") + +TEST_URL = os.environ.get("ADCP_PG_TEST_URL") +if not TEST_URL: + pytest.skip( + "ADCP_PG_TEST_URL not set — skipping PostgresTaskRegistry conformance tests", + allow_module_level=True, + ) + +from adcp.decisioning.pg import PostgresTaskRegistry # noqa: E402 + +# -- fixtures --------------------------------------------------------------- + + +@pytest.fixture() +async def registry() -> AsyncIterator[PostgresTaskRegistry]: + """Async pool + isolated table per test, torn down on exit. + + Uses the internal ``_table`` parameter so each fixture invocation creates + and drops its own uniquely-named table. Parallel test runs and crash-then-retry + scenarios don't collide. + """ + table = f"test_dtasks_{secrets.token_hex(6)}" + async with psycopg_pool.AsyncConnectionPool( + TEST_URL, + min_size=2, + max_size=8, + open=False, + ) as pool: + await pool.open() + reg = PostgresTaskRegistry(pool=pool, _table=table) + await reg.create_schema() + try: + yield reg + finally: + async with pool.connection() as conn: + await conn.execute(f"DROP TABLE IF EXISTS {table}") # noqa: S608 + + +# -- Protocol happy-path --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_issue_returns_unique_task_ids(registry: PostgresTaskRegistry) -> None: + ids = [await registry.issue(account_id="acct1", task_type="create_media_buy") + for _ in range(5)] + assert len(set(ids)) == 5 + + +@pytest.mark.asyncio +async def test_issue_then_get_returns_submitted(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "submitted" + assert record["task_type"] == "create_media_buy" + assert record["account_id"] == "acct1" + + +@pytest.mark.asyncio +async def test_update_progress_transitions_submitted_to_working( + registry: PostgresTaskRegistry, +) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + await registry.update_progress(task_id, {"message": "Reviewing"}) + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "working" + assert record["progress"] == {"message": "Reviewing"} + + +@pytest.mark.asyncio +async def test_update_progress_noop_on_unknown_task(registry: PostgresTaskRegistry) -> None: + # Must not raise; the dispatch wrapper relies on this being silent. + await registry.update_progress("task_unknown", {"x": 1}) + + +@pytest.mark.asyncio +async def test_complete_transitions_to_completed(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + result = {"media_buy_id": "mb_1", "status": "active"} + await registry.complete(task_id, result) + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "completed" + assert record["result"] == result + + +@pytest.mark.asyncio +async def test_complete_idempotent_on_equal_result(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + result = {"media_buy_id": "mb_1"} + await registry.complete(task_id, result) + await registry.complete(task_id, result) # must not raise + + +@pytest.mark.asyncio +async def test_complete_raises_on_different_result(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + await registry.complete(task_id, {"media_buy_id": "mb_1"}) + with pytest.raises(ValueError, match="different result"): + await registry.complete(task_id, {"media_buy_id": "mb_2"}) + + +@pytest.mark.asyncio +async def test_fail_transitions_to_failed(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + error = {"code": "BUDGET_TOO_LOW", "message": "Budget below minimum"} + await registry.fail(task_id, error) + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "failed" + assert record["error"] == error + + +@pytest.mark.asyncio +async def test_fail_idempotent_on_equal_error(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + error = {"code": "BUDGET_TOO_LOW"} + await registry.fail(task_id, error) + await registry.fail(task_id, error) # must not raise + + +@pytest.mark.asyncio +async def test_fail_raises_on_different_error(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + await registry.fail(task_id, {"code": "BUDGET_TOO_LOW"}) + with pytest.raises(ValueError, match="different error"): + await registry.fail(task_id, {"code": "RATE_LIMITED"}) + + +@pytest.mark.asyncio +async def test_update_progress_noop_on_completed_task(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + await registry.complete(task_id, {"media_buy_id": "mb_1"}) + await registry.update_progress(task_id, {"message": "late straggler"}) + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "completed" # must not revert to working + + +@pytest.mark.asyncio +async def test_discard_removes_submitted_task(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + await registry.discard(task_id) + assert await registry.get(task_id) is None + + +@pytest.mark.asyncio +async def test_discard_unknown_task_is_noop(registry: PostgresTaskRegistry) -> None: + await registry.discard("task_does_not_exist") # must not raise + + +@pytest.mark.asyncio +async def test_complete_unknown_task_raises(registry: PostgresTaskRegistry) -> None: + """complete() on an unknown task_id must raise ValueError (matches InMemory).""" + with pytest.raises(ValueError, match="not found"): + await registry.complete("task_unknown", {"media_buy_id": "mb_1"}) + + +@pytest.mark.asyncio +async def test_fail_unknown_task_raises(registry: PostgresTaskRegistry) -> None: + """fail() on an unknown task_id must raise ValueError (matches InMemory).""" + with pytest.raises(ValueError, match="not found"): + await registry.fail("task_unknown", {"code": "RATE_LIMITED"}) + + +# -- Cross-tenant security ------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_cross_tenant_probe_returns_none(registry: PostgresTaskRegistry) -> None: + """get(expected_account_id=wrong) must return None — SQL-level enforcement.""" + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + result = await registry.get(task_id, expected_account_id="acct2") + assert result is None + + +@pytest.mark.asyncio +async def test_get_no_account_filter_returns_record(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + record = await registry.get(task_id, expected_account_id=None) + assert record is not None + assert record["account_id"] == "acct1" + + +@pytest.mark.asyncio +async def test_get_correct_account_returns_record(registry: PostgresTaskRegistry) -> None: + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + record = await registry.get(task_id, expected_account_id="acct1") + assert record is not None + + +@pytest.mark.asyncio +async def test_get_unknown_task_returns_none(registry: PostgresTaskRegistry) -> None: + assert await registry.get("task_unknown") is None + + +# -- Concurrency ----------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_concurrent_issue_yields_unique_ids(registry: PostgresTaskRegistry) -> None: + ids = await asyncio.gather( + *[registry.issue(account_id="acct1", task_type="create_media_buy") + for _ in range(20)] + ) + assert len(set(ids)) == 20 + + +@pytest.mark.asyncio +async def test_concurrent_complete_idempotent(registry: PostgresTaskRegistry) -> None: + """Two workers racing complete() with the same result must not error.""" + task_id = await registry.issue(account_id="acct1", task_type="create_media_buy") + result = {"media_buy_id": "mb_1"} + await asyncio.gather( + registry.complete(task_id, result), + registry.complete(task_id, result), + ) + record = await registry.get(task_id) + assert record is not None + assert record["state"] == "completed" + + +# -- Schema helpers -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_schema_idempotent(registry: PostgresTaskRegistry) -> None: + """create_schema() called twice must not error.""" + await registry.create_schema() # second call; fixture already called it once diff --git a/tests/test_decisioning_pg_task_registry.py b/tests/test_decisioning_pg_task_registry.py new file mode 100644 index 000000000..5c45b690e --- /dev/null +++ b/tests/test_decisioning_pg_task_registry.py @@ -0,0 +1,103 @@ +"""Unit tests for adcp.decisioning.pg.PostgresTaskRegistry. + +These tests run without a PostgreSQL instance — they cover: + +* Stub import path (``from adcp.decisioning import PostgresTaskRegistry``) +* ``is_durable = True`` marker (class-level, not instance-level) +* ``ImportError`` raised on instantiation when ``[pg]`` extra is absent +* account_id validation in ``issue()`` (same guard as InMemoryTaskRegistry) +* Protocol structural matching when psycopg IS installed + +Real-database behavioral tests live in +``tests/conformance/decisioning/test_pg_task_registry.py`` — they skip +unless ``ADCP_PG_TEST_URL`` is set. +""" + +from __future__ import annotations + +import pytest + +# -- Stub always importable ----------------------------------------------- + + +def test_postgres_task_registry_importable_from_decisioning() -> None: + """PostgresTaskRegistry is always importable from adcp.decisioning, + even without the [pg] extra (stub class replaces the real one).""" + from adcp.decisioning import PostgresTaskRegistry # noqa: F401 + + assert PostgresTaskRegistry is not None + + +def test_postgres_task_registry_is_durable_stub() -> None: + """The stub class advertises is_durable=True so type-checking passes.""" + from adcp.decisioning import PostgresTaskRegistry + + assert getattr(PostgresTaskRegistry, "is_durable", None) is True + + +def test_postgres_task_registry_stub_raises_import_error_without_pg() -> None: + """When psycopg_pool is not installed, instantiation raises ImportError.""" + import importlib.util + + if importlib.util.find_spec("psycopg_pool") is not None: + pytest.skip("psycopg_pool is installed — stub not in effect") + + from adcp.decisioning import PostgresTaskRegistry + + with pytest.raises(ImportError, match="adcp\\[pg\\]"): + PostgresTaskRegistry(pool=None) # type: ignore[arg-type] + + +# -- Tests requiring psycopg (structural, no real DB) --------------------- + + +psycopg_pool = pytest.importorskip( + "psycopg_pool", + reason="psycopg_pool not installed — skipping structural pg tests", +) + + +def test_postgres_task_registry_satisfies_protocol() -> None: + """PostgresTaskRegistry structurally matches the TaskRegistry Protocol + when the [pg] extra is installed.""" + from unittest.mock import MagicMock + + from adcp.decisioning import PostgresTaskRegistry + from adcp.decisioning.task_registry import TaskRegistry + + mock_pool = MagicMock() + registry = PostgresTaskRegistry(pool=mock_pool) + assert isinstance(registry, TaskRegistry) + + +def test_postgres_task_registry_is_durable_class_var() -> None: + """is_durable must be a class-level bool, not an instance attribute. + + serve.py checks ``type(registry).is_durable`` (via hasattr(type(...))) + so an instance-level attribute would pass the hasattr check but fail + mypy's ClassVar constraint and Protocol matching. + """ + from adcp.decisioning.pg import PostgresTaskRegistry + + assert PostgresTaskRegistry.is_durable is True + # Verify it's on the class, not only on instances. + assert "is_durable" in PostgresTaskRegistry.__dict__ + + +@pytest.mark.asyncio +async def test_issue_rejects_empty_account_id() -> None: + """issue() must reject blank / sentinel account_ids (cross-tenant guard).""" + from unittest.mock import MagicMock + + from adcp.decisioning.pg import PostgresTaskRegistry + + registry = PostgresTaskRegistry(pool=MagicMock()) + + with pytest.raises(ValueError, match="account_id must be a non-empty"): + await registry.issue(account_id="", task_type="create_media_buy") + + with pytest.raises(ValueError, match="account_id must be a non-empty"): + await registry.issue(account_id=" ", task_type="create_media_buy") + + with pytest.raises(ValueError, match="account_id must be a non-empty"): + await registry.issue(account_id="", task_type="create_media_buy")