diff --git a/src/adcp/decisioning/dispatch.py b/src/adcp/decisioning/dispatch.py index 47507cfe..70afbe16 100644 --- a/src/adcp/decisioning/dispatch.py +++ b/src/adcp/decisioning/dispatch.py @@ -504,6 +504,50 @@ def _walk_ctx_metadata_list(items: list[Any]) -> None: raise ValueError(f"[{index}]: {exc}") from None +def _extract_request_context(params: Any) -> dict[str, Any] | None: + """Pull the buyer-supplied ``context`` extension off the original + request for ``TaskRecord.request_context``. + + The framework hands platform methods a typed Pydantic model in + production; test fixtures occasionally pass a raw dict. + ``model_dump`` failures (rare — Pydantic models with + non-serializable ``extra='allow'`` fields) log + return ``None`` + so downstream tasks/get reads simply omit ``context`` rather than + surfacing a partial / corrupted value. Buyers polling tasks/get + won't see a context echo on those (rare) requests, but their + failure-mode is "missed correlation," not "corrupted wire shape". + + Returns ``None`` when the request had no context field, the + coercion failed, or ``params`` itself was ``None`` (test fixtures + invoking ``_project_handoff`` directly without going through + ``_invoke_platform_method``). The 64KB amplification cap from + :func:`adcp.server.helpers.inject_context` does NOT apply at this + layer — the registry is server-internal storage, not a wire echo + surface; size-bounded enforcement on tasks/get reads should live + on the projection layer if required. + """ + if params is None: + return None + if isinstance(params, dict): + ctx = params.get("context") + return dict(ctx) if isinstance(ctx, dict) else None + if hasattr(params, "model_dump") and callable(params.model_dump): + try: + dumped = params.model_dump(mode="json", exclude_none=False) + except Exception: + logger.warning( + "request_params model_dump failed for %s; tasks/get context " + "echo skipped (correlation IDs lost). Verify the request " + "model serializes cleanly via model_dump.", + type(params).__name__, + exc_info=True, + ) + return None + ctx = dumped.get("context") if isinstance(dumped, dict) else None + return dict(ctx) if isinstance(ctx, dict) else None + return None + + def _internal_error_message(method_name: str, exc: BaseException) -> str: """Build the wire-side ``message`` for an INTERNAL_ERROR wrap. @@ -1227,6 +1271,7 @@ async def _invoke_platform_method( executor=executor, on_complete=on_complete, on_failure=on_failure, + request_params=params, ) if is_workflow_handoff(result): return await _project_workflow_handoff( @@ -1235,6 +1280,7 @@ async def _invoke_platform_method( method_name=method_name, registry=registry, executor=executor, + request_params=params, ) # Sync return path. Fire on_complete with the typed result before @@ -1291,6 +1337,7 @@ async def _project_handoff( executor: ThreadPoolExecutor, on_complete: Callable[[Any], Awaitable[None]] | None = None, on_failure: Callable[[BaseException], Awaitable[None]] | None = None, + request_params: BaseModel | None = None, ) -> dict[str, Any]: """Promote a TaskHandoff to a background task. @@ -1343,6 +1390,15 @@ async def _project_handoff( needs the failure visible via ``tasks/get`` regardless of hook outcomes. + :param request_params: The original request Pydantic model that + triggered the task. Used to echo the request's ``context`` + extension into the registry-stored wire envelope on both + success (``registry.complete``) and failure + (``registry.fail``) paths — closes #563. Mirrors the sync + AdcpError path's context-passthrough (PR #560). When ``None``, + no echo happens (e.g. test fixtures invoking the handoff + helper directly). + The handoff fn is extracted via the type-identity dispatch in :func:`adcp.decisioning.types.is_task_handoff`. Subclassed TaskHandoff instances (deliberate non-feature) silently take the @@ -1350,9 +1406,18 @@ async def _project_handoff( """ fn = handoff._fn + # Extract the buyer's ``context`` extension from the original + # request and lock it onto the TaskRecord at issue-time. The + # registry surfaces it at the top level of ``tasks/get`` reads + # (sibling of ``result`` / ``error`` per + # ``schemas/cache/core/tasks_get_response.json``). Capturing once + # at issue-time means the terminal-state helpers (_fail, + # registry.complete) never need to know about request-side + # context — keeps the wire-shape boundary in one place. task_id = await registry.issue( account_id=ctx.account.id, task_type=method_name, + request_context=_extract_request_context(request_params), ) # Hand off to background. The wire envelope returns immediately; @@ -1364,7 +1429,17 @@ async def _fail(exc: AdcpError) -> None: """Run the framework's on_failure hook (if set) then ``registry.fail``. Hook errors are logged but never block the registry.fail — the buyer needs the failure visible via - tasks/get regardless of hook outcomes.""" + tasks/get regardless of hook outcomes. + + Note: the request's ``context`` extension lands on the + ``tasks/get`` response at the top level (sibling of + ``error``), not inside the ``adcp_error`` envelope — see + :meth:`TaskRecord.to_dict` and the + ``schemas/cache/core/tasks_get_response.json`` + ``TasksGetResponse.context`` field. The context is captured + once at ``registry.issue()`` time below; ``_fail`` doesn't + touch it. + """ if on_failure is not None: try: await on_failure(exc) @@ -1460,6 +1535,12 @@ async def _run() -> None: # the typed Pydantic response. persisted = {"value": str(result)} persisted = strip_credentials_from_wire_result(method_name, persisted) + # ``request.context`` echo lands at the top level of the + # ``tasks/get`` response (sibling of ``result``), not inside + # the typed result body. ``TaskRecord.request_context`` was + # captured at ``registry.issue()`` time and ``to_dict()`` + # surfaces it under the top-level ``context`` key; nothing to + # do here on the result path. await registry.complete(task_id, persisted) # ``asyncio.create_task`` only weak-refs the resulting Task — under @@ -1507,6 +1588,7 @@ async def _project_workflow_handoff( method_name: str, registry: TaskRegistry, executor: ThreadPoolExecutor, + request_params: BaseModel | Any | None = None, ) -> dict[str, Any]: """Project a :class:`WorkflowHandoff` to the wire Submitted envelope. @@ -1538,9 +1620,16 @@ async def _project_workflow_handoff( """ fn = handoff._fn + # Same context-echo capture as :func:`_project_handoff`: the + # request's ``context`` extension lives on the TaskRecord and + # surfaces at the top level of ``tasks/get`` reads (#563). The + # WorkflowHandoff path persists the task and immediately returns + # — the adopter's enqueue fn does not write to the registry — so + # context capture must happen here at issue-time too. task_id = await registry.issue( account_id=ctx.account.id, task_type=method_name, + request_context=_extract_request_context(request_params), ) handoff_ctx = TaskHandoffContext(id=task_id, _registry=registry) diff --git a/src/adcp/decisioning/task_registry.py b/src/adcp/decisioning/task_registry.py index 6241423c..9d3f1af0 100644 --- a/src/adcp/decisioning/task_registry.py +++ b/src/adcp/decisioning/task_registry.py @@ -100,6 +100,15 @@ class TaskRecord: progress: dict[str, Any] | None = None result: dict[str, Any] | None = None error: dict[str, Any] | None = None + request_context: dict[str, Any] | None = None + """Buyer-supplied ``context`` extension from the request that + issued this task. Echoed to ``tasks/get`` responses at the + top-level ``context`` field per + ``schemas/cache/core/tasks_get_response.json`` (sibling of + ``result`` / ``error``). Captured at ``issue()`` time and + immutable afterwards — terminal-state transitions + (``complete`` / ``fail``) MUST NOT touch this field. Closes #563. + """ created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) @@ -109,8 +118,12 @@ def to_dict(self) -> dict[str, Any]: Adopters or middleware reading the dict shape get the exact wire-relevant fields. ``created_at`` / ``updated_at`` are included so admin tooling can build SLA reports. + + ``context`` lands at the top level — sibling of ``result`` + and ``error`` — matching the spec's + ``TasksGetResponse.context`` placement (#563). """ - return { + out: dict[str, Any] = { "task_id": self.task_id, "account_id": self.account_id, "state": self.state, @@ -121,6 +134,9 @@ def to_dict(self) -> dict[str, Any]: "created_at": self.created_at, "updated_at": self.updated_at, } + if self.request_context is not None: + out["context"] = self.request_context + return out @runtime_checkable @@ -187,6 +203,8 @@ async def issue( *, account_id: str, task_type: str, + request_context: dict[str, Any] | None = None, + **_extra: Any, ) -> str: """Allocate a fresh task_id, persist a ``submitted`` row, and return the id. @@ -197,6 +215,27 @@ async def issue( etc.). Persisted on the row and surfaced on ``tasks/get`` reads; NOT included in the synchronous Submitted envelope (per ``schemas/cache/core/protocol-envelope.json``). + :param request_context: Buyer-supplied ``context`` extension + from the request that issued this task. Persisted on the + row and surfaced at the top level of ``tasks/get`` + responses (sibling of ``result`` / ``error``) so buyers + can correlate polled task state with the kick-off + request. ``None`` when the request carried no context + field; the framework supplies it from the original + request params. Adopters writing custom registries SHOULD + store and surface this field; older registry impls that + ignore it are functionally compatible (no echo on + ``tasks/get`` reads, identical to pre-#563 behavior). + :param _extra: Forward-compat slot for kwargs added by future + framework versions. Custom registry impls MUST include + ``**_extra: Any`` on their ``issue()`` signature so the + framework can introduce new optional kwargs without + breaking adopters who haven't yet adopted the new field. + Implementations that don't recognize an extra kwarg + should silently ignore it (the framework only relies on + kwargs the Protocol explicitly declares). Logging the + unrecognized keys at DEBUG level is encouraged so + adopters notice when they've fallen behind. :returns: The framework-allocated task_id (string UUID). """ ... @@ -327,7 +366,18 @@ async def issue( *, account_id: str, task_type: str, + request_context: dict[str, Any] | None = None, + **_extra: Any, ) -> str: + # Forward-compat: log unrecognized kwargs at DEBUG so adopters + # who haven't yet upgraded notice when they're missing new + # framework fields. Don't raise — that would break adopters + # the moment a new version ships. + if _extra: + logger.debug( + "InMemoryTaskRegistry.issue ignoring unrecognized kwargs: %s", + list(_extra.keys()), + ) # Reject empty/unset account_id at issue-time. Without this, # two tenants whose AccountStore returns Account(id="") or the # default Account(id="") share a cache scope class and @@ -349,6 +399,7 @@ async def issue( account_id=account_id, state="submitted", task_type=task_type, + request_context=(dict(request_context) if request_context is not None else None), ) return task_id diff --git a/tests/test_decisioning_dispatch.py b/tests/test_decisioning_dispatch.py index 64c594df..757e1932 100644 --- a/tests/test_decisioning_dispatch.py +++ b/tests/test_decisioning_dispatch.py @@ -957,6 +957,151 @@ async def _handoff_fn(task_ctx): assert "internal bug" not in rec["error"].get("message", "") +@pytest.mark.asyncio +async def test_handoff_request_context_echoes_into_completed_task( + executor: ThreadPoolExecutor, +) -> None: + """Issue #563: when ``request_params`` is supplied with a + ``context`` field, the registry-stored success envelope echoes + that context. Buyer polling ``tasks/get`` on the completed task + sees the same ``context`` they sent on the kick-off request — + symmetric with the sync path's :func:`inject_context` and PR + #560's AdcpError raise path.""" + from pydantic import BaseModel as _Req + + class _ReqWithContext(_Req): + idempotency_key: str + context: dict[str, Any] + + registry = InMemoryTaskRegistry() + ctx = _build_request_context(ToolContext(), Account(id="acct_a"), None) + req = _ReqWithContext(idempotency_key="key-1", context={"correlation_id": "buyer-563"}) + + async def _handoff_fn(task_ctx): + return {"media_buy_id": "mb_1"} + + envelope = await _project_handoff( + TaskHandoff(_handoff_fn), + ctx, + method_name="create_media_buy", + registry=registry, + executor=executor, + request_params=req, + ) + await asyncio.sleep(0.1) + rec = await registry.get(envelope["task_id"], expected_account_id="acct_a") + assert rec is not None + assert rec["state"] == "completed" + # ``context`` lands at the top level — sibling of ``result`` per + # tasks_get_response.json. NOT inside result. + assert rec.get("context") == {"correlation_id": "buyer-563"} + assert "context" not in rec["result"] + + +@pytest.mark.asyncio +async def test_handoff_request_context_echoes_into_failed_task( + executor: ThreadPoolExecutor, +) -> None: + """Same echo on the AdcpError-raised path: registry.fail's wire + envelope carries the request's ``context`` alongside the + ``adcp_error`` shape.""" + from pydantic import BaseModel as _Req + + class _ReqWithContext(_Req): + idempotency_key: str + context: dict[str, Any] + + registry = InMemoryTaskRegistry() + ctx = _build_request_context(ToolContext(), Account(id="acct_a"), None) + req = _ReqWithContext(idempotency_key="key-2", context={"correlation_id": "buyer-fail-563"}) + + async def _handoff_fn(task_ctx): + raise AdcpError("POLICY_VIOLATION", message="rejected", recovery="correctable") + + envelope = await _project_handoff( + TaskHandoff(_handoff_fn), + ctx, + method_name="create_media_buy", + registry=registry, + executor=executor, + request_params=req, + ) + await asyncio.sleep(0.1) + rec = await registry.get(envelope["task_id"], expected_account_id="acct_a") + assert rec is not None + assert rec["state"] == "failed" + assert rec["error"]["code"] == "POLICY_VIOLATION" + # ``context`` lands at the top level of the wire shape — sibling + # of ``error``, not inside it (per tasks_get_response.json). + assert rec.get("context") == {"correlation_id": "buyer-fail-563"} + assert "context" not in rec["error"] + + +@pytest.mark.asyncio +async def test_handoff_unexpected_exception_echoes_context_too( + executor: ThreadPoolExecutor, +) -> None: + """Non-AdcpError exception → wrapped INTERNAL_ERROR → still + echoes context. The wrap path was the salesagent gap (#562 + follow-up territory) and the same fix applies on the bg path.""" + from pydantic import BaseModel as _Req + + class _ReqWithContext(_Req): + idempotency_key: str + context: dict[str, Any] + + registry = InMemoryTaskRegistry() + ctx = _build_request_context(ToolContext(), Account(id="acct_a"), None) + req = _ReqWithContext(idempotency_key="key-3", context={"correlation_id": "buyer-internal-563"}) + + async def _handoff_fn(task_ctx): + raise RuntimeError("bug") + + envelope = await _project_handoff( + TaskHandoff(_handoff_fn), + ctx, + method_name="create_media_buy", + registry=registry, + executor=executor, + request_params=req, + ) + await asyncio.sleep(0.1) + rec = await registry.get(envelope["task_id"], expected_account_id="acct_a") + assert rec is not None + assert rec["error"]["code"] == "INTERNAL_ERROR" + # Top-level context echo, not nested inside error. + assert rec.get("context") == {"correlation_id": "buyer-internal-563"} + assert "context" not in rec["error"] + + +@pytest.mark.asyncio +async def test_handoff_no_request_params_no_context_synthesised( + executor: ThreadPoolExecutor, +) -> None: + """When ``request_params`` is None (test fixtures, custom dispatch), + no context echo happens — the registry stores the wire envelope + as-is.""" + registry = InMemoryTaskRegistry() + ctx = _build_request_context(ToolContext(), Account(id="acct_a"), None) + + async def _handoff_fn(task_ctx): + return {"media_buy_id": "mb_no_params"} + + envelope = await _project_handoff( + TaskHandoff(_handoff_fn), + ctx, + method_name="create_media_buy", + registry=registry, + executor=executor, + ) + await asyncio.sleep(0.1) + rec = await registry.get(envelope["task_id"], expected_account_id="acct_a") + assert rec is not None + # No request_params → no context echo at any level of the wire shape. + assert "context" not in rec + assert "context" not in rec["result"] + + @pytest.mark.asyncio async def test_handoff_sync_fn_runs_on_executor( executor: ThreadPoolExecutor,