Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/adcp/decisioning/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
is_task_handoff,
is_workflow_handoff,
)
from adcp.server.idempotency.store import is_wrapped as _is_idem_wrapped

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
Expand Down Expand Up @@ -1080,7 +1081,10 @@ async def _invoke_platform_method(
try:
if asyncio.iscoroutinefunction(method):
if arg_projector is not None:
result = await method(**arg_projector, ctx=ctx)
if _is_idem_wrapped(method):
result = await method(**arg_projector, ctx=ctx, __adcp_params__=params)
else:
result = await method(**arg_projector, ctx=ctx)
elif extra_kwargs:
result = await method(params, ctx, **extra_kwargs)
else:
Expand All @@ -1089,7 +1093,14 @@ async def _invoke_platform_method(
ctx_snapshot = contextvars.copy_context()
loop = asyncio.get_running_loop()
if arg_projector is not None:
projected_kwargs = {**arg_projector, "ctx": ctx}
# _is_idem_wrapped(method) is always False here in practice
# (wrap() returns an async def), but the sentinel is wired in
# for forward-compatibility with a future sync-capable wrapper.
projected_kwargs = (
{**arg_projector, "ctx": ctx, "__adcp_params__": params}
if _is_idem_wrapped(method)
else {**arg_projector, "ctx": ctx}
)
result = await loop.run_in_executor(
executor,
functools.partial(ctx_snapshot.run, method, **projected_kwargs),
Expand Down
53 changes: 50 additions & 3 deletions src/adcp/server/idempotency/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
# only granted by IdempotencyStore.wrap itself.
_WRAPPED_FUNCTIONS: weakref.WeakSet[Callable[..., Any]] = weakref.WeakSet()

# Sentinel for kwarg presence detection — avoids truthiness bugs when the
# context object is falsy (e.g. a ToolContext subclass with __bool__ = False).
_MISSING: object = object()


def is_wrapped(fn: Any) -> bool:
"""Return True if ``fn`` was produced by :meth:`IdempotencyStore.wrap`.
Expand Down Expand Up @@ -145,17 +149,57 @@ def wrap(self, handler: HandlerFn) -> HandlerFn:
@wraps(handler)
async def _wrapped(
handler_self: Any,
params: Any,
params: Any = None,
context: Any = None,
*args: Any,
**kwargs: Any,
) -> Any:
scope_key, idempotency_key, params_dict = self._prepare(params, context)
# Dispatch threads the original wire-shape Pydantic model here when
# calling via arg-projector so we hash the canonical flat request
# (RFC 8785 / AdCP #2315) rather than the projected subset.
canonical_params = kwargs.pop("__adcp_params__", _MISSING)

# Detect arg-projector calling convention: params not provided and
# projected field kwargs present. Save proj_kwargs before mutation
# so we can re-dispatch with the exact same calling convention.
arg_projected = False
proj_kwargs: dict[str, Any] = {}

if params is None and kwargs:
# Use sentinel — not `or` — to avoid dropping a falsy ToolContext.
ctx_val = kwargs.pop("ctx", _MISSING)
if ctx_val is _MISSING:
ctx_val = kwargs.pop("context", _MISSING)
else:
kwargs.pop("context", _MISSING)

if kwargs: # remaining kwargs are projected request fields
arg_projected = True
proj_kwargs = dict(kwargs)
kwargs.clear()
if ctx_val is not _MISSING:
context = ctx_val

# If the method was called directly (adopter unit-test, no dispatch
# sentinel), there is no params model to hash or extract a key from.
# Fall through so direct calls work without idempotency semantics
# rather than raising a confusing TypeError inside _prepare.
if arg_projected and canonical_params is _MISSING:
return await handler(handler_self, **proj_kwargs, ctx=context)

# For hashing, prefer the canonical params object threaded from
# dispatch (full wire-shape model); fall back to positional params
# on the normal (non-arg-projected) call path.
hash_params = canonical_params if canonical_params is not _MISSING else params

scope_key, idempotency_key, params_dict = self._prepare(hash_params, context)
if scope_key is None or idempotency_key is None:
# No key → spec says the server MUST reject with INVALID_REQUEST.
# We let the handler run so validation layers above us (Pydantic,
# FastAPI, etc.) can reject with a typed error; the middleware's
# job is only to dedup when a key IS present.
if arg_projected:
return await handler(handler_self, **proj_kwargs, ctx=context)
return await handler(handler_self, params, context, *args, **kwargs)

payload_hash = self._hash_fn(params_dict)
Expand Down Expand Up @@ -183,7 +227,10 @@ async def _wrapped(
],
)

response = await handler(handler_self, params, context, *args, **kwargs)
if arg_projected:
response = await handler(handler_self, **proj_kwargs, ctx=context)
else:
response = await handler(handler_self, params, context, *args, **kwargs)
# Deep-copy when caching so post-return mutation of the caller's
# copy can't poison future replays. `_clone_response` also deep-
# copies on the hit path, giving independent objects per replay.
Expand Down
158 changes: 158 additions & 0 deletions tests/test_server_idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,164 @@ def test_ttl_maximum_accepted(self) -> None:
assert store.capability() == {"supported": True, "replay_ttl_seconds": 604800}


class TestArgProjectedCallingConvention:
"""Regression tests for issue #559 — @wrap on arg-projected methods.

The framework's arg-projector calls certain methods (e.g. update_media_buy)
via **kwargs rather than a positional params object:

method(media_buy_id="mb_1", patch=<UpdateMediaBuyRequest>, ctx=ctx,
__adcp_params__=original_params)

These tests verify that @wrap handles that calling convention correctly:
no TypeError, dedup actually fires, conflict detection works.
"""

def _make_store(self) -> IdempotencyStore:
return IdempotencyStore(backend=MemoryBackend(), ttl_seconds=86400)

@pytest.mark.asyncio
async def test_arg_projected_call_does_not_raise_type_error(self) -> None:
store = self._make_store()

class SellerPlatform:
def __init__(self) -> None:
self.call_count = 0

@store.wrap
async def update_media_buy(
self, media_buy_id: str, patch: Any, ctx: Any = None
) -> dict[str, Any]:
self.call_count += 1
return {"media_buy_id": media_buy_id, "status": "active"}

seller = SellerPlatform()
key = str(uuid.uuid4())
canonical_params = {"media_buy_id": "mb_1", "patch": {"budget": 500}, "idempotency_key": key}
ctx = ToolContext(caller_identity="principal-a")

result = await seller.update_media_buy(
media_buy_id="mb_1",
patch={"budget": 500},
ctx=ctx,
__adcp_params__=canonical_params,
)
assert result["media_buy_id"] == "mb_1"
assert seller.call_count == 1

@pytest.mark.asyncio
async def test_arg_projected_cache_hit_deduplicates(self) -> None:
store = self._make_store()

class SellerPlatform:
def __init__(self) -> None:
self.call_count = 0

@store.wrap
async def update_media_buy(
self, media_buy_id: str, patch: Any, ctx: Any = None
) -> dict[str, Any]:
self.call_count += 1
return {"media_buy_id": media_buy_id, "rev": self.call_count}

seller = SellerPlatform()
key = str(uuid.uuid4())
canonical_params = {"media_buy_id": "mb_1", "patch": {"budget": 500}, "idempotency_key": key}
ctx = ToolContext(caller_identity="principal-a")

r1 = await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx, __adcp_params__=canonical_params
)
r2 = await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx, __adcp_params__=canonical_params
)
assert seller.call_count == 1 # second call served from cache
assert r1 == r2

@pytest.mark.asyncio
async def test_arg_projected_payload_conflict_raises(self) -> None:
store = self._make_store()

class SellerPlatform:
@store.wrap
async def update_media_buy(
self, media_buy_id: str, patch: Any, ctx: Any = None
) -> dict[str, Any]:
return {"media_buy_id": media_buy_id}

seller = SellerPlatform()
key = str(uuid.uuid4())
ctx = ToolContext(caller_identity="principal-a")

canonical_a = {"media_buy_id": "mb_1", "patch": {"budget": 500}, "idempotency_key": key}
await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx, __adcp_params__=canonical_a
)
canonical_b = {"media_buy_id": "mb_1", "patch": {"budget": 999}, "idempotency_key": key}
with pytest.raises(IdempotencyConflictError):
await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 999}, ctx=ctx, __adcp_params__=canonical_b
)

@pytest.mark.asyncio
async def test_arg_projected_no_key_falls_through(self) -> None:
store = self._make_store()

class SellerPlatform:
def __init__(self) -> None:
self.call_count = 0

@store.wrap
async def update_media_buy(
self, media_buy_id: str, patch: Any, ctx: Any = None
) -> dict[str, Any]:
self.call_count += 1
return {"media_buy_id": media_buy_id}

seller = SellerPlatform()
ctx = ToolContext(caller_identity="principal-a")
canonical_params = {"media_buy_id": "mb_1", "patch": {"budget": 500}} # no idempotency_key

r1 = await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx, __adcp_params__=canonical_params
)
r2 = await seller.update_media_buy(
media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx, __adcp_params__=canonical_params
)
assert seller.call_count == 2 # no dedup without key
assert r1 == r2

@pytest.mark.asyncio
async def test_arg_projected_direct_call_no_sentinel_falls_through(self) -> None:
"""Adopter unit-test calling wrapped method directly (no dispatch sentinel).

Without __adcp_params__, the wrapper cannot extract an idempotency_key
and MUST fall through to the handler rather than raising TypeError.
Both calls must reach the handler (no dedup, since no key was hashed).
"""
store = self._make_store()

class SellerPlatform:
def __init__(self) -> None:
self.call_count = 0

@store.wrap
async def update_media_buy(
self, media_buy_id: str, patch: Any, ctx: Any = None
) -> dict[str, Any]:
self.call_count += 1
return {"media_buy_id": media_buy_id}

seller = SellerPlatform()
ctx = ToolContext(caller_identity="principal-a")

# Direct call without __adcp_params__ — must not raise TypeError
r1 = await seller.update_media_buy(media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx)
r2 = await seller.update_media_buy(media_buy_id="mb_1", patch={"budget": 500}, ctx=ctx)
assert seller.call_count == 2 # no dedup: no sentinel → no key → fall-through
assert r1 == r2


class TestTTLExpiry:
@pytest.mark.asyncio
async def test_cached_response_expires_after_ttl(self) -> None:
Expand Down
Loading