diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index d1edb06e..0ebcc4c5 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -28,14 +28,18 @@ async def chat(message: str) -> str: """ import asyncio +import copy import functools import inspect +import json +import math import time -from collections.abc import Callable -from dataclasses import dataclass, field +from collections.abc import AsyncGenerator, Callable, Mapping +from contextlib import aclosing +from dataclasses import dataclass, field, fields, is_dataclass from typing import Any, TypeVar -from agent_control_models import Step +from agent_control_models import JSONValue, Step from agent_control import AgentControlClient from agent_control.evaluation import check_evaluation_with_local @@ -51,6 +55,7 @@ async def chat(message: str) -> str: logger = get_logger(__name__) F = TypeVar("F", bound=Callable[..., Any]) +_ASYNC_GEN_ASEND_ERROR = "Decorated @control() async generators do not support asend(non-None)." @dataclass @@ -210,6 +215,94 @@ def __init__( ) +@dataclass +class _BufferedStreamCapture: + """Buffer streamed chunks for post-check evaluation before replay.""" + + max_chunks: int + max_bytes: int + replay_chunks: list[Any] = field(default_factory=list) + normalized_chunks: list[JSONValue] = field(default_factory=list) + chunk_item_bytes: int = 0 + text_content_bytes: int = 0 + string_chunk_count: int = 0 + has_non_string_chunk: bool = False + + def add(self, chunk: Any) -> None: + next_chunk_count = len(self.replay_chunks) + 1 + if next_chunk_count > self.max_chunks: + raise RuntimeError( + "Buffered async generator output exceeded the configured chunk limit. " + "Failing closed." + ) + + normalized_chunk = _normalize_json_value(chunk) + normalized_chunk_bytes = _json_value_size(normalized_chunk) + next_chunk_item_bytes = self.chunk_item_bytes + normalized_chunk_bytes + next_text_content_bytes = self.text_content_bytes + next_string_chunk_count = self.string_chunk_count + next_has_non_string_chunk = self.has_non_string_chunk + if isinstance(normalized_chunk, str): + next_text_content_bytes += normalized_chunk_bytes - 2 + next_string_chunk_count += 1 + else: + next_has_non_string_chunk = True + + next_payload_bytes = self._payload_size( + chunk_count=next_chunk_count, + chunk_item_bytes=next_chunk_item_bytes, + text_content_bytes=next_text_content_bytes, + string_chunk_count=next_string_chunk_count, + has_non_string_chunk=next_has_non_string_chunk, + ) + if next_payload_bytes > self.max_bytes: + raise RuntimeError( + "Buffered async generator output exceeded the configured size limit. " + "Failing closed." + ) + + self.replay_chunks.append(_freeze_replay_chunk(chunk)) + self.normalized_chunks.append(normalized_chunk) + self.chunk_item_bytes = next_chunk_item_bytes + self.text_content_bytes = next_text_content_bytes + self.string_chunk_count = next_string_chunk_count + self.has_non_string_chunk = next_has_non_string_chunk + + def _payload_size( + self, + *, + chunk_count: int, + chunk_item_bytes: int, + text_content_bytes: int, + string_chunk_count: int, + has_non_string_chunk: bool, + ) -> int: + chunk_list_size = 2 + chunk_item_bytes + max(0, chunk_count - 1) + joined_text_size = 0 if string_chunk_count == 0 else 2 + text_content_bytes + + if not has_non_string_chunk: + return joined_text_size + + payload_size = _json_value_size("chunks") + chunk_list_size + 3 + if string_chunk_count: + payload_size += _json_value_size("text") + joined_text_size + 2 + return payload_size + + def output_payload(self) -> JSONValue: + if not self.normalized_chunks: + return "" + + text_chunks = [chunk for chunk in self.normalized_chunks if isinstance(chunk, str)] + if len(text_chunks) == len(self.normalized_chunks): + return "".join(text_chunks) + + output: dict[str, JSONValue] = {"chunks": list(self.normalized_chunks)} + text_output = "".join(text_chunks) + if text_output: + output["text"] = text_output + return output + + def _get_current_agent() -> Any | None: """Get the current agent from agent_control module.""" try: @@ -224,6 +317,178 @@ def _get_server_url() -> str: return get_settings().url +def _copy_public_attributes(source: Any, target: Any) -> None: + """Copy public function attributes onto a wrapped callable.""" + for attr in dir(source): + if not attr.startswith("_") and attr not in ("__call__", "__wrapped__"): + try: + setattr(target, attr, getattr(source, attr)) + except (AttributeError, TypeError): + pass + + +def _safe_stringify(value: Any) -> str: + """Stringify a value without letting representation errors escape.""" + try: + return str(value) + except Exception: + return f"" + + +def _freeze_replay_chunk(chunk: Any) -> Any: + """Freeze a replayed chunk so later mutations cannot bypass the post-check.""" + if isinstance(chunk, (str, bool, int, float, bytes, type(None))): + return chunk + + try: + return copy.deepcopy(chunk) + except Exception as exc: + raise RuntimeError( + "Buffered async generator produced a chunk that could not be safely " + "snapshotted for replay. Failing closed." + ) from exc + + +def _validate_async_gen_send(value: Any) -> None: + """Reject interactive sends for decorated async generators.""" + if value is not None: + raise TypeError(_ASYNC_GEN_ASEND_ERROR) + + +def _json_value_size(value: JSONValue) -> int: + """Estimate normalized JSON payload size in bytes for buffered streams.""" + try: + return len( + json.dumps( + value, + ensure_ascii=False, + sort_keys=True, + allow_nan=False, + separators=(",", ":"), + ).encode("utf-8") + ) + except (TypeError, ValueError): + return len(_safe_stringify(value).encode("utf-8")) + + +def _normalize_json_value(value: Any, *, _seen: set[int] | None = None) -> JSONValue: + """Convert runtime values into JSON-safe data for evaluation payloads.""" + if isinstance(value, str) or value is None: + return value + + if isinstance(value, bool | int): + return value + + if isinstance(value, float): + return value if math.isfinite(value) else _safe_stringify(value) + + if isinstance(value, (bytes, bytearray)): + return bytes(value).decode("utf-8", errors="replace") + + seen = set() if _seen is None else _seen + value_id = id(value) + if value_id in seen: + return _safe_stringify(value) + + try: + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + seen.add(value_id) + try: + return _normalize_json_value(model_dump(mode="json"), _seen=seen) + except Exception: + return _safe_stringify(value) + finally: + seen.remove(value_id) + + dict_method = getattr(value, "dict", None) + if callable(dict_method) and hasattr(value, "__fields__"): + seen.add(value_id) + try: + return _normalize_json_value(dict_method(), _seen=seen) + except Exception: + return _safe_stringify(value) + finally: + seen.remove(value_id) + + if isinstance(value, Mapping): + seen.add(value_id) + try: + return { + str(key): _normalize_json_value(item, _seen=seen) + for key, item in value.items() + } + finally: + seen.remove(value_id) + + if isinstance(value, (list, tuple, set, frozenset)): + seen.add(value_id) + try: + return [_normalize_json_value(item, _seen=seen) for item in value] + finally: + seen.remove(value_id) + + if is_dataclass(value) and not isinstance(value, type): + seen.add(value_id) + try: + return { + field_info.name: _normalize_json_value( + getattr(value, field_info.name), + _seen=seen, + ) + for field_info in fields(value) + } + finally: + seen.remove(value_id) + except Exception: + return _safe_stringify(value) + + return _safe_stringify(value) + + +def _normalized_output_value(output: Any) -> JSONValue | None: + """Return normalized output for a post-check payload.""" + if output is None: + return None + return _normalize_json_value(output) + + +def _build_control_context( + func: Callable, + args: tuple, + kwargs: dict, + step_name: str | None, +) -> tuple[ControlContext, list[dict[str, Any]] | None] | None: + """Create the shared control context for a protected invocation.""" + agent = _get_current_agent() + if agent is None: + return None + + controls = _get_server_controls() + + existing_trace_id = get_current_trace_id() + if existing_trace_id: + trace_id = existing_trace_id + span_id = _generate_span_id() + else: + trace_id, span_id = get_trace_and_span_ids() + + return ( + ControlContext( + agent_name=agent.agent_name, + server_url=_get_server_url(), + func=func, + args=args, + kwargs=kwargs, + trace_id=trace_id, + span_id=span_id, + start_time=time.perf_counter(), + step_name=step_name, + ), + controls, + ) + + async def _evaluate( agent_name: str, step: dict[str, Any], @@ -480,9 +745,7 @@ def _create_evaluation_payload( "type": "tool", "name": determined_name, "input": dict(bound.arguments), - "output": output if isinstance(output, (str, int, float, bool, dict, list)) else ( - None if output is None else str(output) - ), + "output": _normalized_output_value(output), } # This is an LLM step @@ -491,9 +754,7 @@ def _create_evaluation_payload( "type": "llm", "name": determined_name, "input": input_data, - "output": output if isinstance(output, (str, int, float, bool, dict, list)) else ( - None if output is None else str(output) - ), + "output": _normalized_output_value(output), } @@ -687,8 +948,8 @@ async def _execute_with_control( ControlSteerError: If any control triggers with "steer" action RuntimeError: If control evaluation fails unexpectedly """ - agent = _get_current_agent() - if agent is None: + context = _build_control_context(func, args, kwargs, step_name) + if context is None: logger.warning( "No agent initialized. Call agent_control.init() first. " "Running without protection." @@ -697,29 +958,7 @@ async def _execute_with_control( return await func(*args, **kwargs) return func(*args, **kwargs) - # Get cached controls for local evaluation support - controls = _get_server_controls() - - # Get trace context: inherit trace_id if set, always generate new span_id - # This allows multiple @control() calls to share the same trace but have unique spans - existing_trace_id = get_current_trace_id() - if existing_trace_id: - trace_id = existing_trace_id - span_id = _generate_span_id() # New span for this function - else: - trace_id, span_id = get_trace_and_span_ids() # New trace and span - - ctx = ControlContext( - agent_name=agent.agent_name, - server_url=_get_server_url(), - func=func, - args=args, - kwargs=kwargs, - trace_id=trace_id, - span_id=span_id, - start_time=time.perf_counter(), - step_name=step_name, - ) + ctx, controls = context ctx.log_start() try: @@ -768,6 +1007,16 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable 3. After function execution: Calls server with stage="post" - Server evaluates all matching "post" controls for the agent + Async generator note: + - When an agent is initialized, decorated async generators are buffered + before replay. + - The post-check runs on the full buffered output. + - Chunks are yielded to the caller only after the post-check passes. + - stream_buffer_max_bytes limits the normalized post-check payload size, + not the in-memory size of replayed chunks. + - Decorated async generators are pull-only and do not preserve + interactive asend()/athrow() semantics against the source generator. + Example: import agent_control @@ -813,19 +1062,69 @@ def decorator(func: F) -> F: register(func, policy) + @functools.wraps(func) + async def async_gen_wrapper( + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Any, None]: + context = _build_control_context(func, args, kwargs, step_name) + if context is None: + logger.warning( + "No agent initialized. Call agent_control.init() first. " + "Running without protection." + ) + async with aclosing(func(*args, **kwargs)) as stream: + while True: + try: + chunk = await anext(stream) + except StopAsyncIteration: + return + + sent = yield chunk + _validate_async_gen_send(sent) + return + + ctx, controls = context + ctx.log_start() + span_open = True + settings = get_settings() + capture = _BufferedStreamCapture( + max_chunks=settings.stream_buffer_max_chunks, + max_bytes=settings.stream_buffer_max_bytes, + ) + + try: + await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) + + async with aclosing(func(*args, **kwargs)) as stream: + async for chunk in stream: + capture.add(chunk) + + await _run_control_check( + ctx, + "post", + ctx.post_payload(capture.output_payload()), + controls, + ) + + ctx.log_end() + span_open = False + + for chunk in capture.replay_chunks: + sent = yield chunk + _validate_async_gen_send(sent) + finally: + if span_open: + ctx.log_end() + @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: return await _execute_with_control( func, args, kwargs, is_async=True, step_name=step_name ) - # Copy over ALL attributes from the original function (important for LangChain tools) - for attr in dir(func): - if not attr.startswith('_') and attr not in ('__call__', '__wrapped__'): - try: - setattr(async_wrapper, attr, getattr(func, attr)) - except (AttributeError, TypeError): - pass + _copy_public_attributes(func, async_gen_wrapper) + _copy_public_attributes(func, async_wrapper) @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: @@ -833,6 +1132,10 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: _execute_with_control(func, args, kwargs, is_async=False, step_name=step_name) ) + _copy_public_attributes(func, sync_wrapper) + + if inspect.isasyncgenfunction(func): + return async_gen_wrapper # type: ignore if inspect.iscoroutinefunction(func): return async_wrapper # type: ignore return sync_wrapper # type: ignore diff --git a/sdks/python/src/agent_control/settings.py b/sdks/python/src/agent_control/settings.py index f06398d4..bf1fe558 100644 --- a/sdks/python/src/agent_control/settings.py +++ b/sdks/python/src/agent_control/settings.py @@ -55,6 +55,21 @@ class SDKSettings(BaseSettings): description="API key for server authentication", ) + # Buffered async-generator control limits + stream_buffer_max_chunks: int = Field( + default=10000, + ge=1, + description="Maximum streamed chunks buffered by @control() before failing closed", + ) + stream_buffer_max_bytes: int = Field( + default=5_000_000, + ge=1, + description=( + "Maximum normalized post-check payload bytes buffered by @control() " + "before failing closed; does not bound replay-memory usage" + ), + ) + # Observability (event batching) observability_enabled: bool = Field( default=True, diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index d13ec46d..077f2270 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -1,10 +1,13 @@ """Tests for @control decorator.""" +import inspect +from dataclasses import dataclass from unittest.mock import MagicMock, patch import pytest from agent_control.control_decorators import ControlViolationError, ControlSteerError, control +from agent_control.settings import get_settings # ============================================================================= @@ -316,6 +319,76 @@ async def chat(message: str) -> str: assert "output" in captured_step assert "Generated response" in captured_step["output"] + @pytest.mark.asyncio + async def test_post_check_serializer_failure_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that serializer hook failures do not escape the post-check path.""" + captured_step = {} + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + captured_step.update(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def chat(message: str) -> _ExplodingModelDumpChunk: + return _ExplodingModelDumpChunk() + + result = await chat("Hello!") + + assert isinstance(result, _ExplodingModelDumpChunk) + assert captured_step["output"] == "exploding-model-dump" + + @pytest.mark.asyncio + async def test_post_check_non_finite_float_is_stringified( + self, + mock_agent, + mock_safe_response, + ): + """Test that non-finite floats are normalized to JSON-safe strings.""" + captured_step = {} + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + captured_step.update(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def chat(message: str) -> float: + return float("inf") + + result = await chat("Hello!") + + assert result == float("inf") + assert captured_step["output"] == "inf" + # ============================================================================= # INPUT EXTRACTION TESTS @@ -788,6 +861,677 @@ async def test_func(): assert exc_info.value.steering_context == "Default evaluator message" +# ============================================================================= +# ASYNC GENERATOR TESTS +# ============================================================================= + + +@dataclass +class _StructuredChunk: + value: str + + +class _RecursiveModelDumpChunk: + def model_dump(self, mode: str = "json") -> dict[str, object]: + _ = mode + return {"self": self} + + def __str__(self) -> str: + return "recursive-model-dump" + + +class _ExplodingModelDumpChunk: + def model_dump(self, mode: str = "json") -> dict[str, object]: + _ = mode + raise RuntimeError("model_dump exploded") + + def __str__(self) -> str: + return "exploding-model-dump" + + +class _PlainObjectChunk: + def __init__(self) -> None: + self.secret = "top-secret" + + def __str__(self) -> str: + return "plain-object" + + +class _DuckDictChunk: + def dict(self) -> dict[str, str]: + return {"secret": "top-secret"} + + def __str__(self) -> str: + return "duck-dict" + + +class _UndeepcopyableChunk: + def __deepcopy__(self, memo: dict[int, object]) -> "_UndeepcopyableChunk": + _ = memo + raise RuntimeError("cannot deepcopy") + + def __str__(self) -> str: + return "undeepcopyable" + + +class TestAsyncGeneratorControl: + """Tests for buffered async-generator support in @control().""" + + @pytest.mark.asyncio + async def test_preserves_async_generator_identity(self, mock_agent, mock_safe_response): + """Test that decorating an async generator preserves its async-gen type.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + yield "hello" + + assert inspect.isasyncgenfunction(stream) + + @pytest.mark.asyncio + async def test_buffers_stream_until_post_check_passes(self, mock_agent, mock_safe_response): + """Test that async-generator output is buffered before replay.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + collected = [] + async for chunk in stream("test"): + collected.append(chunk) + break + + assert collected == ["chunk1"] + assert call_stages == ["pre", "post"] + + @pytest.mark.asyncio + async def test_post_check_receives_joined_text_output_for_string_chunks( + self, + mock_agent, + mock_safe_response, + ): + """Test that string chunks are joined for the buffered post-check payload.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "hello " + yield "world" + + chunks = [chunk async for chunk in stream("test")] + assert chunks == ["hello ", "world"] + assert post_steps[0]["output"] == "hello world" + + @pytest.mark.asyncio + async def test_structured_chunks_are_normalized_without_lossy_repr( + self, + mock_agent, + mock_safe_response, + ): + """Test that structured chunks are normalized into JSON-safe post payloads.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield {"kind": "json"} + yield _StructuredChunk(value="structured") + yield b"tail" + + chunks = [chunk async for chunk in stream("test")] + assert chunks == [{"kind": "json"}, _StructuredChunk(value="structured"), b"tail"] + assert post_steps[0]["output"] == { + "chunks": [ + {"kind": "json"}, + {"value": "structured"}, + "tail", + ], + "text": "tail", + } + + @pytest.mark.asyncio + async def test_model_dump_cycle_is_guarded_during_chunk_normalization( + self, + mock_agent, + mock_safe_response, + ): + """Test that self-referential model_dump payloads do not recurse forever.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _RecursiveModelDumpChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == { + "chunks": [{"self": "recursive-model-dump"}] + } + + @pytest.mark.asyncio + async def test_chunk_serializer_failure_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that serializer hook failures on streamed chunks fall back to strings.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _ExplodingModelDumpChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "exploding-model-dump" + + @pytest.mark.asyncio + async def test_plain_object_chunk_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that plain objects are stringified instead of expanding __dict__.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _PlainObjectChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "plain-object" + + @pytest.mark.asyncio + async def test_duck_typed_dict_chunk_falls_back_to_string( + self, + mock_agent, + mock_safe_response, + ): + """Test that arbitrary objects with dict() do not take the Pydantic v1 path.""" + post_steps = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + if stage == "post": + post_steps.append(step) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _DuckDictChunk() + + chunks = [chunk async for chunk in stream("test")] + assert len(chunks) == 1 + assert post_steps[0]["output"] == "duck-dict" + + @pytest.mark.asyncio + async def test_post_check_block_happens_before_any_chunk_is_replayed( + self, + mock_agent, + mock_safe_response, + mock_unsafe_response, + ): + """Test that a post-stage deny blocks the buffered stream before replay starts.""" + call_count = 0 + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_safe_response + return mock_unsafe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + + collected = [] + with pytest.raises(ControlViolationError): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + + @pytest.mark.asyncio + async def test_stream_buffer_chunk_limit_fails_closed_before_replay( + self, + mock_agent, + mock_safe_response, + monkeypatch, + ): + """Test that exceeding the configured chunk cap blocks buffered replay.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + monkeypatch.setattr(get_settings(), "stream_buffer_max_chunks", 2) + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + yield "chunk3" + + collected = [] + with pytest.raises(RuntimeError, match="chunk limit"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + + @pytest.mark.asyncio + async def test_stream_buffer_byte_limit_accounts_for_mixed_text_duplication( + self, + mock_agent, + mock_safe_response, + monkeypatch, + ): + """Test that mixed streams count duplicated text toward the byte cap.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + monkeypatch.setattr(get_settings(), "stream_buffer_max_bytes", 15) + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "a" + yield {"kind": "json"} + yield "b" + + collected = [] + with pytest.raises(RuntimeError, match="size limit"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + + @pytest.mark.asyncio + async def test_generator_failure_skips_post_check_and_replays_nothing( + self, + mock_agent, + mock_safe_response, + ): + """Test that a source stream failure propagates before any replay or post-check.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield "chunk1" + raise RuntimeError("stream failed") + + collected = [] + with pytest.raises(RuntimeError, match="stream failed"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + + @pytest.mark.asyncio + async def test_mutable_chunk_is_replayed_from_captured_snapshot( + self, + mock_agent, + mock_safe_response, + ): + """Test that later chunk mutation cannot change buffered replay output.""" + original_chunks = [] + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + chunk = {"text": "safe"} + original_chunks.append(chunk) + yield chunk + chunk["text"] = "unsafe" + + chunks = [chunk async for chunk in stream("test")] + + assert chunks == [{"text": "safe"}] + assert chunks[0] is not original_chunks[0] + assert original_chunks[0] == {"text": "unsafe"} + + @pytest.mark.asyncio + async def test_stream_span_ends_before_buffered_replay_begins( + self, + mock_agent, + mock_safe_response, + ): + """Test that control tracing ends before replay is yielded to the caller.""" + end_events = [] + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response), \ + patch( + "agent_control.control_decorators.log_span_end", + side_effect=lambda *args, **kwargs: end_events.append("ended"), + ): + + @control() + async def stream(message: str): + yield "chunk1" + yield "chunk2" + + agen = stream("test") + first = await anext(agen) + rest = [chunk async for chunk in agen] + + assert first == "chunk1" + assert rest == ["chunk2"] + assert end_events == ["ended"] + + @pytest.mark.asyncio + async def test_stream_is_closed_when_buffering_fails( + self, + mock_agent, + mock_safe_response, + monkeypatch, + ): + """Test that buffering failures still close the source async generator.""" + closed = False + + async def stream(message: str): + nonlocal closed + try: + yield "chunk1" + yield "chunk2" + finally: + closed = True + + monkeypatch.setattr(get_settings(), "stream_buffer_max_chunks", 1) + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + wrapped = control()(stream) + + with pytest.raises(RuntimeError, match="chunk limit"): + async for _ in wrapped("test"): + pass + + assert closed is True + + @pytest.mark.asyncio + async def test_non_none_asend_is_rejected_without_agent(self): + """Test that no-agent async generators reject interactive sends consistently.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=None): + + @control() + async def stream(message: str): + received = yield "chunk1" + yield f"received={received}" + + agen = stream("hello") + first = await agen.__anext__() + assert first == "chunk1" + + with pytest.raises(TypeError, match="asend\\(non-None\\)"): + await agen.asend("interactive-input") + + @pytest.mark.asyncio + async def test_async_generator_without_agent_passthrough(self): + """Test that async generators pass through unchanged without an initialized agent.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=None): + + @control() + async def stream(message: str): + yield "a" + yield "b" + + chunks = [chunk async for chunk in stream("hello")] + assert chunks == ["a", "b"] + + @pytest.mark.asyncio + async def test_asend_non_none_is_rejected_for_buffered_replay( + self, + mock_agent, + mock_safe_response, + ): + """Test that buffered replay rejects interactive sends.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + @control() + async def stream(message: str): + received = yield "chunk1" + yield f"received={received}" + + agen = stream("hello") + first = await agen.__anext__() + assert first == "chunk1" + + with pytest.raises(TypeError, match="asend\\(non-None\\)"): + await agen.asend("interactive-input") + + @pytest.mark.asyncio + async def test_non_deepcopyable_chunk_fails_closed_before_replay( + self, + mock_agent, + mock_safe_response, + ): + """Test that replay snapshot failures raise instead of substituting types.""" + call_stages = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + call_stages.append(stage) + return mock_safe_response + + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate): + + @control() + async def stream(message: str): + yield _UndeepcopyableChunk() + + collected = [] + with pytest.raises(RuntimeError, match="safely snapshotted for replay"): + async for chunk in stream("test"): + collected.append(chunk) + + assert collected == [] + assert call_stages == ["pre"] + + def test_sync_wrapper_copies_public_attributes(self, mock_agent, mock_safe_response): + """Test that sync wrappers preserve custom public attributes.""" + with patch("agent_control.control_decorators._get_current_agent", return_value=mock_agent), \ + patch("agent_control.control_decorators._evaluate", return_value=mock_safe_response): + + def process(input: str) -> str: + return input.upper() + + process.description = "Uppercases the input" + + wrapped = control()(process) + + assert wrapped.description == "Uppercases the input" + + # ============================================================================= # EXCEPTION HANDLING TESTS # =============================================================================