Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion src/adcp/decisioning/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,50 @@ def _walk_ctx_metadata_list(items: list[Any]) -> None:
raise ValueError(f"[{index}]: {exc}") from None


def _extract_request_context(params: Any) -> dict[str, Any] | None:
"""Pull the buyer-supplied ``context`` extension off the original
request for ``TaskRecord.request_context``.

The framework hands platform methods a typed Pydantic model in
production; test fixtures occasionally pass a raw dict.
``model_dump`` failures (rare — Pydantic models with
non-serializable ``extra='allow'`` fields) log + return ``None``
so downstream tasks/get reads simply omit ``context`` rather than
surfacing a partial / corrupted value. Buyers polling tasks/get
won't see a context echo on those (rare) requests, but their
failure-mode is "missed correlation," not "corrupted wire shape".

Returns ``None`` when the request had no context field, the
coercion failed, or ``params`` itself was ``None`` (test fixtures
invoking ``_project_handoff`` directly without going through
``_invoke_platform_method``). The 64KB amplification cap from
:func:`adcp.server.helpers.inject_context` does NOT apply at this
layer — the registry is server-internal storage, not a wire echo
surface; size-bounded enforcement on tasks/get reads should live
on the projection layer if required.
"""
if params is None:
return None
if isinstance(params, dict):
ctx = params.get("context")
return dict(ctx) if isinstance(ctx, dict) else None
if hasattr(params, "model_dump") and callable(params.model_dump):
try:
dumped = params.model_dump(mode="json", exclude_none=False)
except Exception:
logger.warning(
"request_params model_dump failed for %s; tasks/get context "
"echo skipped (correlation IDs lost). Verify the request "
"model serializes cleanly via model_dump.",
type(params).__name__,
exc_info=True,
)
return None
ctx = dumped.get("context") if isinstance(dumped, dict) else None
return dict(ctx) if isinstance(ctx, dict) else None
return None


def _internal_error_message(method_name: str, exc: BaseException) -> str:
"""Build the wire-side ``message`` for an INTERNAL_ERROR wrap.

Expand Down Expand Up @@ -1227,6 +1271,7 @@ async def _invoke_platform_method(
executor=executor,
on_complete=on_complete,
on_failure=on_failure,
request_params=params,
)
if is_workflow_handoff(result):
return await _project_workflow_handoff(
Expand All @@ -1235,6 +1280,7 @@ async def _invoke_platform_method(
method_name=method_name,
registry=registry,
executor=executor,
request_params=params,
)

# Sync return path. Fire on_complete with the typed result before
Expand Down Expand Up @@ -1291,6 +1337,7 @@ async def _project_handoff(
executor: ThreadPoolExecutor,
on_complete: Callable[[Any], Awaitable[None]] | None = None,
on_failure: Callable[[BaseException], Awaitable[None]] | None = None,
request_params: BaseModel | None = None,
) -> dict[str, Any]:
"""Promote a TaskHandoff to a background task.

Expand Down Expand Up @@ -1343,16 +1390,34 @@ async def _project_handoff(
needs the failure visible via ``tasks/get`` regardless of
hook outcomes.

:param request_params: The original request Pydantic model that
triggered the task. Used to echo the request's ``context``
extension into the registry-stored wire envelope on both
success (``registry.complete``) and failure
(``registry.fail``) paths — closes #563. Mirrors the sync
AdcpError path's context-passthrough (PR #560). When ``None``,
no echo happens (e.g. test fixtures invoking the handoff
helper directly).

The handoff fn is extracted via the type-identity dispatch in
:func:`adcp.decisioning.types.is_task_handoff`. Subclassed
TaskHandoff instances (deliberate non-feature) silently take the
sync-return path before reaching this function.
"""
fn = handoff._fn

# Extract the buyer's ``context`` extension from the original
# request and lock it onto the TaskRecord at issue-time. The
# registry surfaces it at the top level of ``tasks/get`` reads
# (sibling of ``result`` / ``error`` per
# ``schemas/cache/core/tasks_get_response.json``). Capturing once
# at issue-time means the terminal-state helpers (_fail,
# registry.complete) never need to know about request-side
# context — keeps the wire-shape boundary in one place.
task_id = await registry.issue(
account_id=ctx.account.id,
task_type=method_name,
request_context=_extract_request_context(request_params),
)

# Hand off to background. The wire envelope returns immediately;
Expand All @@ -1364,7 +1429,17 @@ async def _fail(exc: AdcpError) -> None:
"""Run the framework's on_failure hook (if set) then
``registry.fail``. Hook errors are logged but never block the
registry.fail — the buyer needs the failure visible via
tasks/get regardless of hook outcomes."""
tasks/get regardless of hook outcomes.

Note: the request's ``context`` extension lands on the
``tasks/get`` response at the top level (sibling of
``error``), not inside the ``adcp_error`` envelope — see
:meth:`TaskRecord.to_dict` and the
``schemas/cache/core/tasks_get_response.json``
``TasksGetResponse.context`` field. The context is captured
once at ``registry.issue()`` time below; ``_fail`` doesn't
touch it.
"""
if on_failure is not None:
try:
await on_failure(exc)
Expand Down Expand Up @@ -1460,6 +1535,12 @@ async def _run() -> None:
# the typed Pydantic response.
persisted = {"value": str(result)}
persisted = strip_credentials_from_wire_result(method_name, persisted)
# ``request.context`` echo lands at the top level of the
# ``tasks/get`` response (sibling of ``result``), not inside
# the typed result body. ``TaskRecord.request_context`` was
# captured at ``registry.issue()`` time and ``to_dict()``
# surfaces it under the top-level ``context`` key; nothing to
# do here on the result path.
await registry.complete(task_id, persisted)

# ``asyncio.create_task`` only weak-refs the resulting Task — under
Expand Down Expand Up @@ -1507,6 +1588,7 @@ async def _project_workflow_handoff(
method_name: str,
registry: TaskRegistry,
executor: ThreadPoolExecutor,
request_params: BaseModel | Any | None = None,
) -> dict[str, Any]:
"""Project a :class:`WorkflowHandoff` to the wire Submitted envelope.

Expand Down Expand Up @@ -1538,9 +1620,16 @@ async def _project_workflow_handoff(
"""
fn = handoff._fn

# Same context-echo capture as :func:`_project_handoff`: the
# request's ``context`` extension lives on the TaskRecord and
# surfaces at the top level of ``tasks/get`` reads (#563). The
# WorkflowHandoff path persists the task and immediately returns
# — the adopter's enqueue fn does not write to the registry — so
# context capture must happen here at issue-time too.
task_id = await registry.issue(
account_id=ctx.account.id,
task_type=method_name,
request_context=_extract_request_context(request_params),
)
handoff_ctx = TaskHandoffContext(id=task_id, _registry=registry)

Expand Down
53 changes: 52 additions & 1 deletion src/adcp/decisioning/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class TaskRecord:
progress: dict[str, Any] | None = None
result: dict[str, Any] | None = None
error: dict[str, Any] | None = None
request_context: dict[str, Any] | None = None
"""Buyer-supplied ``context`` extension from the request that
issued this task. Echoed to ``tasks/get`` responses at the
top-level ``context`` field per
``schemas/cache/core/tasks_get_response.json`` (sibling of
``result`` / ``error``). Captured at ``issue()`` time and
immutable afterwards — terminal-state transitions
(``complete`` / ``fail``) MUST NOT touch this field. Closes #563.
"""
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)

Expand All @@ -109,8 +118,12 @@ def to_dict(self) -> dict[str, Any]:
Adopters or middleware reading the dict shape get the exact
wire-relevant fields. ``created_at`` / ``updated_at`` are
included so admin tooling can build SLA reports.

``context`` lands at the top level — sibling of ``result``
and ``error`` — matching the spec's
``TasksGetResponse.context`` placement (#563).
"""
return {
out: dict[str, Any] = {
"task_id": self.task_id,
"account_id": self.account_id,
"state": self.state,
Expand All @@ -121,6 +134,9 @@ def to_dict(self) -> dict[str, Any]:
"created_at": self.created_at,
"updated_at": self.updated_at,
}
if self.request_context is not None:
out["context"] = self.request_context
return out


@runtime_checkable
Expand Down Expand Up @@ -187,6 +203,8 @@ async def issue(
*,
account_id: str,
task_type: str,
request_context: dict[str, Any] | None = None,
**_extra: Any,
) -> str:
"""Allocate a fresh task_id, persist a ``submitted`` row, and
return the id.
Expand All @@ -197,6 +215,27 @@ async def issue(
etc.). Persisted on the row and surfaced on ``tasks/get``
reads; NOT included in the synchronous Submitted envelope
(per ``schemas/cache/core/protocol-envelope.json``).
:param request_context: Buyer-supplied ``context`` extension
from the request that issued this task. Persisted on the
row and surfaced at the top level of ``tasks/get``
responses (sibling of ``result`` / ``error``) so buyers
can correlate polled task state with the kick-off
request. ``None`` when the request carried no context
field; the framework supplies it from the original
request params. Adopters writing custom registries SHOULD
store and surface this field; older registry impls that
ignore it are functionally compatible (no echo on
``tasks/get`` reads, identical to pre-#563 behavior).
:param _extra: Forward-compat slot for kwargs added by future
framework versions. Custom registry impls MUST include
``**_extra: Any`` on their ``issue()`` signature so the
framework can introduce new optional kwargs without
breaking adopters who haven't yet adopted the new field.
Implementations that don't recognize an extra kwarg
should silently ignore it (the framework only relies on
kwargs the Protocol explicitly declares). Logging the
unrecognized keys at DEBUG level is encouraged so
adopters notice when they've fallen behind.
:returns: The framework-allocated task_id (string UUID).
"""
...
Expand Down Expand Up @@ -327,7 +366,18 @@ async def issue(
*,
account_id: str,
task_type: str,
request_context: dict[str, Any] | None = None,
**_extra: Any,
) -> str:
# Forward-compat: log unrecognized kwargs at DEBUG so adopters
# who haven't yet upgraded notice when they're missing new
# framework fields. Don't raise — that would break adopters
# the moment a new version ships.
if _extra:
logger.debug(
"InMemoryTaskRegistry.issue ignoring unrecognized kwargs: %s",
list(_extra.keys()),
)
# Reject empty/unset account_id at issue-time. Without this,
# two tenants whose AccountStore returns Account(id="") or the
# default Account(id="<unset>") share a cache scope class and
Expand All @@ -349,6 +399,7 @@ async def issue(
account_id=account_id,
state="submitted",
task_type=task_type,
request_context=(dict(request_context) if request_context is not None else None),
)
return task_id

Expand Down
Loading
Loading