From ca174334442cd615ece5bdfd33b0b877469c2071 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 12:47:21 +0530 Subject: [PATCH 1/8] fix(sdk): buffer async generator outputs in control decorator --- .../src/agent_control/control_decorators.py | 250 +++++++++++++++--- sdks/python/tests/test_control_decorators.py | 233 ++++++++++++++++ 2 files changed, 442 insertions(+), 41 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index d1edb06e..5fee7f79 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -31,11 +31,11 @@ async def chat(message: str) -> str: import functools import inspect import time -from collections.abc import Callable -from dataclasses import dataclass, field +from collections.abc import AsyncGenerator, Callable, Mapping +from dataclasses import dataclass, field, fields, is_dataclass from typing import Any, TypeVar -from agent_control_models import Step +from agent_control_models import JSONValue, Step from agent_control import AgentControlClient from agent_control.evaluation import check_evaluation_with_local @@ -210,6 +210,33 @@ def __init__( ) +@dataclass +class _BufferedStreamCapture: + """Buffer streamed chunks for post-check evaluation before replay.""" + + replay_chunks: list[Any] = field(default_factory=list) + normalized_chunks: list[JSONValue] = field(default_factory=list) + + def add(self, chunk: Any) -> None: + self.replay_chunks.append(chunk) + self.normalized_chunks.append(_normalize_json_value(chunk)) + + def output_payload(self) -> JSONValue: + if not self.normalized_chunks: + return "" + + if all(isinstance(chunk, str) for chunk in self.normalized_chunks): + return "".join(chunk for chunk in self.normalized_chunks if isinstance(chunk, str)) + + output: dict[str, JSONValue] = {"chunks": list(self.normalized_chunks)} + text_output = "".join( + chunk for chunk in self.normalized_chunks if isinstance(chunk, str) + ) + if text_output: + output["text"] = text_output + return output + + def _get_current_agent() -> Any | None: """Get the current agent from agent_control module.""" try: @@ -224,6 +251,132 @@ def _get_server_url() -> str: return get_settings().url +def _copy_public_attributes(source: Any, target: Any) -> None: + """Copy public function attributes onto a wrapped callable.""" + for attr in dir(source): + if not attr.startswith("_") and attr not in ("__call__", "__wrapped__"): + try: + setattr(target, attr, getattr(source, attr)) + except (AttributeError, TypeError): + pass + + +def _normalize_json_value(value: Any, *, _seen: set[int] | None = None) -> JSONValue: + """Convert runtime values into JSON-safe data for evaluation payloads.""" + if isinstance(value, (str, int, float, bool)) or value is None: + return value + + if isinstance(value, (bytes, bytearray)): + return bytes(value).decode("utf-8", errors="replace") + + seen = set() if _seen is None else _seen + value_id = id(value) + if value_id in seen: + return str(value) + + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + return _normalize_json_value(model_dump(mode="json"), _seen=seen) + + dict_method = getattr(value, "dict", None) + if callable(dict_method): + return _normalize_json_value(dict_method(), _seen=seen) + + if isinstance(value, Mapping): + seen.add(value_id) + try: + return { + str(key): _normalize_json_value(item, _seen=seen) + for key, item in value.items() + } + finally: + seen.remove(value_id) + + if isinstance(value, (list, tuple, set, frozenset)): + seen.add(value_id) + try: + return [_normalize_json_value(item, _seen=seen) for item in value] + finally: + seen.remove(value_id) + + if is_dataclass(value) and not isinstance(value, type): + seen.add(value_id) + try: + return { + field_info.name: _normalize_json_value( + getattr(value, field_info.name), + _seen=seen, + ) + for field_info in fields(value) + } + finally: + seen.remove(value_id) + + value_dict = getattr(value, "__dict__", None) + if isinstance(value_dict, dict): + seen.add(value_id) + try: + return { + str(key): _normalize_json_value(item, _seen=seen) + for key, item in value_dict.items() + } + finally: + seen.remove(value_id) + + return str(value) + + +def _normalized_output_value(output: Any) -> JSONValue | None: + """Return normalized output for a post-check payload.""" + if output is None: + return None + return _normalize_json_value(output) + + +def _build_control_context( + func: Callable, + args: tuple, + kwargs: dict, + step_name: str | None, +) -> tuple[ControlContext, list[dict[str, Any]] | None] | None: + """Create the shared control context for a protected invocation.""" + agent = _get_current_agent() + if agent is None: + return None + + controls = _get_server_controls() + + existing_trace_id = get_current_trace_id() + if existing_trace_id: + trace_id = existing_trace_id + span_id = _generate_span_id() + else: + trace_id, span_id = get_trace_and_span_ids() + + return ( + ControlContext( + agent_name=agent.agent_name, + server_url=_get_server_url(), + func=func, + args=args, + kwargs=kwargs, + trace_id=trace_id, + span_id=span_id, + start_time=time.perf_counter(), + step_name=step_name, + ), + controls, + ) + + +async def _close_async_generator(stream: AsyncGenerator[Any, None]) -> None: + """Close a buffered async generator without masking the original failure.""" + try: + await stream.aclose() + except Exception: + logger.debug("Failed to close buffered async generator cleanly", exc_info=True) + + async def _evaluate( agent_name: str, step: dict[str, Any], @@ -480,9 +633,7 @@ def _create_evaluation_payload( "type": "tool", "name": determined_name, "input": dict(bound.arguments), - "output": output if isinstance(output, (str, int, float, bool, dict, list)) else ( - None if output is None else str(output) - ), + "output": _normalized_output_value(output), } # This is an LLM step @@ -491,9 +642,7 @@ def _create_evaluation_payload( "type": "llm", "name": determined_name, "input": input_data, - "output": output if isinstance(output, (str, int, float, bool, dict, list)) else ( - None if output is None else str(output) - ), + "output": _normalized_output_value(output), } @@ -687,8 +836,8 @@ async def _execute_with_control( ControlSteerError: If any control triggers with "steer" action RuntimeError: If control evaluation fails unexpectedly """ - agent = _get_current_agent() - if agent is None: + context = _build_control_context(func, args, kwargs, step_name) + if context is None: logger.warning( "No agent initialized. Call agent_control.init() first. " "Running without protection." @@ -697,29 +846,7 @@ async def _execute_with_control( return await func(*args, **kwargs) return func(*args, **kwargs) - # Get cached controls for local evaluation support - controls = _get_server_controls() - - # Get trace context: inherit trace_id if set, always generate new span_id - # This allows multiple @control() calls to share the same trace but have unique spans - existing_trace_id = get_current_trace_id() - if existing_trace_id: - trace_id = existing_trace_id - span_id = _generate_span_id() # New span for this function - else: - trace_id, span_id = get_trace_and_span_ids() # New trace and span - - ctx = ControlContext( - agent_name=agent.agent_name, - server_url=_get_server_url(), - func=func, - args=args, - kwargs=kwargs, - trace_id=trace_id, - span_id=span_id, - start_time=time.perf_counter(), - step_name=step_name, - ) + ctx, controls = context ctx.log_start() try: @@ -768,6 +895,11 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable 3. After function execution: Calls server with stage="post" - Server evaluates all matching "post" controls for the agent + Async generator note: + - Decorated async generators are buffered before replay. + - The post-check runs on the full buffered output. + - Chunks are yielded to the caller only after the post-check passes. + Example: import agent_control @@ -813,19 +945,53 @@ def decorator(func: F) -> F: register(func, policy) + @functools.wraps(func) + async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + context = _build_control_context(func, args, kwargs, step_name) + if context is None: + logger.warning( + "No agent initialized. Call agent_control.init() first. " + "Running without protection." + ) + async for chunk in func(*args, **kwargs): + yield chunk + return + + ctx, controls = context + ctx.log_start() + stream = func(*args, **kwargs) + capture = _BufferedStreamCapture() + + try: + await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + + try: + async for chunk in stream: + capture.add(chunk) + except BaseException: + await _close_async_generator(stream) + raise + + await _run_control_check( + ctx, + "post", + ctx.post_payload(capture.output_payload()), + controls, + ) + + for chunk in capture.replay_chunks: + yield chunk + finally: + ctx.log_end() + @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: return await _execute_with_control( func, args, kwargs, is_async=True, step_name=step_name ) - # Copy over ALL attributes from the original function (important for LangChain tools) - for attr in dir(func): - if not attr.startswith('_') and attr not in ('__call__', '__wrapped__'): - try: - setattr(async_wrapper, attr, getattr(func, attr)) - except (AttributeError, TypeError): - pass + _copy_public_attributes(func, async_gen_wrapper) + _copy_public_attributes(func, async_wrapper) @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: @@ -833,6 +999,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: _execute_with_control(func, args, kwargs, is_async=False, step_name=step_name) ) + if inspect.isasyncgenfunction(func): + return async_gen_wrapper # type: ignore if inspect.iscoroutinefunction(func): return async_wrapper # type: ignore return sync_wrapper # type: ignore diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index d13ec46d..992d804b 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -1,5 +1,7 @@ """Tests for @control decorator.""" +import inspect +from dataclasses import dataclass from unittest.mock import MagicMock, patch import pytest @@ -788,6 +790,237 @@ async def test_func(): assert exc_info.value.steering_context == "Default evaluator message" +# ============================================================================= +# ASYNC GENERATOR TESTS +# ============================================================================= + + +@dataclass +class _StructuredChunk: + value: str + + +class TestAsyncGeneratorControl: + """Tests for buffered async-generator support in @control().""" + + @pytest.mark.asyncio + async def test_preserves_async_generator_identity(self, mock_agent, mock_safe_response): + """Test that decorating an async generator preserves its async-gen type.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + yield "hello" + + assert inspect.isasyncgenfunction(stream) + + @pytest.mark.asyncio + async def test_buffers_stream_until_post_check_passes(self, mock_agent, mock_safe_response): + """Test that async-generator output is buffered before replay.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + collected = [] + async for chunk in stream("test"): + collected.append(chunk) + break + + assert collected == ["chunk1"] + assert call_stages == ["pre", "post"] + + @pytest.mark.asyncio + async def test_post_check_receives_joined_text_output_for_string_chunks( + self, + mock_agent, + mock_safe_response, + ): + """Test that string chunks are joined for the buffered post-check payload.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "hello " + yield "world" + + chunks = [chunk async for chunk in stream("test")] + assert chunks == ["hello ", "world"] + assert post_steps[0]["output"] == "hello world" + + @pytest.mark.asyncio + async def test_structured_chunks_are_normalized_without_lossy_repr( + self, + mock_agent, + mock_safe_response, + ): + """Test that structured chunks are normalized into JSON-safe post payloads.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield {"kind": "json"} + yield _StructuredChunk(value="structured") + yield b"tail" + + chunks = [chunk async for chunk in stream("test")] + assert chunks == [{"kind": "json"}, _StructuredChunk(value="structured"), b"tail"] + assert post_steps[0]["output"] == { + "chunks": [ + {"kind": "json"}, + {"value": "structured"}, + "tail", + ], + "text": "tail", + } + + @pytest.mark.asyncio + async def test_post_check_block_happens_before_any_chunk_is_replayed( + self, + mock_agent, + mock_safe_response, + mock_unsafe_response, + ): + """Test that a post-stage deny blocks the buffered stream before replay starts.""" + call_count = 0 + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_safe_response + return mock_unsafe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + + collected = [] + with pytest.raises(ControlViolationError): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + + @pytest.mark.asyncio + async def test_generator_failure_skips_post_check_and_replays_nothing( + self, + mock_agent, + mock_safe_response, + ): + """Test that a source stream failure propagates before any replay or post-check.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + raise RuntimeError("stream failed") + + collected = [] + with pytest.raises(RuntimeError, match="stream failed"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + + @pytest.mark.asyncio + async def test_async_generator_without_agent_passthrough(self): + """Test that async generators pass through unchanged without an initialized agent.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=None): + + @control() + async def stream(message: str): + yield "a" + yield "b" + + chunks = [chunk async for chunk in stream("hello")] + assert chunks == ["a", "b"] + + # ============================================================================= # EXCEPTION HANDLING TESTS # ============================================================================= From 0ee78fd5d1cb43b16aa8bd4cb754e0eee8d416f9 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 13:18:37 +0530 Subject: [PATCH 2/8] fix(sdk): tighten async generator chunk normalization --- .../src/agent_control/control_decorators.py | 42 +++++++++-------- sdks/python/tests/test_control_decorators.py | 45 +++++++++++++++++++ 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 5fee7f79..06552f60 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -225,13 +225,12 @@ def output_payload(self) -> JSONValue: if not self.normalized_chunks: return "" - if all(isinstance(chunk, str) for chunk in self.normalized_chunks): - return "".join(chunk for chunk in self.normalized_chunks if isinstance(chunk, str)) + text_chunks = [chunk for chunk in self.normalized_chunks if isinstance(chunk, str)] + if len(text_chunks) == len(self.normalized_chunks): + return "".join(text_chunks) output: dict[str, JSONValue] = {"chunks": list(self.normalized_chunks)} - text_output = "".join( - chunk for chunk in self.normalized_chunks if isinstance(chunk, str) - ) + text_output = "".join(text_chunks) if text_output: output["text"] = text_output return output @@ -276,11 +275,19 @@ def _normalize_json_value(value: Any, *, _seen: set[int] | None = None) -> JSONV model_dump = getattr(value, "model_dump", None) if callable(model_dump): - return _normalize_json_value(model_dump(mode="json"), _seen=seen) + seen.add(value_id) + try: + return _normalize_json_value(model_dump(mode="json"), _seen=seen) + finally: + seen.remove(value_id) dict_method = getattr(value, "dict", None) if callable(dict_method): - return _normalize_json_value(dict_method(), _seen=seen) + seen.add(value_id) + try: + return _normalize_json_value(dict_method(), _seen=seen) + finally: + seen.remove(value_id) if isinstance(value, Mapping): seen.add(value_id) @@ -369,14 +376,6 @@ def _build_control_context( ) -async def _close_async_generator(stream: AsyncGenerator[Any, None]) -> None: - """Close a buffered async generator without masking the original failure.""" - try: - await stream.aclose() - except Exception: - logger.debug("Failed to close buffered async generator cleanly", exc_info=True) - - async def _evaluate( agent_name: str, step: dict[str, Any], @@ -946,7 +945,10 @@ def decorator(func: F) -> F: register(func, policy) @functools.wraps(func) - async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + async def async_gen_wrapper( + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Any, None]: context = _build_control_context(func, args, kwargs, step_name) if context is None: logger.warning( @@ -965,12 +967,8 @@ async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: try: await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) - try: - async for chunk in stream: - capture.add(chunk) - except BaseException: - await _close_async_generator(stream) - raise + async for chunk in stream: + capture.add(chunk) await _run_control_check( ctx, diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index 992d804b..dcecedc4 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -800,6 +800,15 @@ class _StructuredChunk: value: str +class _RecursiveModelDumpChunk: + def model_dump(self, mode: str = "json") -> dict[str, object]: + _ = mode + return {"self": self} + + def __str__(self) -> str: + return "recursive-model-dump" + + class TestAsyncGeneratorControl: """Tests for buffered async-generator support in @control().""" @@ -928,6 +937,42 @@ async def stream(message: str): "text": "tail", } + @pytest.mark.asyncio + async def test_model_dump_cycle_is_guarded_during_chunk_normalization( + self, + mock_agent, + mock_safe_response, + ): + """Test that self-referential model_dump payloads do not recurse forever.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _RecursiveModelDumpChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == { + "chunks": [{"self": "recursive-model-dump"}] + } + @pytest.mark.asyncio async def test_post_check_block_happens_before_any_chunk_is_replayed( self, From 55a1a3c1dbb00ab811c3729f910d1f95d511c365 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 14:48:51 +0530 Subject: [PATCH 3/8] fix(sdk): harden buffered async generator serialization --- .../src/agent_control/control_decorators.py | 172 +++++++++----- sdks/python/src/agent_control/settings.py | 12 + sdks/python/tests/test_control_decorators.py | 215 ++++++++++++++++++ 3 files changed, 339 insertions(+), 60 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 06552f60..49dca5e1 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -30,6 +30,8 @@ async def chat(message: str) -> str: import asyncio import functools import inspect +import json +import math import time from collections.abc import AsyncGenerator, Callable, Mapping from dataclasses import dataclass, field, fields, is_dataclass @@ -214,12 +216,31 @@ def __init__( class _BufferedStreamCapture: """Buffer streamed chunks for post-check evaluation before replay.""" + max_chunks: int + max_bytes: int replay_chunks: list[Any] = field(default_factory=list) normalized_chunks: list[JSONValue] = field(default_factory=list) + buffered_bytes: int = 0 def add(self, chunk: Any) -> None: + next_chunk_count = len(self.replay_chunks) + 1 + if next_chunk_count > self.max_chunks: + raise RuntimeError( + "Buffered async generator output exceeded the configured chunk limit. " + "Failing closed." + ) + + normalized_chunk = _normalize_json_value(chunk) + next_buffered_bytes = self.buffered_bytes + _json_value_size(normalized_chunk) + if next_buffered_bytes > self.max_bytes: + raise RuntimeError( + "Buffered async generator output exceeded the configured size limit. " + "Failing closed." + ) + self.replay_chunks.append(chunk) - self.normalized_chunks.append(_normalize_json_value(chunk)) + self.normalized_chunks.append(normalized_chunk) + self.buffered_bytes = next_buffered_bytes def output_payload(self) -> JSONValue: if not self.normalized_chunks: @@ -260,77 +281,102 @@ def _copy_public_attributes(source: Any, target: Any) -> None: pass +def _safe_stringify(value: Any) -> str: + """Stringify a value without letting representation errors escape.""" + try: + return str(value) + except Exception: + return f"" + + +def _json_value_size(value: JSONValue) -> int: + """Estimate normalized JSON payload size in bytes for buffered streams.""" + try: + return len( + json.dumps( + value, + ensure_ascii=False, + sort_keys=True, + allow_nan=False, + ).encode("utf-8") + ) + except (TypeError, ValueError): + return len(_safe_stringify(value).encode("utf-8")) + + def _normalize_json_value(value: Any, *, _seen: set[int] | None = None) -> JSONValue: """Convert runtime values into JSON-safe data for evaluation payloads.""" - if isinstance(value, (str, int, float, bool)) or value is None: + if isinstance(value, str) or value is None: + return value + + if isinstance(value, bool | int): return value + if isinstance(value, float): + return value if math.isfinite(value) else _safe_stringify(value) + if isinstance(value, (bytes, bytearray)): return bytes(value).decode("utf-8", errors="replace") seen = set() if _seen is None else _seen value_id = id(value) if value_id in seen: - return str(value) + return _safe_stringify(value) - model_dump = getattr(value, "model_dump", None) - if callable(model_dump): - seen.add(value_id) - try: - return _normalize_json_value(model_dump(mode="json"), _seen=seen) - finally: - seen.remove(value_id) - - dict_method = getattr(value, "dict", None) - if callable(dict_method): - seen.add(value_id) - try: - return _normalize_json_value(dict_method(), _seen=seen) - finally: - seen.remove(value_id) - - if isinstance(value, Mapping): - seen.add(value_id) - try: - return { - str(key): _normalize_json_value(item, _seen=seen) - for key, item in value.items() - } - finally: - seen.remove(value_id) - - if isinstance(value, (list, tuple, set, frozenset)): - seen.add(value_id) - try: - return [_normalize_json_value(item, _seen=seen) for item in value] - finally: - seen.remove(value_id) - - if is_dataclass(value) and not isinstance(value, type): - seen.add(value_id) - try: - return { - field_info.name: _normalize_json_value( - getattr(value, field_info.name), - _seen=seen, - ) - for field_info in fields(value) - } - finally: - seen.remove(value_id) + try: + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + seen.add(value_id) + try: + return _normalize_json_value(model_dump(mode="json"), _seen=seen) + except Exception: + return _safe_stringify(value) + finally: + seen.remove(value_id) - value_dict = getattr(value, "__dict__", None) - if isinstance(value_dict, dict): - seen.add(value_id) - try: - return { - str(key): _normalize_json_value(item, _seen=seen) - for key, item in value_dict.items() - } - finally: - seen.remove(value_id) + dict_method = getattr(value, "dict", None) + if callable(dict_method): + seen.add(value_id) + try: + return _normalize_json_value(dict_method(), _seen=seen) + except Exception: + return _safe_stringify(value) + finally: + seen.remove(value_id) - return str(value) + if isinstance(value, Mapping): + seen.add(value_id) + try: + return { + str(key): _normalize_json_value(item, _seen=seen) + for key, item in value.items() + } + finally: + seen.remove(value_id) + + if isinstance(value, (list, tuple, set, frozenset)): + seen.add(value_id) + try: + return [_normalize_json_value(item, _seen=seen) for item in value] + finally: + seen.remove(value_id) + + if is_dataclass(value) and not isinstance(value, type): + seen.add(value_id) + try: + return { + field_info.name: _normalize_json_value( + getattr(value, field_info.name), + _seen=seen, + ) + for field_info in fields(value) + } + finally: + seen.remove(value_id) + except Exception: + return _safe_stringify(value) + + return _safe_stringify(value) def _normalized_output_value(output: Any) -> JSONValue | None: @@ -962,7 +1008,11 @@ async def async_gen_wrapper( ctx, controls = context ctx.log_start() stream = func(*args, **kwargs) - capture = _BufferedStreamCapture() + settings = get_settings() + capture = _BufferedStreamCapture( + max_chunks=settings.stream_buffer_max_chunks, + max_bytes=settings.stream_buffer_max_bytes, + ) try: await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) @@ -997,6 +1047,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: _execute_with_control(func, args, kwargs, is_async=False, step_name=step_name) ) + _copy_public_attributes(func, sync_wrapper) + if inspect.isasyncgenfunction(func): return async_gen_wrapper # type: ignore if inspect.iscoroutinefunction(func): diff --git a/sdks/python/src/agent_control/settings.py b/sdks/python/src/agent_control/settings.py index f06398d4..51b50d00 100644 --- a/sdks/python/src/agent_control/settings.py +++ b/sdks/python/src/agent_control/settings.py @@ -55,6 +55,18 @@ class SDKSettings(BaseSettings): description="API key for server authentication", ) + # Buffered async-generator control limits + stream_buffer_max_chunks: int = Field( + default=10000, + ge=1, + description="Maximum streamed chunks buffered by @control() before failing closed", + ) + stream_buffer_max_bytes: int = Field( + default=5_000_000, + ge=1, + description="Maximum normalized bytes buffered by @control() before failing closed", + ) + # Observability (event batching) observability_enabled: bool = Field( default=True, diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index dcecedc4..4eb6fee5 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -7,6 +7,7 @@ import pytest from agent_control.control_decorators import ControlViolationError, ControlSteerError, control +from agent_control.settings import configure_settings, get_settings # ============================================================================= @@ -318,6 +319,76 @@ async def chat(message: str) -> str: assert "output" in captured_step assert "Generated response" in captured_step["output"] + @pytest.mark.asyncio + async def test_post_check_serializer_failure_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that serializer hook failures do not escape the post-check path.""" + captured_step = {} + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + captured_step.update(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def chat(message: str) -> _ExplodingModelDumpChunk: + return _ExplodingModelDumpChunk() + + result = await chat("Hello!") + + assert isinstance(result, _ExplodingModelDumpChunk) + assert captured_step["output"] == "exploding-model-dump" + + @pytest.mark.asyncio + async def test_post_check_non_finite_float_is_stringified( + self, + mock_agent, + mock_safe_response, + ): + """Test that non-finite floats are normalized to JSON-safe strings.""" + captured_step = {} + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + captured_step.update(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def chat(message: str) -> float: + return float("inf") + + result = await chat("Hello!") + + assert result == float("inf") + assert captured_step["output"] == "inf" + # ============================================================================= # INPUT EXTRACTION TESTS @@ -809,6 +880,23 @@ def __str__(self) -> str: return "recursive-model-dump" +class _ExplodingModelDumpChunk: + def model_dump(self, mode: str = "json") -> dict[str, object]: + _ = mode + raise RuntimeError("model_dump exploded") + + def __str__(self) -> str: + return "exploding-model-dump" + + +class _PlainObjectChunk: + def __init__(self) -> None: + self.secret = "top-secret" + + def __str__(self) -> str: + return "plain-object" + + class TestAsyncGeneratorControl: """Tests for buffered async-generator support in @control().""" @@ -973,6 +1061,74 @@ async def stream(message: str): "chunks": [{"self": "recursive-model-dump"}] } + @pytest.mark.asyncio + async def test_chunk_serializer_failure_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that serializer hook failures on streamed chunks fall back to strings.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _ExplodingModelDumpChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "exploding-model-dump" + + @pytest.mark.asyncio + async def test_plain_object_chunk_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that plain objects are stringified instead of expanding __dict__.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _PlainObjectChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "plain-object" + @pytest.mark.asyncio async def test_post_check_block_happens_before_any_chunk_is_replayed( self, @@ -1014,6 +1170,51 @@ async def stream(message: str): assert collected == [] + @pytest.mark.asyncio + async def test_stream_buffer_chunk_limit_fails_closed_before_replay( + self, + mock_agent, + mock_safe_response, + ): + """Test that exceeding the configured chunk cap blocks buffered replay.""" + original_settings = get_settings().model_dump() + limited_settings = {**original_settings, "stream_buffer_max_chunks": 2} + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + configure_settings(**limited_settings) + try: + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + collected = [] + with pytest.raises(RuntimeError, match="chunk limit"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + finally: + configure_settings(**original_settings) + @pytest.mark.asyncio async def test_generator_failure_skips_post_check_and_replays_nothing( self, @@ -1065,6 +1266,20 @@ async def stream(message: str): chunks = [chunk async for chunk in stream("hello")] assert chunks == ["a", "b"] + def test_sync_wrapper_copies_public_attributes(self, mock_agent, mock_safe_response): + """Test that sync wrappers preserve custom public attributes.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + def process(input: str) -> str: + return input.upper() + + process.description = "Uppercases the input" + + wrapped = control()(process) + + assert wrapped.description == "Uppercases the input" + # ============================================================================= # EXCEPTION HANDLING TESTS From c6b946edb3c886aac7c7c59eee205fd4f64503dd Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 15:08:56 +0530 Subject: [PATCH 4/8] fix(sdk): document and guard buffered replay semantics --- .../src/agent_control/control_decorators.py | 32 +++++- sdks/python/tests/test_control_decorators.py | 101 ++++++++++++++---- 2 files changed, 108 insertions(+), 25 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 49dca5e1..9625b7cf 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -220,7 +220,9 @@ class _BufferedStreamCapture: max_bytes: int replay_chunks: list[Any] = field(default_factory=list) normalized_chunks: list[JSONValue] = field(default_factory=list) - buffered_bytes: int = 0 + chunk_bytes: int = 0 + text_bytes: int = 0 + has_non_string_chunk: bool = False def add(self, chunk: Any) -> None: next_chunk_count = len(self.replay_chunks) + 1 @@ -231,8 +233,19 @@ def add(self, chunk: Any) -> None: ) normalized_chunk = _normalize_json_value(chunk) - next_buffered_bytes = self.buffered_bytes + _json_value_size(normalized_chunk) - if next_buffered_bytes > self.max_bytes: + next_chunk_bytes = self.chunk_bytes + _json_value_size(normalized_chunk) + next_text_bytes = self.text_bytes + next_has_non_string_chunk = self.has_non_string_chunk + if isinstance(normalized_chunk, str): + next_text_bytes += _json_value_size(normalized_chunk) + else: + next_has_non_string_chunk = True + + next_payload_bytes = next_chunk_bytes + if next_has_non_string_chunk: + next_payload_bytes += next_text_bytes + + if next_payload_bytes > self.max_bytes: raise RuntimeError( "Buffered async generator output exceeded the configured size limit. " "Failing closed." @@ -240,7 +253,9 @@ def add(self, chunk: Any) -> None: self.replay_chunks.append(chunk) self.normalized_chunks.append(normalized_chunk) - self.buffered_bytes = next_buffered_bytes + self.chunk_bytes = next_chunk_bytes + self.text_bytes = next_text_bytes + self.has_non_string_chunk = next_has_non_string_chunk def output_payload(self) -> JSONValue: if not self.normalized_chunks: @@ -944,6 +959,8 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable - Decorated async generators are buffered before replay. - The post-check runs on the full buffered output. - Chunks are yielded to the caller only after the post-check passes. + - Buffered async generators are pull-only and do not preserve interactive + asend()/athrow() semantics against the source generator. Example: import agent_control @@ -1028,7 +1045,12 @@ async def async_gen_wrapper( ) for chunk in capture.replay_chunks: - yield chunk + sent = yield chunk + if sent is not None: + raise TypeError( + "Buffered @control() async generators do not support " + "asend(non-None)." + ) finally: ctx.log_end() diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index 4eb6fee5..8912f1e8 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -7,7 +7,7 @@ import pytest from agent_control.control_decorators import ControlViolationError, ControlSteerError, control -from agent_control.settings import configure_settings, get_settings +from agent_control.settings import get_settings # ============================================================================= @@ -1175,10 +1175,9 @@ async def test_stream_buffer_chunk_limit_fails_closed_before_replay( self, mock_agent, mock_safe_response, + monkeypatch, ): """Test that exceeding the configured chunk cap blocks buffered replay.""" - original_settings = get_settings().model_dump() - limited_settings = {**original_settings, "stream_buffer_max_chunks": 2} call_stages = [] async def mock_evaluate( @@ -1194,26 +1193,66 @@ async def mock_evaluate( call_stages.append(stage) return mock_safe_response - configure_settings(**limited_settings) - try: - with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ - patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + monkeypatch.setattr(get_settings(), "stream_buffer_max_chunks", 2) - @control() - async def stream(message: str): - yield "chunk1" - yield "chunk2" - yield "chunk3" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" - collected = [] - with pytest.raises(RuntimeError, match="chunk limit"): - async for chunk in stream("test"): - collected.append(chunk) + collected = [] + with pytest.raises(RuntimeError, match="chunk limit"): + async for chunk in stream("test"): + collected.append(chunk) - assert collected == [] - assert call_stages == ["pre"] - finally: - configure_settings(**original_settings) + assert collected == [] + assert call_stages == ["pre"] + + @pytest.mark.asyncio + async def test_stream_buffer_byte_limit_accounts_for_mixed_text_duplication( + self, + mock_agent, + mock_safe_response, + monkeypatch, + ): + """Test that mixed streams count duplicated text toward the byte cap.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + monkeypatch.setattr(get_settings(), "stream_buffer_max_bytes", 15) + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "a" + yield {"kind": "json"} + yield "b" + + collected = [] + with pytest.raises(RuntimeError, match="size limit"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] @pytest.mark.asyncio async def test_generator_failure_skips_post_check_and_replays_nothing( @@ -1266,6 +1305,28 @@ async def stream(message: str): chunks = [chunk async for chunk in stream("hello")] assert chunks == ["a", "b"] + @pytest.mark.asyncio + async def test_asend_non_none_is_rejected_for_buffered_replay( + self, + mock_agent, + mock_safe_response, + ): + """Test that buffered replay rejects interactive sends.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + received = yield "chunk1" + yield f"received={received}" + + agen = stream("hello") + first = await agen.__anext__() + assert first == "chunk1" + + with pytest.raises(TypeError, match="asend\\(non-None\\)"): + await agen.asend("interactive-input") + def test_sync_wrapper_copies_public_attributes(self, mock_agent, mock_safe_response): """Test that sync wrappers preserve custom public attributes.""" with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ From 007d79afd7a3c41e9c71df6c8f9f04cfe1a1c739 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 15:27:22 +0530 Subject: [PATCH 5/8] fix(sdk): tighten buffered payload size accounting --- .../src/agent_control/control_decorators.py | 52 ++++++++++++++----- sdks/python/tests/test_control_decorators.py | 42 +++++++++++++++ 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index 9625b7cf..d1721d04 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -220,8 +220,9 @@ class _BufferedStreamCapture: max_bytes: int replay_chunks: list[Any] = field(default_factory=list) normalized_chunks: list[JSONValue] = field(default_factory=list) - chunk_bytes: int = 0 - text_bytes: int = 0 + chunk_item_bytes: int = 0 + text_content_bytes: int = 0 + string_chunk_count: int = 0 has_non_string_chunk: bool = False def add(self, chunk: Any) -> None: @@ -233,18 +234,23 @@ def add(self, chunk: Any) -> None: ) normalized_chunk = _normalize_json_value(chunk) - next_chunk_bytes = self.chunk_bytes + _json_value_size(normalized_chunk) - next_text_bytes = self.text_bytes + next_chunk_item_bytes = self.chunk_item_bytes + _json_value_size(normalized_chunk) + next_text_content_bytes = self.text_content_bytes + next_string_chunk_count = self.string_chunk_count next_has_non_string_chunk = self.has_non_string_chunk if isinstance(normalized_chunk, str): - next_text_bytes += _json_value_size(normalized_chunk) + next_text_content_bytes += _json_value_size(normalized_chunk) - 2 + next_string_chunk_count += 1 else: next_has_non_string_chunk = True - next_payload_bytes = next_chunk_bytes - if next_has_non_string_chunk: - next_payload_bytes += next_text_bytes - + next_payload_bytes = self._payload_size( + chunk_count=next_chunk_count, + chunk_item_bytes=next_chunk_item_bytes, + text_content_bytes=next_text_content_bytes, + string_chunk_count=next_string_chunk_count, + has_non_string_chunk=next_has_non_string_chunk, + ) if next_payload_bytes > self.max_bytes: raise RuntimeError( "Buffered async generator output exceeded the configured size limit. " @@ -253,10 +259,31 @@ def add(self, chunk: Any) -> None: self.replay_chunks.append(chunk) self.normalized_chunks.append(normalized_chunk) - self.chunk_bytes = next_chunk_bytes - self.text_bytes = next_text_bytes + self.chunk_item_bytes = next_chunk_item_bytes + self.text_content_bytes = next_text_content_bytes + self.string_chunk_count = next_string_chunk_count self.has_non_string_chunk = next_has_non_string_chunk + def _payload_size( + self, + *, + chunk_count: int, + chunk_item_bytes: int, + text_content_bytes: int, + string_chunk_count: int, + has_non_string_chunk: bool, + ) -> int: + chunk_list_size = 2 + chunk_item_bytes + max(0, chunk_count - 1) + joined_text_size = 0 if string_chunk_count == 0 else 2 + text_content_bytes + + if not has_non_string_chunk: + return joined_text_size + + payload_size = _json_value_size("chunks") + chunk_list_size + 3 + if string_chunk_count: + payload_size += _json_value_size("text") + joined_text_size + 2 + return payload_size + def output_payload(self) -> JSONValue: if not self.normalized_chunks: return "" @@ -313,6 +340,7 @@ def _json_value_size(value: JSONValue) -> int: ensure_ascii=False, sort_keys=True, allow_nan=False, + separators=(",", ":"), ).encode("utf-8") ) except (TypeError, ValueError): @@ -350,7 +378,7 @@ def _normalize_json_value(value: Any, *, _seen: set[int] | None = None) -> JSONV seen.remove(value_id) dict_method = getattr(value, "dict", None) - if callable(dict_method): + if callable(dict_method) and hasattr(value, "__fields__"): seen.add(value_id) try: return _normalize_json_value(dict_method(), _seen=seen) diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index 8912f1e8..d1099263 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -897,6 +897,14 @@ def __str__(self) -> str: return "plain-object" +class _DuckDictChunk: + def dict(self) -> dict[str, str]: + return {"secret": "top-secret"} + + def __str__(self) -> str: + return "duck-dict" + + class TestAsyncGeneratorControl: """Tests for buffered async-generator support in @control().""" @@ -1129,6 +1137,40 @@ async def stream(message: str): assert len(chunks) == 1 assert post_steps[0]["output"] == "plain-object" + @pytest.mark.asyncio + async def test_duck_typed_dict_chunk_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that arbitrary objects with dict() do not take the Pydantic v1 path.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _DuckDictChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "duck-dict" + @pytest.mark.asyncio async def test_post_check_block_happens_before_any_chunk_is_replayed( self, From 098860e8466a63fc4a8527eb9a1b2b12f35f6e9e Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 15:46:41 +0530 Subject: [PATCH 6/8] fix(sdk): freeze buffered async generator replay --- .../src/agent_control/control_decorators.py | 52 ++++++++---- sdks/python/tests/test_control_decorators.py | 85 +++++++++++++++++++ 2 files changed, 121 insertions(+), 16 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index d1721d04..b0cfdbfa 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -28,12 +28,14 @@ async def chat(message: str) -> str: """ import asyncio +import copy import functools import inspect import json import math import time from collections.abc import AsyncGenerator, Callable, Mapping +from contextlib import aclosing from dataclasses import dataclass, field, fields, is_dataclass from typing import Any, TypeVar @@ -234,12 +236,13 @@ def add(self, chunk: Any) -> None: ) normalized_chunk = _normalize_json_value(chunk) - next_chunk_item_bytes = self.chunk_item_bytes + _json_value_size(normalized_chunk) + normalized_chunk_bytes = _json_value_size(normalized_chunk) + next_chunk_item_bytes = self.chunk_item_bytes + normalized_chunk_bytes next_text_content_bytes = self.text_content_bytes next_string_chunk_count = self.string_chunk_count next_has_non_string_chunk = self.has_non_string_chunk if isinstance(normalized_chunk, str): - next_text_content_bytes += _json_value_size(normalized_chunk) - 2 + next_text_content_bytes += normalized_chunk_bytes - 2 next_string_chunk_count += 1 else: next_has_non_string_chunk = True @@ -257,7 +260,7 @@ def add(self, chunk: Any) -> None: "Failing closed." ) - self.replay_chunks.append(chunk) + self.replay_chunks.append(_freeze_replay_chunk(chunk, normalized_chunk)) self.normalized_chunks.append(normalized_chunk) self.chunk_item_bytes = next_chunk_item_bytes self.text_content_bytes = next_text_content_bytes @@ -331,6 +334,17 @@ def _safe_stringify(value: Any) -> str: return f"" +def _freeze_replay_chunk(chunk: Any, normalized_chunk: JSONValue) -> Any: + """Freeze a replayed chunk so later mutations cannot bypass the post-check.""" + if isinstance(chunk, (str, bool, int, float, bytes, type(None))): + return chunk + + try: + return copy.deepcopy(chunk) + except Exception: + return copy.deepcopy(normalized_chunk) + + def _json_value_size(value: JSONValue) -> int: """Estimate normalized JSON payload size in bytes for buffered streams.""" try: @@ -1046,13 +1060,14 @@ async def async_gen_wrapper( "No agent initialized. Call agent_control.init() first. " "Running without protection." ) - async for chunk in func(*args, **kwargs): - yield chunk + async with aclosing(func(*args, **kwargs)) as stream: + async for chunk in stream: + yield chunk return ctx, controls = context ctx.log_start() - stream = func(*args, **kwargs) + span_open = True settings = get_settings() capture = _BufferedStreamCapture( max_chunks=settings.stream_buffer_max_chunks, @@ -1060,17 +1075,21 @@ async def async_gen_wrapper( ) try: - await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + async with aclosing(func(*args, **kwargs)) as stream: + await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) - async for chunk in stream: - capture.add(chunk) + async for chunk in stream: + capture.add(chunk) - await _run_control_check( - ctx, - "post", - ctx.post_payload(capture.output_payload()), - controls, - ) + await _run_control_check( + ctx, + "post", + ctx.post_payload(capture.output_payload()), + controls, + ) + + ctx.log_end() + span_open = False for chunk in capture.replay_chunks: sent = yield chunk @@ -1080,7 +1099,8 @@ async def async_gen_wrapper( "asend(non-None)." ) finally: - ctx.log_end() + if span_open: + ctx.log_end() @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index d1099263..b688f98d 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -1334,6 +1334,91 @@ async def stream(message: str): assert collected == [] assert call_stages == ["pre"] + @pytest.mark.asyncio + async def test_mutable_chunk_is_replayed_from_captured_snapshot( + self, + mock_agent, + mock_safe_response, + ): + """Test that later chunk mutation cannot change buffered replay output.""" + original_chunks = [] + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + chunk = {"text": "safe"} + original_chunks.append(chunk) + yield chunk + chunk["text"] = "unsafe" + + chunks = [chunk async for chunk in stream("test")] + + assert chunks == [{"text": "safe"}] + assert chunks[0] is not original_chunks[0] + assert original_chunks[0] == {"text": "unsafe"} + + @pytest.mark.asyncio + async def test_stream_span_ends_before_buffered_replay_begins( + self, + mock_agent, + mock_safe_response, + ): + """Test that control tracing ends before replay is yielded to the caller.""" + end_events = [] + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response), \ + patch( + "agent_control.control_decorators.log_span_end", + side_effect=lambda *args, **kwargs: end_events.append("ended"), + ): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + + agen = stream("test") + first = await anext(agen) + rest = [chunk async for chunk in agen] + + assert first == "chunk1" + assert rest == ["chunk2"] + assert end_events == ["ended"] + + @pytest.mark.asyncio + async def test_stream_is_closed_when_buffering_fails( + self, + mock_agent, + mock_safe_response, + monkeypatch, + ): + """Test that buffering failures still close the source async generator.""" + closed = False + + async def stream(message: str): + nonlocal closed + try: + yield "chunk1" + yield "chunk2" + finally: + closed = True + + monkeypatch.setattr(get_settings(), "stream_buffer_max_chunks", 1) + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + wrapped = control()(stream) + + with pytest.raises(RuntimeError, match="chunk limit"): + async for _ in wrapped("test"): + pass + + assert closed is True + @pytest.mark.asyncio async def test_async_generator_without_agent_passthrough(self): """Test that async generators pass through unchanged without an initialized agent.""" From ac2c03aca8a29ec053ea2af22473239f470905f4 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 17:59:17 +0530 Subject: [PATCH 7/8] fix(sdk): harden async generator replay semantics --- .../src/agent_control/control_decorators.py | 45 ++++++++----- sdks/python/tests/test_control_decorators.py | 63 +++++++++++++++++++ 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index b0cfdbfa..bb2611d7 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -55,6 +55,7 @@ async def chat(message: str) -> str: logger = get_logger(__name__) F = TypeVar("F", bound=Callable[..., Any]) +_ASYNC_GEN_ASEND_ERROR = "Decorated @control() async generators do not support asend(non-None)." @dataclass @@ -260,7 +261,7 @@ def add(self, chunk: Any) -> None: "Failing closed." ) - self.replay_chunks.append(_freeze_replay_chunk(chunk, normalized_chunk)) + self.replay_chunks.append(_freeze_replay_chunk(chunk)) self.normalized_chunks.append(normalized_chunk) self.chunk_item_bytes = next_chunk_item_bytes self.text_content_bytes = next_text_content_bytes @@ -334,15 +335,24 @@ def _safe_stringify(value: Any) -> str: return f"" -def _freeze_replay_chunk(chunk: Any, normalized_chunk: JSONValue) -> Any: +def _freeze_replay_chunk(chunk: Any) -> Any: """Freeze a replayed chunk so later mutations cannot bypass the post-check.""" if isinstance(chunk, (str, bool, int, float, bytes, type(None))): return chunk try: return copy.deepcopy(chunk) - except Exception: - return copy.deepcopy(normalized_chunk) + except Exception as exc: + raise RuntimeError( + "Buffered async generator produced a chunk that could not be safely " + "snapshotted for replay. Failing closed." + ) from exc + + +def _validate_async_gen_send(value: Any) -> None: + """Reject interactive sends for decorated async generators.""" + if value is not None: + raise TypeError(_ASYNC_GEN_ASEND_ERROR) def _json_value_size(value: JSONValue) -> int: @@ -998,11 +1008,12 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable - Server evaluates all matching "post" controls for the agent Async generator note: - - Decorated async generators are buffered before replay. + - When an agent is initialized, decorated async generators are buffered + before replay. - The post-check runs on the full buffered output. - Chunks are yielded to the caller only after the post-check passes. - - Buffered async generators are pull-only and do not preserve interactive - asend()/athrow() semantics against the source generator. + - Decorated async generators are pull-only and do not preserve + interactive asend()/athrow() semantics against the source generator. Example: import agent_control @@ -1061,8 +1072,14 @@ async def async_gen_wrapper( "Running without protection." ) async with aclosing(func(*args, **kwargs)) as stream: - async for chunk in stream: - yield chunk + while True: + try: + chunk = await anext(stream) + except StopAsyncIteration: + return + + sent = yield chunk + _validate_async_gen_send(sent) return ctx, controls = context @@ -1075,9 +1092,9 @@ async def async_gen_wrapper( ) try: - async with aclosing(func(*args, **kwargs)) as stream: - await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + async with aclosing(func(*args, **kwargs)) as stream: async for chunk in stream: capture.add(chunk) @@ -1093,11 +1110,7 @@ async def async_gen_wrapper( for chunk in capture.replay_chunks: sent = yield chunk - if sent is not None: - raise TypeError( - "Buffered @control() async generators do not support " - "asend(non-None)." - ) + _validate_async_gen_send(sent) finally: if span_open: ctx.log_end() diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index b688f98d..077f2270 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -905,6 +905,15 @@ def __str__(self) -> str: return "duck-dict" +class _UndeepcopyableChunk: + def __deepcopy__(self, memo: dict[int, object]) -> "_UndeepcopyableChunk": + _ = memo + raise RuntimeError("cannot deepcopy") + + def __str__(self) -> str: + return "undeepcopyable" + + class TestAsyncGeneratorControl: """Tests for buffered async-generator support in @control().""" @@ -1419,6 +1428,23 @@ async def stream(message: str): assert closed is True + @pytest.mark.asyncio + async def test_non_none_asend_is_rejected_without_agent(self): + """Test that no-agent async generators reject interactive sends consistently.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=None): + + @control() + async def stream(message: str): + received = yield "chunk1" + yield f"received={received}" + + agen = stream("hello") + first = await agen.__anext__() + assert first == "chunk1" + + with pytest.raises(TypeError, match="asend\\(non-None\\)"): + await agen.asend("interactive-input") + @pytest.mark.asyncio async def test_async_generator_without_agent_passthrough(self): """Test that async generators pass through unchanged without an initialized agent.""" @@ -1454,6 +1480,43 @@ async def stream(message: str): with pytest.raises(TypeError, match="asend\\(non-None\\)"): await agen.asend("interactive-input") + @pytest.mark.asyncio + async def test_non_deepcopyable_chunk_fails_closed_before_replay( + self, + mock_agent, + mock_safe_response, + ): + """Test that replay snapshot failures raise instead of substituting types.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _UndeepcopyableChunk() + + collected = [] + with pytest.raises(RuntimeError, match="safely snapshotted for replay"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + def test_sync_wrapper_copies_public_attributes(self, mock_agent, mock_safe_response): """Test that sync wrappers preserve custom public attributes.""" with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ From dddb2097f0dd4da7ea13f1752be75f79a827cc12 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 30 Mar 2026 19:48:25 +0530 Subject: [PATCH 8/8] docs(sdk): clarify async stream byte limit scope --- sdks/python/src/agent_control/control_decorators.py | 2 ++ sdks/python/src/agent_control/settings.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index bb2611d7..0ebcc4c5 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -1012,6 +1012,8 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable before replay. - The post-check runs on the full buffered output. - Chunks are yielded to the caller only after the post-check passes. + - stream_buffer_max_bytes limits the normalized post-check payload size, + not the in-memory size of replayed chunks. - Decorated async generators are pull-only and do not preserve interactive asend()/athrow() semantics against the source generator. diff --git a/sdks/python/src/agent_control/settings.py b/sdks/python/src/agent_control/settings.py index 51b50d00..bf1fe558 100644 --- a/sdks/python/src/agent_control/settings.py +++ b/sdks/python/src/agent_control/settings.py @@ -64,7 +64,10 @@ class SDKSettings(BaseSettings): stream_buffer_max_bytes: int = Field( default=5_000_000, ge=1, - description="Maximum normalized bytes buffered by @control() before failing closed", + description=( + "Maximum normalized post-check payload bytes buffered by @control() " + "before failing closed; does not bound replay-memory usage" + ), ) # Observability (event batching)