From 58eb627341fdb037753b05db794f32bc3454db09 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Mon, 23 Mar 2026 13:01:44 -0700 Subject: [PATCH 1/9] Add provider agnostic traceing --- sdks/python/src/agent_control/__init__.py | 14 +++++ .../src/agent_control/telemetry/__init__.py | 27 +++++++++ .../src/agent_control/telemetry/event_sink.py | 33 +++++++++++ .../agent_control/telemetry/trace_context.py | 53 +++++++++++++++++ sdks/python/src/agent_control/tracing.py | 22 ++++++- sdks/python/tests/test_event_sink.py | 59 +++++++++++++++++++ sdks/python/tests/test_trace_context.py | 48 +++++++++++++++ sdks/python/tests/test_tracing.py | 46 +++++++++++++++ 8 files changed, 300 insertions(+), 2 deletions(-) create mode 100644 sdks/python/src/agent_control/telemetry/__init__.py create mode 100644 sdks/python/src/agent_control/telemetry/event_sink.py create mode 100644 sdks/python/src/agent_control/telemetry/trace_context.py create mode 100644 sdks/python/tests/test_event_sink.py create mode 100644 sdks/python/tests/test_trace_context.py diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 33658fb4..364e76bc 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -107,6 +107,14 @@ async def handle_input(user_message: str) -> str: is_otel_available, with_trace, ) +from .telemetry import ( + clear_control_event_sink, + clear_trace_context_provider, + emit_control_events, + get_trace_context_from_provider, + set_control_event_sink, + set_trace_context_provider, +) from .validation import ensure_agent_name # Module logger @@ -1305,6 +1313,12 @@ async def main(): "get_current_span_id", "with_trace", "is_otel_available", + "set_trace_context_provider", + "get_trace_context_from_provider", + "clear_trace_context_provider", + "set_control_event_sink", + "emit_control_events", + "clear_control_event_sink", # Observability "init_observability", "add_event", diff --git a/sdks/python/src/agent_control/telemetry/__init__.py b/sdks/python/src/agent_control/telemetry/__init__.py new file mode 100644 index 00000000..8933553d --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/__init__.py @@ -0,0 +1,27 @@ +"""Telemetry interfaces for provider-agnostic tracing and event emission.""" + +from .event_sink import ( + ControlEventSink, + clear_control_event_sink, + emit_control_events, + set_control_event_sink, +) +from .trace_context import ( + TraceContext, + TraceContextProvider, + clear_trace_context_provider, + get_trace_context_from_provider, + set_trace_context_provider, +) + +__all__ = [ + "ControlEventSink", + "TraceContext", + "TraceContextProvider", + "clear_control_event_sink", + "clear_trace_context_provider", + "emit_control_events", + "get_trace_context_from_provider", + "set_control_event_sink", + "set_trace_context_provider", +] diff --git a/sdks/python/src/agent_control/telemetry/event_sink.py b/sdks/python/src/agent_control/telemetry/event_sink.py new file mode 100644 index 00000000..b36e9c13 --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/event_sink.py @@ -0,0 +1,33 @@ +"""Provider-agnostic sink for merged control execution events.""" + +from collections.abc import Callable + +from agent_control_models import ControlExecutionEvent + +ControlEventSink = Callable[[list[ControlExecutionEvent]], None] + +_control_event_sink: ControlEventSink | None = None + + +def set_control_event_sink(sink: ControlEventSink | None) -> None: + """Register a sink for merged control execution events.""" + global _control_event_sink + _control_event_sink = sink + + +def emit_control_events(events: list[ControlExecutionEvent]) -> None: + """Emit merged control execution events to the registered sink.""" + if not events or _control_event_sink is None: + return + + try: + _control_event_sink(events) + except Exception: + # Sink failures should not break control evaluation. + pass + + +def clear_control_event_sink() -> None: + """Clear the registered control event sink.""" + global _control_event_sink + _control_event_sink = None diff --git a/sdks/python/src/agent_control/telemetry/trace_context.py b/sdks/python/src/agent_control/telemetry/trace_context.py new file mode 100644 index 00000000..82c4326e --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/trace_context.py @@ -0,0 +1,53 @@ +"""Provider-agnostic trace context interface for external tracing systems.""" + +from collections.abc import Callable +from typing import TypedDict + + +class TraceContext(TypedDict): + """Resolved trace context for a control evaluation.""" + + trace_id: str + span_id: str + + +TraceContextProvider = Callable[[], TraceContext | None] + +_trace_context_provider: TraceContextProvider | None = None + + +def set_trace_context_provider(provider: TraceContextProvider | None) -> None: + """Register a provider that returns the current trace context.""" + global _trace_context_provider + _trace_context_provider = provider + + +def get_trace_context_from_provider() -> TraceContext | None: + """Return trace context from the registered provider, if any.""" + if _trace_context_provider is None: + return None + + try: + trace_context = _trace_context_provider() + except Exception: + # Provider failures should not break control evaluation. + return None + + if trace_context is None: + return None + + trace_id = trace_context.get("trace_id") + span_id = trace_context.get("span_id") + if not isinstance(trace_id, str) or not isinstance(span_id, str): + return None + + return { + "trace_id": trace_id, + "span_id": span_id, + } + + +def clear_trace_context_provider() -> None: + """Clear the registered trace context provider.""" + global _trace_context_provider + _trace_context_provider = None diff --git a/sdks/python/src/agent_control/tracing.py b/sdks/python/src/agent_control/tracing.py index 473b5633..47696b15 100644 --- a/sdks/python/src/agent_control/tracing.py +++ b/sdks/python/src/agent_control/tracing.py @@ -31,6 +31,8 @@ from contextlib import contextmanager from contextvars import ContextVar, Token +from .telemetry.trace_context import get_trace_context_from_provider + # Context variables for trace/span propagation _trace_id_var: ContextVar[str | None] = ContextVar("trace_id", default=None) _span_id_var: ContextVar[str | None] = ContextVar("span_id", default=None) @@ -94,8 +96,9 @@ def get_trace_and_span_ids() -> tuple[str, str]: Priority: 1. Context variable (set by with_trace or explicitly) - 2. OpenTelemetry context (if OTEL is installed and active) - 3. Generate new OTEL-compatible IDs + 2. External provider + 3. OpenTelemetry context (if OTEL is installed and active) + 4. Generate new OTEL-compatible IDs Returns: Tuple of (trace_id, span_id) - both are hex strings @@ -114,6 +117,11 @@ def get_trace_and_span_ids() -> tuple[str, str]: if trace_id is not None and span_id is not None: return trace_id, span_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["trace_id"], trace_context["span_id"] + # Try OpenTelemetry context otel_trace_id, otel_span_id = _get_otel_ids() @@ -136,6 +144,11 @@ def get_current_trace_id() -> str | None: if trace_id is not None: return trace_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["trace_id"] + # Try OpenTelemetry otel_trace_id, _ = _get_otel_ids() return otel_trace_id @@ -153,6 +166,11 @@ def get_current_span_id() -> str | None: if span_id is not None: return span_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["span_id"] + # Try OpenTelemetry _, otel_span_id = _get_otel_ids() return otel_span_id diff --git a/sdks/python/tests/test_event_sink.py b/sdks/python/tests/test_event_sink.py new file mode 100644 index 00000000..8013f4d6 --- /dev/null +++ b/sdks/python/tests/test_event_sink.py @@ -0,0 +1,59 @@ +"""Tests for the telemetry merged control event sink interface.""" + +from datetime import UTC, datetime + +from agent_control.telemetry.event_sink import ( + clear_control_event_sink, + emit_control_events, + set_control_event_sink, +) +from agent_control_models import ControlExecutionEvent + + +def _event() -> ControlExecutionEvent: + return ControlExecutionEvent( + control_execution_id="ce-1", + trace_id="a" * 32, + span_id="b" * 16, + agent_name="test-agent", + control_id=1, + control_name="pii_check", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=0.95, + timestamp=datetime.now(UTC), + metadata={}, + ) + + +def teardown_function() -> None: + clear_control_event_sink() + + +def test_emit_control_events_calls_registered_sink() -> None: + seen: list[list[ControlExecutionEvent]] = [] + + def _sink(events: list[ControlExecutionEvent]) -> None: + seen.append(events) + + event = _event() + set_control_event_sink(_sink) + + emit_control_events([event]) + + assert seen == [[event]] + + +def test_emit_control_events_noops_without_sink() -> None: + emit_control_events([_event()]) + + +def test_emit_control_events_swallows_sink_failures() -> None: + def _sink(_events: list[ControlExecutionEvent]) -> None: + raise RuntimeError("boom") + + set_control_event_sink(_sink) + + emit_control_events([_event()]) diff --git a/sdks/python/tests/test_trace_context.py b/sdks/python/tests/test_trace_context.py new file mode 100644 index 00000000..9df234c6 --- /dev/null +++ b/sdks/python/tests/test_trace_context.py @@ -0,0 +1,48 @@ +"""Tests for the telemetry trace context provider interface.""" + +from agent_control.telemetry.trace_context import ( + clear_trace_context_provider, + get_trace_context_from_provider, + set_trace_context_provider, +) + + +def teardown_function() -> None: + clear_trace_context_provider() + + +def test_get_trace_context_from_provider_returns_registered_context() -> None: + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + ) + + assert get_trace_context_from_provider() == { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + + +def test_get_trace_context_from_provider_returns_none_when_unset() -> None: + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_swallows_provider_failures() -> None: + def _raising_provider(): + raise RuntimeError("boom") + + set_trace_context_provider(_raising_provider) + + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_for_invalid_shape() -> None: + set_trace_context_provider( # type: ignore[arg-type] + lambda: { + "trace_id": "a" * 32, + } + ) + + assert get_trace_context_from_provider() is None diff --git a/sdks/python/tests/test_tracing.py b/sdks/python/tests/test_tracing.py index 175cb7c4..97397b8d 100644 --- a/sdks/python/tests/test_tracing.py +++ b/sdks/python/tests/test_tracing.py @@ -2,6 +2,7 @@ import pytest +from agent_control.telemetry.trace_context import clear_trace_context_provider, set_trace_context_provider from agent_control.tracing import ( _generate_span_id, _generate_trace_id, @@ -17,6 +18,10 @@ ) +def teardown_function() -> None: + clear_trace_context_provider() + + class TestIdGeneration: """Tests for trace and span ID generation.""" @@ -132,6 +137,30 @@ def test_get_current_ids_without_context(self): assert trace_id is None or isinstance(trace_id, str) assert span_id is None or isinstance(span_id, str) + def test_get_current_trace_id_uses_provider(self): + """Test that get_current_trace_id uses external provider before OTEL fallback.""" + expected_trace = "a" * 32 + set_trace_context_provider( + lambda: { + "trace_id": expected_trace, + "span_id": "b" * 16, + } + ) + + assert get_current_trace_id() == expected_trace + + def test_get_current_span_id_uses_provider(self): + """Test that get_current_span_id uses external provider before OTEL fallback.""" + expected_span = "b" * 16 + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": expected_span, + } + ) + + assert get_current_span_id() == expected_span + class TestWithTraceContextManager: """Tests for the with_trace context manager.""" @@ -237,6 +266,23 @@ def test_get_trace_and_span_ids_uses_context(self): assert trace_id == expected_trace assert span_id == expected_span + def test_get_trace_and_span_ids_uses_provider_before_otel(self): + """Test that an external provider is checked before OTEL fallback.""" + expected_trace = "c" * 32 + expected_span = "d" * 16 + + set_trace_context_provider( + lambda: { + "trace_id": expected_trace, + "span_id": expected_span, + } + ) + + trace_id, span_id = get_trace_and_span_ids() + + assert trace_id == expected_trace + assert span_id == expected_span + class TestOtelAvailability: """Tests for OpenTelemetry availability detection.""" From 3d39706d53742bb7185ce751d953db4a8be0fa76 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Mon, 23 Mar 2026 13:38:26 -0700 Subject: [PATCH 2/9] fix linting --- sdks/python/src/agent_control/__init__.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 364e76bc..c03a3f66 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -78,11 +78,7 @@ async def handle_input(user_message: str) -> str: from ._control_registry import ( clear as clear_step_registry, ) - -# Import client and operations modules from .client import AgentControlClient - -# Import control decorator from .control_decorators import ControlSteerError, ControlViolationError, control from .evaluation import check_evaluation_with_local, evaluate_controls from .observability import ( @@ -98,15 +94,6 @@ async def handle_input(user_message: str) -> str: shutdown_observability, sync_shutdown_observability, ) - -# Import tracing and observability -from .tracing import ( - get_current_span_id, - get_current_trace_id, - get_trace_and_span_ids, - is_otel_available, - with_trace, -) from .telemetry import ( clear_control_event_sink, clear_trace_context_provider, @@ -115,6 +102,13 @@ async def handle_input(user_message: str) -> str: set_control_event_sink, set_trace_context_provider, ) +from .tracing import ( + get_current_span_id, + get_current_trace_id, + get_trace_and_span_ids, + is_otel_available, + with_trace, +) from .validation import ensure_agent_name # Module logger From c3241c1c954b09e0e25f412c7396c5d3fdf31c6e Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Mon, 23 Mar 2026 13:50:40 -0700 Subject: [PATCH 3/9] add test --- sdks/python/tests/test_trace_context.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdks/python/tests/test_trace_context.py b/sdks/python/tests/test_trace_context.py index 9df234c6..e1305711 100644 --- a/sdks/python/tests/test_trace_context.py +++ b/sdks/python/tests/test_trace_context.py @@ -29,6 +29,12 @@ def test_get_trace_context_from_provider_returns_none_when_unset() -> None: assert get_trace_context_from_provider() is None +def test_get_trace_context_from_provider_returns_none_when_provider_returns_none() -> None: + set_trace_context_provider(lambda: None) + + assert get_trace_context_from_provider() is None + + def test_get_trace_context_from_provider_swallows_provider_failures() -> None: def _raising_provider(): raise RuntimeError("boom") From 55e57b594520889957d2d179dd76f05c375ac369 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Thu, 26 Mar 2026 14:42:22 -0700 Subject: [PATCH 4/9] draft --- models/src/agent_control_models/evaluation.py | 5 + sdks/python/src/agent_control/__init__.py | 2 + sdks/python/src/agent_control/evaluation.py | 115 ++++--- .../src/agent_control/telemetry/__init__.py | 2 + .../src/agent_control/telemetry/event_sink.py | 5 + .../tests/test_observability_updates.py | 303 +++++++++++++----- .../endpoints/evaluation.py | 68 ++-- .../tests/test_evaluation_error_handling.py | 61 +++- 8 files changed, 400 insertions(+), 161 deletions(-) diff --git a/models/src/agent_control_models/evaluation.py b/models/src/agent_control_models/evaluation.py index 07ab4810..8b0dd9c2 100644 --- a/models/src/agent_control_models/evaluation.py +++ b/models/src/agent_control_models/evaluation.py @@ -6,6 +6,7 @@ from .agent import AGENT_NAME_MIN_LENGTH, AGENT_NAME_PATTERN, Step, normalize_agent_name from .base import BaseModel from .controls import ControlMatch +from .observability import ControlExecutionEvent class EvaluationRequest(BaseModel): @@ -127,6 +128,10 @@ class EvaluationResponse(BaseModel): default=None, description="List of controls that were evaluated but did not match (if any)", ) + events: list[ControlExecutionEvent] | None = Field( + default=None, + description="Control execution events produced during evaluation (if any)", + ) class EvaluationResult(EvaluationResponse): diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index c03a3f66..d57ca0d9 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -99,6 +99,7 @@ async def handle_input(user_message: str) -> str: clear_trace_context_provider, emit_control_events, get_trace_context_from_provider, + has_control_event_sink, set_control_event_sink, set_trace_context_provider, ) @@ -1311,6 +1312,7 @@ async def main(): "get_trace_context_from_provider", "clear_trace_context_provider", "set_control_event_sink", + "has_control_event_sink", "emit_control_events", "clear_control_event_sink", # Observability diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index 55f5efc1..70695b92 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -20,6 +20,7 @@ from ._state import state from .client import AgentControlClient from .observability import add_event, get_logger, is_observability_enabled +from .telemetry import emit_control_events, has_control_event_sink from .validation import ensure_agent_name _logger = get_logger(__name__) @@ -53,15 +54,15 @@ def _map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: return "tool_call" if step_type == "tool" else "llm_call" -def _emit_local_events( +def _build_local_events( local_result: "EvaluationResponse", request: "EvaluationRequest", local_controls: list["_ControlAdapter"], trace_id: str | None, span_id: str | None, agent_name: str | None, -) -> None: - """Emit observability events for locally-evaluated controls. +) -> list[ControlExecutionEvent]: + """Build observability events for locally-evaluated controls. Mirrors the server's _emit_observability_events() so that SDK-evaluated controls are visible in the observability pipeline. @@ -69,11 +70,9 @@ def _emit_local_events( When trace_id/span_id are missing, fallback all-zero IDs are used so events are still recorded (but clearly marked as uncorrelated). - Only runs when observability is enabled. + Returns a list of local events. Fallback IDs are applied when trace context + is missing so the events can still be correlated within the SDK pipeline. """ - if not is_observability_enabled(): - return - global _trace_warning_logged # noqa: PLW0603 if not trace_id or not span_id: if not _trace_warning_logged: @@ -90,8 +89,9 @@ def _emit_local_events( control_lookup = {c.id: c for c in local_controls} now = datetime.now(UTC) resolved_agent_name = agent_name or request.agent_name + events: list[ControlExecutionEvent] = [] - def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: + def _append_matches(matches: list[ControlMatch] | None, matched: bool) -> None: if not matches: return for match in matches: @@ -104,7 +104,7 @@ def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: ctrl.control ) event_metadata.update(identity_metadata) - add_event( + events.append( ControlExecutionEvent( control_execution_id=match.control_execution_id, trace_id=trace_id, @@ -125,9 +125,19 @@ def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: ) ) - _emit_matches(local_result.matches, matched=True) - _emit_matches(local_result.errors, matched=False) - _emit_matches(local_result.non_matches, matched=False) + _append_matches(local_result.matches, matched=True) + _append_matches(local_result.errors, matched=False) + _append_matches(local_result.non_matches, matched=False) + return events + + +def _deliver_oss_events(events: list[ControlExecutionEvent]) -> None: + """Send events through the existing OSS SDK observability path.""" + if not is_observability_enabled(): + return + + for event in events: + add_event(event) async def check_evaluation( @@ -236,6 +246,10 @@ def _merge_results( if local_result.non_matches or server_result.non_matches: non_matches = (local_result.non_matches or []) + (server_result.non_matches or []) + events: list[ControlExecutionEvent] | None = None + if local_result.events or server_result.events: + events = (local_result.events or []) + (server_result.events or []) + reason = None if local_result.reason and server_result.reason: reason = f"{local_result.reason}; {server_result.reason}" @@ -251,6 +265,7 @@ def _merge_results( matches=matches if matches else None, errors=errors if errors else None, non_matches=non_matches if non_matches else None, + events=events if events else None, ) @@ -366,16 +381,14 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if not parse_errors: return result combined_errors = (result.errors or []) + parse_errors - return EvaluationResult( - is_safe=result.is_safe, - confidence=result.confidence, - reason=result.reason, - matches=result.matches, - errors=combined_errors, - non_matches=result.non_matches, + return result.model_copy( + update={ + "errors": combined_errors, + } ) local_result: EvaluationResponse | None = None + merged_emission_enabled = has_control_event_sink() applicable_local_controls = _get_applicable_controls( local_controls, request, @@ -385,7 +398,7 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: engine = ControlEngine(applicable_local_controls, context="sdk") local_result = await engine.process(request) - _emit_local_events( + local_events = _build_local_events( local_result, request, applicable_local_controls, @@ -393,18 +406,16 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: span_id, agent_name=event_agent_name, ) + local_result = local_result.model_copy(update={"events": local_events or None}) + + if not merged_emission_enabled: + _deliver_oss_events(local_events) if not local_result.is_safe: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) - ) + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if merged_emission_enabled and result.events: + emit_control_events(result.events) + return result if _has_applicable_prefiltered_server_controls(server_control_payloads, request): request_payload = request.model_dump(mode="json", exclude_none=True) @@ -413,6 +424,8 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: headers["X-Trace-Id"] = trace_id if span_id: headers["X-Span-Id"] = span_id + if merged_emission_enabled: + headers["X-Agent-Control-Merge-Events"] = "true" response = await client.http_client.post( "/api/v1/evaluation", @@ -423,32 +436,28 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: server_result = EvaluationResponse.model_validate(response.json()) if local_result is not None: - return _with_parse_errors(_merge_results(local_result, server_result)) - - return _with_parse_errors( - EvaluationResult( - is_safe=server_result.is_safe, - confidence=server_result.confidence, - reason=server_result.reason, - matches=server_result.matches, - errors=server_result.errors, - non_matches=server_result.non_matches, - ) - ) + result = _with_parse_errors(_merge_results(local_result, server_result)) + if merged_emission_enabled and result.events: + emit_control_events(result.events) + return result - if local_result is not None: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) + result = _with_parse_errors( + EvaluationResult.model_validate(server_result.model_dump()) ) + if merged_emission_enabled and result.events: + emit_control_events(result.events) + return result - return _with_parse_errors(EvaluationResult(is_safe=True, confidence=1.0)) + if local_result is not None: + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if merged_emission_enabled and result.events: + emit_control_events(result.events) + return result + + result = _with_parse_errors(EvaluationResult(is_safe=True, confidence=1.0)) + if merged_emission_enabled and result.events: + emit_control_events(result.events) + return result async def evaluate_controls( diff --git a/sdks/python/src/agent_control/telemetry/__init__.py b/sdks/python/src/agent_control/telemetry/__init__.py index 8933553d..6e40b8a2 100644 --- a/sdks/python/src/agent_control/telemetry/__init__.py +++ b/sdks/python/src/agent_control/telemetry/__init__.py @@ -4,6 +4,7 @@ ControlEventSink, clear_control_event_sink, emit_control_events, + has_control_event_sink, set_control_event_sink, ) from .trace_context import ( @@ -22,6 +23,7 @@ "clear_trace_context_provider", "emit_control_events", "get_trace_context_from_provider", + "has_control_event_sink", "set_control_event_sink", "set_trace_context_provider", ] diff --git a/sdks/python/src/agent_control/telemetry/event_sink.py b/sdks/python/src/agent_control/telemetry/event_sink.py index b36e9c13..19062604 100644 --- a/sdks/python/src/agent_control/telemetry/event_sink.py +++ b/sdks/python/src/agent_control/telemetry/event_sink.py @@ -27,6 +27,11 @@ def emit_control_events(events: list[ControlExecutionEvent]) -> None: pass +def has_control_event_sink() -> bool: + """Return whether a merged control event sink is currently registered.""" + return _control_event_sink is not None + + def clear_control_event_sink() -> None: """Clear the registered control event sink.""" global _control_event_sink diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index cdaaa6ce..b3e654b8 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -6,7 +6,8 @@ from agent_control import evaluation from agent_control.evaluation import ( _ControlAdapter, - _emit_local_events, + _build_local_events, + _deliver_oss_events, _map_applies_to, _merge_results, ) @@ -103,14 +104,49 @@ def test_still_combines_matches_and_errors(self): assert len(result.matches) == 2 assert len(result.errors) == 1 + def test_combines_events(self): + from agent_control_models import ControlExecutionEvent + + ev1 = ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=1, + control_name="ctrl-1", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=1.0, + ) + ev2 = ControlExecutionEvent( + trace_id="c" * 32, + span_id="d" * 16, + agent_name="agent-000000000001", + control_id=2, + control_name="ctrl-2", + check_stage="pre", + applies_to="llm_call", + action="deny", + matched=True, + confidence=1.0, + ) + + local = self._make_response(events=[ev1]) + server = self._make_response(events=[ev2]) + + result = _merge_results(local, server) + assert result.events is not None + assert [event.control_id for event in result.events] == [1, 2] + # ============================================================================= -# _emit_local_events tests +# local event build/delivery tests # ============================================================================= class TestEmitLocalEvents: - """Tests for _emit_local_events helper.""" + """Tests for local event build/delivery helpers.""" def _make_control_adapter(self, id, name, evaluator_name="regex", selector_path="input"): """Create a _ControlAdapter for testing.""" @@ -153,75 +189,97 @@ def _make_request(self, step_type="llm"): stage="pre", ) - def test_emits_events_when_observability_enabled(self): - """Should call add_event for each match/error/non_match.""" - from agent_control.evaluation import _emit_local_events - + def test_builds_events(self): + """Should build one event per match/error/non_match.""" ctrl = self._make_control_adapter(1, "ctrl-1") match = self._make_match(1, "ctrl-1") non_match = self._make_match(2, "ctrl-2", matched=False) response = self._make_response(matches=[match], non_matches=[non_match]) request = self._make_request() + events = _build_local_events( + response, + request, + [ctrl, self._make_control_adapter(2, "ctrl-2")], + "trace123", + "span456", + "test-agent", + ) + assert len(events) == 2 + event = events[0] + assert event.trace_id == "trace123" + assert event.span_id == "span456" + assert event.agent_name == "test-agent" + assert event.matched is True + assert event.evaluator_name == "regex" + assert event.selector_path == "input" + + def test_delivers_events_when_observability_enabled(self): + """Should call add_event for each built event when OSS delivery is enabled.""" + from agent_control_models import ControlExecutionEvent + + built_events = [ + ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=1, + control_name="ctrl-1", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=1.0, + ) + ] + with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, - [ctrl, self._make_control_adapter(2, "ctrl-2")], - "trace123", "span456", "test-agent", - ) - assert mock_add.call_count == 2 - # Verify event fields for the match - event = mock_add.call_args_list[0][0][0] - assert event.trace_id == "trace123" - assert event.span_id == "span456" - assert event.agent_name == "test-agent" - assert event.matched is True - assert event.evaluator_name == "regex" - assert event.selector_path == "input" - - def test_skips_when_observability_disabled(self): + _deliver_oss_events(built_events) + mock_add.assert_called_once_with(built_events[0]) + + def test_skips_delivery_when_observability_disabled(self): """Should not call add_event when observability is disabled.""" - from agent_control.evaluation import _emit_local_events + from agent_control_models import ControlExecutionEvent - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request() + built_events = [ + ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=1, + control_name="ctrl-1", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=1.0, + ) + ] with patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) + _deliver_oss_events(built_events) mock_add.assert_not_called() def test_maps_tool_step_to_tool_call(self): """Should set applies_to='tool_call' for tool steps.""" - from agent_control.evaluation import _emit_local_events - ctrl = self._make_control_adapter(1, "ctrl-1") match = self._make_match(1, "ctrl-1") response = self._make_response(matches=[match]) request = self._make_request(step_type="tool") - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) - event = mock_add.call_args_list[0][0][0] - assert event.applies_to == "tool_call" + built_events = _build_local_events( + response, request, [ctrl], "trace123", "span456", "test-agent" + ) + assert built_events[0].applies_to == "tool_call" def test_uses_fallback_ids_when_trace_context_missing(self): - """Should emit events with all-zero fallback IDs when trace context is absent.""" + """Should build events with all-zero fallback IDs when trace context is absent.""" import agent_control.evaluation as eval_mod from agent_control.evaluation import ( _FALLBACK_SPAN_ID, _FALLBACK_TRACE_ID, - _emit_local_events, ) ctrl = self._make_control_adapter(1, "ctrl-1") @@ -232,15 +290,12 @@ def test_uses_fallback_ids_when_trace_context_missing(self): # Reset the once-only warning flag so the warning fires in this test eval_mod._trace_warning_logged = False - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add, \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events( - response, request, [ctrl], - None, None, "test-agent", + with patch("agent_control.evaluation._logger") as mock_logger: + built_events = _build_local_events( + response, request, [ctrl], None, None, "test-agent" ) - assert mock_add.call_count == 1 - event = mock_add.call_args_list[0][0][0] + assert len(built_events) == 1 + event = built_events[0] assert event.trace_id == _FALLBACK_TRACE_ID assert event.span_id == _FALLBACK_SPAN_ID assert event.trace_id == "0" * 32 @@ -277,9 +332,7 @@ def test_composite_control_emits_representative_leaf_metadata(self): request = self._make_request() # When: emitting local observability events - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( + built_events = _build_local_events( response, request, [ctrl], @@ -287,7 +340,7 @@ def test_composite_control_emits_representative_leaf_metadata(self): "span456", "test-agent", ) - event = mock_add.call_args_list[0][0][0] + event = built_events[0] # Then: the first leaf becomes the event identity and full context is preserved assert event.evaluator_name == "regex" @@ -301,8 +354,6 @@ def test_composite_control_emits_representative_leaf_metadata(self): def test_fallback_warning_logged_only_once(self): """The missing-trace-context warning should fire only on the first call.""" import agent_control.evaluation as eval_mod - from agent_control.evaluation import _emit_local_events - ctrl = self._make_control_adapter(1, "ctrl-1") match = self._make_match(1, "ctrl-1") response = self._make_response(matches=[match]) @@ -310,11 +361,9 @@ def test_fallback_warning_logged_only_once(self): eval_mod._trace_warning_logged = False - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event"), \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") + with patch("agent_control.evaluation._logger") as mock_logger: + _build_local_events(response, request, [ctrl], None, None, "agent-test-a1") + _build_local_events(response, request, [ctrl], None, None, "agent-test-a1") assert mock_logger.warning.call_count == 1 @@ -373,7 +422,7 @@ async def test_emits_events_when_trace_context_provided(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: + patch("agent_control.evaluation._deliver_oss_events") as mock_deliver: result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -385,16 +434,15 @@ async def test_emits_events_when_trace_context_provided(self): event_agent_name="test-agent", ) - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][2] is not None # local_controls - assert call_args[0][3] == "abc123" # trace_id - assert call_args[0][4] == "def456" # span_id - assert call_args.kwargs["agent_name"] == "test-agent" + mock_deliver.assert_called_once() # Also verify non_matches propagated assert result.non_matches is not None assert len(result.non_matches) == 1 + assert result.events is not None + assert len(result.events) == 1 + assert result.events[0].trace_id == "abc123" + assert result.events[0].span_id == "def456" @pytest.mark.asyncio async def test_emits_events_without_trace_context(self): @@ -427,8 +475,8 @@ async def test_emits_events_without_trace_context(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: - await evaluation.check_evaluation_with_local( + patch("agent_control.evaluation._deliver_oss_events") as mock_deliver: + result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", step=step, @@ -436,10 +484,10 @@ async def test_emits_events_without_trace_context(self): controls=controls, # No trace_id/span_id ) - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][3] is None # trace_id passed as None - assert call_args[0][4] is None # span_id passed as None + mock_deliver.assert_called_once() + assert result.events is not None + assert result.events[0].trace_id == "0" * 32 + assert result.events[0].span_id == "0" * 16 @pytest.mark.asyncio async def test_forwards_trace_headers_to_server(self): @@ -492,6 +540,109 @@ async def test_forwards_trace_headers_to_server(self): assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" assert headers["X-Span-Id"] == "eeee5555ffff6666" + @pytest.mark.asyncio + async def test_merged_event_sink_emits_once_after_merge(self): + """When a sink is registered, local/server events should merge and emit once.""" + from agent_control_models import ( + ControlExecutionEvent, + ControlMatch, + EvaluationResponse, + EvaluatorResult, + Step, + ) + + local_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + matches=[ + ControlMatch( + control_id=1, + control_name="local-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.8), + ) + ], + ) + server_event = ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=2, + control_name="server-ctrl", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=0.4, + ) + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": None, + "errors": None, + "non_matches": None, + "events": [server_event.model_dump(mode="json")], + } + + controls = [ + { + "id": 1, + "name": "local-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "sdk", + }, + }, + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + }, + ] + + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=local_response) + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch("agent_control.evaluation.add_event") as mock_add: + result = await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + ) + + mock_add.assert_not_called() + mock_emit.assert_called_once() + merged_events = mock_emit.call_args.args[0] + assert len(merged_events) == 2 + assert {event.control_id for event in merged_events} == {1, 2} + headers = client.http_client.post.call_args.kwargs["headers"] + assert headers["X-Agent-Control-Merge-Events"] == "true" + assert result.events is not None + assert len(result.events) == 2 + # ============================================================================= # control_decorators non_matches dict conversion diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index c92ea315..128c9e08 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -156,6 +156,7 @@ async def evaluate( db: AsyncSession = Depends(get_async_db), x_trace_id: str | None = Header(default=None, alias="X-Trace-Id"), x_span_id: str | None = Header(default=None, alias="X-Span-Id"), + x_merge_events: str | None = Header(default=None, alias="X-Agent-Control-Merge-Events"), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -238,30 +239,35 @@ async def evaluate( # Calculate total execution time total_duration_ms = (time.perf_counter() - start_time) * 1000 - # Emit observability events if enabled - if observability_settings.enabled: + merge_events_requested = (x_merge_events or "").lower() == "true" + response_events = _build_observability_events( + response=raw_response, + request=request, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + applies_to=applies_to, + control_lookup=control_lookup, + total_duration_ms=total_duration_ms, + ) + + # OSS keeps server-side ingestion as the default. Enterprise merged mode + # returns events to the SDK and skips this server-side delivery step. + if observability_settings.enabled and not merge_events_requested: # Get ingestor from app.state (None if not initialized) try: ingestor = get_event_ingestor(req) except RuntimeError: ingestor = None + await _ingest_observability_events(response_events, ingestor) - await _emit_observability_events( - response=raw_response, - request=request, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - applies_to=applies_to, - control_lookup=control_lookup, - total_duration_ms=total_duration_ms, - ingestor=ingestor, - ) - - return _sanitize_evaluation_response(raw_response) + sanitized = _sanitize_evaluation_response(raw_response) + if response_events: + sanitized = sanitized.model_copy(update={"events": response_events}) + return sanitized -async def _emit_observability_events( +def _build_observability_events( response: EvaluationResponse, request: EvaluationRequest, trace_id: str, @@ -270,9 +276,8 @@ async def _emit_observability_events( applies_to: Literal["llm_call", "tool_call"], control_lookup: dict, total_duration_ms: float, - ingestor: EventIngestor | None, -) -> None: - """Create and enqueue observability events for all evaluated controls. + ) -> list[ControlExecutionEvent]: + """Create observability events for all evaluated controls. Uses control_execution_id from the engine response to ensure correlation between SDK logs and server observability events. @@ -379,11 +384,20 @@ async def _emit_observability_events( ) ) - # Ingest events - if events and ingestor: - result = await ingestor.ingest(events) - if result.dropped > 0: - _logger.warning( - f"Dropped {result.dropped} observability events, " - f"processed {result.processed}" - ) + return events + + +async def _ingest_observability_events( + events: list[ControlExecutionEvent], + ingestor: EventIngestor | None, +) -> None: + """Ingest server-side observability events when OSS batching is active.""" + if not events or ingestor is None: + return + + result = await ingestor.ingest(events) + if result.dropped > 0: + _logger.warning( + f"Dropped {result.dropped} observability events, " + f"processed {result.processed}" + ) diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index 942dca66..a5abdb00 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -3,7 +3,13 @@ import uuid from unittest.mock import AsyncMock, MagicMock -from agent_control_models import ControlMatch, EvaluationRequest, EvaluatorResult, Step +from agent_control_models import ( + ControlExecutionEvent, + ControlMatch, + EvaluationRequest, + EvaluatorResult, + Step, +) from fastapi.testclient import TestClient from agent_control_server.endpoints.evaluation import ( @@ -198,8 +204,10 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani lambda _config: mock_evaluator, ) - emit_mock = AsyncMock() - monkeypatch.setattr(evaluation_module, "_emit_observability_events", emit_mock) + build_mock = MagicMock(return_value=[]) + ingest_mock = AsyncMock() + monkeypatch.setattr(evaluation_module, "_build_observability_events", build_mock) + monkeypatch.setattr(evaluation_module, "_ingest_observability_events", ingest_mock) monkeypatch.setattr(evaluation_module.observability_settings, "enabled", True) # When: sending an evaluation request @@ -220,8 +228,8 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani assert data["errors"][0]["result"]["error"] == SAFE_EVALUATOR_ERROR # And: observability receives the raw engine response with unsanitized diagnostics - emit_mock.assert_awaited_once() - raw_response = emit_mock.await_args.kwargs["response"] + build_mock.assert_called_once() + raw_response = build_mock.call_args.kwargs["response"] assert raw_response.errors is not None raw_error = raw_response.errors[0] assert raw_error.control_name == control_name @@ -229,6 +237,7 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani raw_trace = raw_error.result.metadata["condition_trace"] assert raw_trace["error"] == "RuntimeError: Simulated evaluator crash" assert raw_trace["message"] == "Evaluation failed: RuntimeError: Simulated evaluator crash" + ingest_mock.assert_awaited_once() def test_sanitize_control_match_redacts_nested_condition_trace_errors() -> None: @@ -372,3 +381,45 @@ async def ingest(self, events): # type: ignore[no-untyped-def] del app.state.event_ingestor else: app.state.event_ingestor = previous_ingestor + + +def test_evaluation_returns_events_and_skips_ingest_for_merge_mode( + client: TestClient, monkeypatch +) -> None: + """Merged-event mode should return events without ingesting them server-side.""" + agent_name, _ = create_and_assign_policy(client) + + import agent_control_server.endpoints.evaluation as evaluation_module + + event = ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name=agent_name, + control_id=1, + control_name="test-control", + check_stage="pre", + applies_to="llm_call", + action="deny", + matched=True, + confidence=0.9, + ) + build_mock = MagicMock(return_value=[event]) + ingest_mock = AsyncMock() + monkeypatch.setattr(evaluation_module, "_build_observability_events", build_mock) + monkeypatch.setattr(evaluation_module, "_ingest_observability_events", ingest_mock) + monkeypatch.setattr(evaluation_module.observability_settings, "enabled", True) + + payload = Step(type="llm", name="test-step", input="x", output=None) + req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") + resp = client.post( + "/api/v1/evaluation", + json=req.model_dump(mode="json"), + headers={"X-Agent-Control-Merge-Events": "true"}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["events"] is not None + assert len(body["events"]) == 1 + assert body["events"][0]["control_execution_id"] == event.control_execution_id + ingest_mock.assert_not_awaited() From 61cd78875e531a0de12298295c92bdf7c328e641 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Thu, 26 Mar 2026 14:52:12 -0700 Subject: [PATCH 5/9] address comments --- sdks/python/src/agent_control/evaluation.py | 20 ++++-- .../agent_control/telemetry/trace_context.py | 2 + .../tests/test_observability_updates.py | 72 ++++++++++++++++++- sdks/python/tests/test_trace_context.py | 11 +++ 4 files changed, 96 insertions(+), 9 deletions(-) diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index 55f5efc1..d76af177 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -20,6 +20,7 @@ from ._state import state from .client import AgentControlClient from .observability import add_event, get_logger, is_observability_enabled +from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name _logger = get_logger(__name__) @@ -291,6 +292,13 @@ async def check_evaluation_with_local( httpx.HTTPError: If server request fails """ normalized_name = ensure_agent_name(agent_name) + resolved_trace_id = trace_id + resolved_span_id = span_id + if trace_id is None or span_id is None: + current_trace_id, current_span_id = get_trace_and_span_ids() + resolved_trace_id = trace_id or current_trace_id + resolved_span_id = span_id or current_span_id + # Partition controls by local flag local_controls: list[_ControlAdapter] = [] parse_errors: list[ControlMatch] = [] @@ -389,8 +397,8 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: local_result, request, applicable_local_controls, - trace_id, - span_id, + resolved_trace_id, + resolved_span_id, agent_name=event_agent_name, ) @@ -409,10 +417,10 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if _has_applicable_prefiltered_server_controls(server_control_payloads, request): request_payload = request.model_dump(mode="json", exclude_none=True) headers: dict[str, str] = {} - if trace_id: - headers["X-Trace-Id"] = trace_id - if span_id: - headers["X-Span-Id"] = span_id + if resolved_trace_id: + headers["X-Trace-Id"] = resolved_trace_id + if resolved_span_id: + headers["X-Span-Id"] = resolved_span_id response = await client.http_client.post( "/api/v1/evaluation", diff --git a/sdks/python/src/agent_control/telemetry/trace_context.py b/sdks/python/src/agent_control/telemetry/trace_context.py index 82c4326e..a871fb29 100644 --- a/sdks/python/src/agent_control/telemetry/trace_context.py +++ b/sdks/python/src/agent_control/telemetry/trace_context.py @@ -40,6 +40,8 @@ def get_trace_context_from_provider() -> TraceContext | None: span_id = trace_context.get("span_id") if not isinstance(trace_id, str) or not isinstance(span_id, str): return None + if not trace_id or not span_id: + return None return { "trace_id": trace_id, diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index cdaaa6ce..bb11a5ae 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -10,6 +10,10 @@ _map_applies_to, _merge_results, ) +from agent_control.telemetry.trace_context import ( + clear_trace_context_provider, + set_trace_context_provider, +) from agent_control_models import ControlDefinition # ============================================================================= @@ -326,6 +330,9 @@ def test_fallback_warning_logged_only_once(self): class TestCheckEvaluationWithLocal: """Tests for check_evaluation_with_local event emission and non_matches.""" + def teardown_method(self) -> None: + clear_trace_context_provider() + @pytest.mark.asyncio async def test_emits_events_when_trace_context_provided(self): """Should emit observability events when trace_id and span_id are passed.""" @@ -398,7 +405,7 @@ async def test_emits_events_when_trace_context_provided(self): @pytest.mark.asyncio async def test_emits_events_without_trace_context(self): - """Should still emit events when trace_id/span_id not provided (fallback IDs).""" + """Should resolve trace context from the provider when IDs are omitted.""" from agent_control_models import EvaluationResponse, Step mock_response = EvaluationResponse( @@ -424,6 +431,12 @@ async def test_emits_events_without_trace_context(self): client = MagicMock() client.http_client = AsyncMock() step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + ) with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ @@ -438,8 +451,8 @@ async def test_emits_events_without_trace_context(self): ) mock_emit.assert_called_once() call_args = mock_emit.call_args - assert call_args[0][3] is None # trace_id passed as None - assert call_args[0][4] is None # span_id passed as None + assert call_args[0][3] == "a" * 32 + assert call_args[0][4] == "b" * 16 @pytest.mark.asyncio async def test_forwards_trace_headers_to_server(self): @@ -492,6 +505,59 @@ async def test_forwards_trace_headers_to_server(self): assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" assert headers["X-Span-Id"] == "eeee5555ffff6666" + @pytest.mark.asyncio + async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): + """Server POST should resolve trace headers from the provider when omitted.""" + from agent_control_models import Step + + controls = [{ + "id": 1, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "deny"}, + "execution": "server", + }, + }] + + mock_http_response = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 1.0, + "matches": None, + "errors": None, + "non_matches": None, + } + mock_http_response.raise_for_status = MagicMock() + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "c" * 32, + "span_id": "d" * 16, + } + ) + + with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): + await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + ) + + call_kwargs = client.http_client.post.call_args + headers = call_kwargs.kwargs.get("headers", {}) + assert headers["X-Trace-Id"] == "c" * 32 + assert headers["X-Span-Id"] == "d" * 16 + # ============================================================================= # control_decorators non_matches dict conversion diff --git a/sdks/python/tests/test_trace_context.py b/sdks/python/tests/test_trace_context.py index e1305711..f08306e0 100644 --- a/sdks/python/tests/test_trace_context.py +++ b/sdks/python/tests/test_trace_context.py @@ -52,3 +52,14 @@ def test_get_trace_context_from_provider_returns_none_for_invalid_shape() -> Non ) assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_for_empty_ids() -> None: + set_trace_context_provider( + lambda: { + "trace_id": "", + "span_id": "", + } + ) + + assert get_trace_context_from_provider() is None From ff1e3440c3ac7bce17bbf085b396b10de6b0543e Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 31 Mar 2026 12:27:28 -0700 Subject: [PATCH 6/9] update docstring --- sdks/python/src/agent_control/evaluation.py | 66 +++++++++++++++++-- .../src/agent_control/evaluation_events.py | 37 ++++++++++- .../src/agent_control/telemetry/event_sink.py | 33 +++++++++- .../tests/test_observability_updates.py | 35 ++++++---- .../endpoints/evaluation.py | 47 +++++++++++-- 5 files changed, 190 insertions(+), 28 deletions(-) diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index e63940c6..b9de8d1c 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -17,7 +17,7 @@ from ._state import state from .client import AgentControlClient -from .evaluation_events import build_control_execution_events, deliver_oss_events +from .evaluation_events import build_control_execution_events, enqueue_observability_events from .telemetry import emit_control_events, has_control_event_sink from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name @@ -49,7 +49,20 @@ def _get_applicable_controls( def _build_server_control_lookup( server_control_payloads: list[dict[str, Any]], ) -> dict[int, ControlDefinition]: - """Return best-effort parsed server control definitions keyed by control ID.""" + """Build a best-effort lookup of server control definitions. + + The merged-event path reconstructs server-side events in the SDK after the + server returns a lightweight ``EvaluationResponse``. This helper parses the + cached server control payloads so the shared event builder can reconstruct + those events locally. + + Args: + server_control_payloads: Raw cached server control payloads. + + Returns: + A mapping of control ID to parsed ``ControlDefinition`` for every + payload that can be parsed locally. + """ control_lookup: dict[int, ControlDefinition] = {} for control in server_control_payloads: @@ -107,7 +120,20 @@ def _merge_results( local_result: EvaluationResponse, server_result: EvaluationResponse, ) -> EvaluationResult: - """Merge local and server evaluation results.""" + """Merge local and server evaluation results into one SDK-facing result. + + This helper merges only evaluation semantics. Event reconstruction happens + later so the response shape can stay lightweight regardless of which event + ingestion path is used. + + Args: + local_result: Evaluation response produced by SDK-local controls. + server_result: Evaluation response produced by server-side controls. + + Returns: + A merged ``EvaluationResult`` with combined matches, errors, + non-matches, and the strictest safety/confidence outcome. + """ is_safe = local_result.is_safe and server_result.is_safe confidence = min(local_result.confidence, server_result.confidence) @@ -173,7 +199,33 @@ async def check_evaluation_with_local( span_id: str | None = None, event_agent_name: str | None = None, ) -> EvaluationResult: - """Check if agent interaction is safe, running local controls first.""" + """Evaluate controls with local-first execution and configurable event flow. + + This is the main decision boundary between the two supported event + ingestion styles: + - default behavior: local events are reconstructed and queued immediately in + the SDK, while server-side events are still emitted by the server + - merged-event behavior: local and server events are reconstructed in the + SDK and emitted once through a registered sink + + In both cases, the evaluation result itself stays lightweight and event + reconstruction happens after evaluation completes. + + Args: + client: Configured AgentControl client. + agent_name: Agent name to evaluate against. + step: Step payload to evaluate. + stage: Evaluation stage, ``pre`` or ``post``. + controls: Cached control payloads used to split local vs server + execution. + trace_id: Optional explicit trace ID. + span_id: Optional explicit span ID. + event_agent_name: Optional override for the agent name stamped on + reconstructed events. + + Returns: + A merged evaluation result across local and server execution. + """ normalized_name = ensure_agent_name(agent_name) resolved_trace_id = trace_id resolved_span_id = span_id @@ -264,7 +316,9 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if applicable_local_controls: engine = ControlEngine(applicable_local_controls, context="sdk") local_result = await engine.process(request) - local_control_lookup = {control.id: control.control for control in applicable_local_controls} + local_control_lookup = { + control.id: control.control for control in applicable_local_controls + } local_events = build_control_execution_events( local_result, request, @@ -275,7 +329,7 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: ) if not merged_emission_enabled: - deliver_oss_events(local_events) + enqueue_observability_events(local_events) if not local_result.is_safe: result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) diff --git a/sdks/python/src/agent_control/evaluation_events.py b/sdks/python/src/agent_control/evaluation_events.py index f74e1d32..c21ac08a 100644 --- a/sdks/python/src/agent_control/evaluation_events.py +++ b/sdks/python/src/agent_control/evaluation_events.py @@ -125,7 +125,28 @@ def build_control_execution_events( span_id: str | None, agent_name: str | None, ) -> list[ControlExecutionEvent]: - """Construct final ControlExecutionEvents from an evaluation response.""" + """Reconstruct control execution events from an evaluation response. + + This is the shared reconstruction step used by both supported ingestion + styles: + - the default SDK observability path, where reconstructed local events are + queued into the existing SDK batcher + - the merged-event path, where local and server events are reconstructed in + the SDK and emitted together through a registered sink + + Args: + response: Evaluation response containing matches, errors, and + non-matches. + request: Original evaluation request used to derive stage and + ``applies_to``. + control_lookup: Parsed controls keyed by control ID. + trace_id: Optional trace ID for correlation. + span_id: Optional span ID for correlation. + agent_name: Optional override for the agent name stamped on events. + + Returns: + A list of reconstructed ``ControlExecutionEvent`` objects. + """ resolved_trace_id, resolved_span_id = _resolve_event_trace_context(trace_id, span_id) resolved_agent_name = agent_name or request.agent_name now = datetime.now(UTC) @@ -170,8 +191,18 @@ def build_control_execution_events( return events -def deliver_oss_events(events: list[ControlExecutionEvent]) -> None: - """Send reconstructed events through the existing OSS SDK observability path.""" +def enqueue_observability_events(events: list[ControlExecutionEvent]) -> None: + """Enqueue reconstructed events through the existing SDK observability path. + + This preserves the default SDK behavior of forwarding local events through + the existing observability batcher rather than a custom merged-event sink. + + Args: + events: Reconstructed control execution events to enqueue. + + Returns: + None. + """ if not is_observability_enabled(): return diff --git a/sdks/python/src/agent_control/telemetry/event_sink.py b/sdks/python/src/agent_control/telemetry/event_sink.py index 19062604..cb9910c3 100644 --- a/sdks/python/src/agent_control/telemetry/event_sink.py +++ b/sdks/python/src/agent_control/telemetry/event_sink.py @@ -10,13 +10,33 @@ def set_control_event_sink(sink: ControlEventSink | None) -> None: - """Register a sink for merged control execution events.""" + """Register a sink for merged control execution events. + + Registering a sink enables the optional merged-event path, where the SDK + reconstructs local and server events and emits them together after merging + results. + + Args: + sink: Sink callback to receive merged control execution events, or + ``None`` to clear the current sink. + + Returns: + None. + """ global _control_event_sink _control_event_sink = sink def emit_control_events(events: list[ControlExecutionEvent]) -> None: - """Emit merged control execution events to the registered sink.""" + """Emit merged control execution events to the registered sink. + + Args: + events: Merged control execution events to emit. + + Returns: + None. Sink failures are swallowed so evaluation behavior is not changed + by telemetry issues. + """ if not events or _control_event_sink is None: return @@ -28,7 +48,14 @@ def emit_control_events(events: list[ControlExecutionEvent]) -> None: def has_control_event_sink() -> bool: - """Return whether a merged control event sink is currently registered.""" + """Return whether the optional merged-event path is enabled. + + Args: + None. + + Returns: + ``True`` when a merged control event sink has been registered. + """ return _control_event_sink is not None diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index 3bb8b4cb..3d8c284c 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -7,7 +7,7 @@ from agent_control.evaluation import _ControlAdapter, _merge_results from agent_control.evaluation_events import ( build_control_execution_events, - deliver_oss_events, + enqueue_observability_events, map_applies_to, ) from agent_control.telemetry.trace_context import ( @@ -183,7 +183,7 @@ def test_composite_control_uses_representative_observability_identity(self): assert event.metadata["all_evaluators"] == ["regex"] assert event.metadata["all_selector_paths"] == ["input", "output"] - def test_deliver_oss_events_uses_existing_batcher(self): + def test_enqueue_observability_events_uses_existing_batcher(self): from agent_control_models import ControlExecutionEvent events = [ @@ -203,7 +203,7 @@ def test_deliver_oss_events_uses_existing_batcher(self): with patch("agent_control.evaluation_events.is_observability_enabled", return_value=True), \ patch("agent_control.evaluation_events.add_event") as mock_add: - deliver_oss_events(events) + enqueue_observability_events(events) mock_add.assert_called_once_with(events[0]) @@ -253,7 +253,7 @@ async def test_delivers_local_events_in_oss_mode(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation.deliver_oss_events") as mock_deliver: + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -265,8 +265,8 @@ async def test_delivers_local_events_in_oss_mode(self): event_agent_name="test-agent", ) - mock_deliver.assert_called_once() - delivered_events = mock_deliver.call_args.args[0] + mock_enqueue.assert_called_once() + delivered_events = mock_enqueue.call_args.args[0] assert len(delivered_events) == 1 assert delivered_events[0].trace_id == "abc123" assert delivered_events[0].span_id == "def456" @@ -275,9 +275,20 @@ async def test_delivers_local_events_in_oss_mode(self): @pytest.mark.asyncio async def test_resolves_provider_trace_context_for_local_events(self): - from agent_control_models import EvaluationResponse, Step + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step - mock_response = EvaluationResponse(is_safe=True, confidence=1.0) + mock_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + non_matches=[ + ControlMatch( + control_id=1, + control_name="test-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.1), + ) + ], + ) mock_engine = MagicMock() mock_engine.process = AsyncMock(return_value=mock_response) controls = [{ @@ -300,7 +311,7 @@ async def test_resolves_provider_trace_context_for_local_events(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation.deliver_oss_events") as mock_deliver: + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -309,7 +320,7 @@ async def test_resolves_provider_trace_context_for_local_events(self): controls=controls, ) - delivered_events = mock_deliver.call_args.args[0] + delivered_events = mock_enqueue.call_args.args[0] assert delivered_events[0].trace_id == "a" * 32 assert delivered_events[0].span_id == "b" * 16 @@ -500,7 +511,7 @@ async def test_merged_event_sink_emits_reconstructed_local_and_server_events_onc patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ patch("agent_control.evaluation.emit_control_events") as mock_emit, \ - patch("agent_control.evaluation.deliver_oss_events") as mock_deliver: + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -512,7 +523,7 @@ async def test_merged_event_sink_emits_reconstructed_local_and_server_events_onc event_agent_name="test-agent", ) - mock_deliver.assert_not_called() + mock_enqueue.assert_not_called() mock_emit.assert_called_once() merged_events = mock_emit.call_args.args[0] assert len(merged_events) == 2 diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index cd71ab07..68a315b4 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -274,10 +274,24 @@ def _build_observability_events( control_lookup: dict, total_duration_ms: float, ) -> list[ControlExecutionEvent]: - """Create observability events for all evaluated controls. - - Uses control_execution_id from the engine response to ensure correlation - between SDK logs and server observability events. + """Build observability events for all evaluated controls. + + This preserves the existing server-side event shape while allowing the + merged-event path to skip server-side ingestion and keep the response + lightweight. + + Args: + response: Raw evaluation response from the engine. + request: Original evaluation request. + trace_id: Trace ID to stamp on emitted events. + span_id: Span ID to stamp on emitted events. + agent_name: Agent name to stamp on emitted events. + applies_to: Observability applies_to value derived from the step type. + control_lookup: Controls keyed by control ID. + total_duration_ms: Total request execution duration in milliseconds. + + Returns: + A list of reconstructed server-side control execution events. """ events: list[ControlExecutionEvent] = [] now = datetime.now(UTC) @@ -398,3 +412,28 @@ async def _ingest_observability_events( f"Dropped {result.dropped} observability events, " f"processed {result.processed}" ) + + +async def _emit_observability_events( + response: EvaluationResponse, + request: EvaluationRequest, + trace_id: str, + span_id: str, + agent_name: str, + applies_to: Literal["llm_call", "tool_call"], + control_lookup: dict, + total_duration_ms: float, + ingestor: EventIngestor | None, +) -> None: + """Backward-compatible wrapper around build + ingest observability helpers.""" + events = _build_observability_events( + response=response, + request=request, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + applies_to=applies_to, + control_lookup=control_lookup, + total_duration_ms=total_duration_ms, + ) + await _ingest_observability_events(events, ingestor) From a8c40c5c372f4b020475328de330404294ee2248 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 31 Mar 2026 12:34:57 -0700 Subject: [PATCH 7/9] TS sdk fix --- sdks/typescript/src/generated/funcs/evaluation-evaluate.ts | 5 +++++ .../models/operations/evaluate-api-v1-evaluation-post.ts | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts b/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts index 47ec39e4..81862c54 100644 --- a/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts +++ b/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts @@ -109,6 +109,11 @@ async function $do( const headers = new Headers(compactMap({ "Content-Type": "application/json", Accept: "application/json", + "X-Agent-Control-Merge-Events": encodeSimple( + "X-Agent-Control-Merge-Events", + payload["X-Agent-Control-Merge-Events"], + { explode: false, charEncoding: "none" }, + ), "X-Span-Id": encodeSimple("X-Span-Id", payload["X-Span-Id"], { explode: false, charEncoding: "none", diff --git a/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts b/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts index 026e4065..204841a6 100644 --- a/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts +++ b/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts @@ -9,6 +9,7 @@ import * as models from "../index.js"; export type EvaluateApiV1EvaluationPostRequest = { xTraceId?: string | null | undefined; xSpanId?: string | null | undefined; + xAgentControlMergeEvents?: string | null | undefined; body: models.EvaluationRequest; }; @@ -16,6 +17,7 @@ export type EvaluateApiV1EvaluationPostRequest = { export type EvaluateApiV1EvaluationPostRequest$Outbound = { "X-Trace-Id"?: string | null | undefined; "X-Span-Id"?: string | null | undefined; + "X-Agent-Control-Merge-Events"?: string | null | undefined; body: models.EvaluationRequest$Outbound; }; @@ -27,12 +29,14 @@ export const EvaluateApiV1EvaluationPostRequest$outboundSchema: z.ZodMiniType< z.object({ xTraceId: z.optional(z.nullable(z.string())), xSpanId: z.optional(z.nullable(z.string())), + xAgentControlMergeEvents: z.optional(z.nullable(z.string())), body: models.EvaluationRequest$outboundSchema, }), z.transform((v) => { return remap$(v, { xTraceId: "X-Trace-Id", xSpanId: "X-Span-Id", + xAgentControlMergeEvents: "X-Agent-Control-Merge-Events", }); }), ); From f58778e19a1ceb534b8f2fbd71ebbeba2e6e4cfd Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 31 Mar 2026 19:19:05 -0700 Subject: [PATCH 8/9] address comments --- sdks/python/src/agent_control/evaluation.py | 101 ++++++++--- sdks/python/tests/test_evaluation.py | 1 + .../tests/test_observability_updates.py | 165 ++++++++++++++---- 3 files changed, 208 insertions(+), 59 deletions(-) diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index b9de8d1c..bc26bf37 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -18,6 +18,7 @@ from ._state import state from .client import AgentControlClient from .evaluation_events import build_control_execution_events, enqueue_observability_events +from .observability import is_observability_enabled from .telemetry import emit_control_events, has_control_event_sink from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name @@ -173,8 +174,36 @@ async def check_evaluation( step: Step, stage: Literal["pre", "post"], ) -> EvaluationResult: - """Check if agent interaction is safe.""" + """Check if agent interaction is safe through the public SDK helper. + + This helper preserves the default server-only evaluation path, but it also + participates in the optional merged-event flow when a control event sink is + registered. In that mode, the SDK asks the server to skip final event + ingestion, reconstructs server events from the lightweight response, and + emits them through the registered sink before returning the parsed result. + + Args: + client: Configured AgentControl client. + agent_name: Agent name to evaluate against. + step: Step payload to evaluate. + stage: Evaluation stage, ``pre`` or ``post``. + + Returns: + The parsed evaluation result returned by the server. + """ normalized_name = ensure_agent_name(agent_name) + merged_emission_enabled = has_control_event_sink() + trace_id = None + span_id = None + headers: dict[str, str] | None = None + + if merged_emission_enabled: + trace_id, span_id = get_trace_and_span_ids() + headers = { + "X-Trace-Id": trace_id, + "X-Span-Id": span_id, + "X-Agent-Control-Merge-Events": "true", + } request = EvaluationRequest( agent_name=normalized_name, @@ -183,10 +212,28 @@ async def check_evaluation( ) request_payload = request.model_dump(mode="json") - response = await client.http_client.post("/api/v1/evaluation", json=request_payload) + response = await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=headers, + ) response.raise_for_status() - return cast(EvaluationResult, EvaluationResult.from_dict(response.json())) + evaluation_response = EvaluationResponse.model_validate(response.json()) + + if merged_emission_enabled: + server_control_lookup = _build_server_control_lookup(state.server_controls or []) + server_events = build_control_execution_events( + evaluation_response, + request, + server_control_lookup, + trace_id, + span_id, + normalized_name, + ) + emit_control_events(server_events) + + return cast(EvaluationResult, EvaluationResult.from_dict(evaluation_response.model_dump())) async def check_evaluation_with_local( @@ -305,6 +352,7 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: return result.model_copy(update={"errors": combined_errors}) merged_emission_enabled = has_control_event_sink() + should_reconstruct_local_events = merged_emission_enabled or is_observability_enabled() local_result: EvaluationResponse | None = None local_events = [] @@ -316,20 +364,21 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if applicable_local_controls: engine = ControlEngine(applicable_local_controls, context="sdk") local_result = await engine.process(request) - local_control_lookup = { - control.id: control.control for control in applicable_local_controls - } - local_events = build_control_execution_events( - local_result, - request, - local_control_lookup, - resolved_trace_id, - resolved_span_id, - event_agent_name, - ) + if should_reconstruct_local_events: + local_control_lookup = { + control.id: control.control for control in applicable_local_controls + } + local_events = build_control_execution_events( + local_result, + request, + local_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, + ) - if not merged_emission_enabled: - enqueue_observability_events(local_events) + if not merged_emission_enabled: + enqueue_observability_events(local_events) if not local_result.is_safe: result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) @@ -354,15 +403,17 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: ) response.raise_for_status() server_result = EvaluationResponse.model_validate(response.json()) - server_control_lookup = _build_server_control_lookup(server_control_payloads) - server_events = build_control_execution_events( - server_result, - request, - server_control_lookup, - resolved_trace_id, - resolved_span_id, - event_agent_name, - ) + server_events = [] + if merged_emission_enabled: + server_control_lookup = _build_server_control_lookup(server_control_payloads) + server_events = build_control_execution_events( + server_result, + request, + server_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, + ) if local_result is not None: result = _with_parse_errors(_merge_results(local_result, server_result)) diff --git a/sdks/python/tests/test_evaluation.py b/sdks/python/tests/test_evaluation.py index e9842313..4c7a647b 100644 --- a/sdks/python/tests/test_evaluation.py +++ b/sdks/python/tests/test_evaluation.py @@ -66,6 +66,7 @@ def json(self) -> dict[str, object]: }, "stage": "pre", }, + headers=None, ) diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index 3d8c284c..5b81be21 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -326,6 +326,7 @@ async def test_resolves_provider_trace_context_for_local_events(self): @pytest.mark.asyncio async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): + """Server POST should resolve trace headers from the provider when omitted.""" from agent_control_models import Step controls = [{ @@ -355,7 +356,12 @@ async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): client.http_client = AsyncMock() client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider(lambda: {"trace_id": "c" * 32, "span_id": "d" * 16}) + set_trace_context_provider( + lambda: { + "trace_id": "c" * 32, + "span_id": "d" * 16, + } + ) with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): await evaluation.check_evaluation_with_local( @@ -366,56 +372,147 @@ async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): controls=controls, ) - headers = client.http_client.post.call_args.kwargs["headers"] - assert headers["X-Trace-Id"] == "c" * 32 - assert headers["X-Span-Id"] == "d" * 16 - # Verify POST was called with headers call_kwargs = client.http_client.post.call_args headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" - assert headers["X-Span-Id"] == "eeee5555ffff6666" + assert headers["X-Trace-Id"] == "c" * 32 + assert headers["X-Span-Id"] == "d" * 16 + +class TestCheckEvaluation: @pytest.mark.asyncio - async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): - """Server POST should resolve trace headers from the provider when omitted.""" + async def test_default_path_keeps_server_only_behavior(self): from agent_control_models import Step + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": None, + "errors": None, + "non_matches": None, + } + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.has_control_event_sink", return_value=False), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit: + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_emit.assert_not_called() + assert result.is_safe is True + assert result.confidence == 0.9 + + @pytest.mark.asyncio + async def test_merged_event_sink_emits_reconstructed_server_events(self): + from agent_control_models import Step + + controls = [ + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + } + ] + + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": [ + { + "control_id": 2, + "control_name": "server-ctrl", + "action": "allow", + "control_execution_id": "ce-server", + "result": {"matched": True, "confidence": 0.4}, + } + ], + "errors": None, + "non_matches": None, + } + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + patch("agent_control.evaluation.get_trace_and_span_ids", return_value=("a" * 32, "b" * 16)), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch.object(evaluation.state, "server_controls", controls): + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + headers = client.http_client.post.call_args.kwargs["headers"] + assert headers["X-Trace-Id"] == "a" * 32 + assert headers["X-Span-Id"] == "b" * 16 + assert headers["X-Agent-Control-Merge-Events"] == "true" + + mock_emit.assert_called_once() + emitted_events = mock_emit.call_args.args[0] + assert len(emitted_events) == 1 + assert emitted_events[0].control_id == 2 + assert emitted_events[0].trace_id == "a" * 32 + assert emitted_events[0].span_id == "b" * 16 + assert emitted_events[0].metadata["primary_evaluator"] == "regex" + assert result.matches is not None + assert len(result.matches) == 1 + + @pytest.mark.asyncio + async def test_skips_local_event_reconstruction_when_nothing_consumes_events(self): + from agent_control_models import EvaluationResponse, Step + controls = [{ "id": 1, - "name": "server-ctrl", + "name": "local-ctrl", "control": { "condition": { "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "deny"}, - "execution": "server", + "action": {"decision": "allow"}, + "execution": "sdk", }, }] - mock_http_response = MagicMock() - mock_http_response.json.return_value = { - "is_safe": True, - "confidence": 1.0, - "matches": None, - "errors": None, - "non_matches": None, - } - mock_http_response.raise_for_status = MagicMock() + mock_response = EvaluationResponse(is_safe=True, confidence=1.0) + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=mock_response) client = MagicMock() client.http_client = AsyncMock() - client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider( - lambda: { - "trace_id": "c" * 32, - "span_id": "d" * 16, - } - ) - with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): - await evaluation.check_evaluation_with_local( + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation.has_control_event_sink", return_value=False), \ + patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ + patch("agent_control.evaluation.build_control_execution_events") as mock_build, \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", step=step, @@ -423,10 +520,10 @@ async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): controls=controls, ) - call_kwargs = client.http_client.post.call_args - headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "c" * 32 - assert headers["X-Span-Id"] == "d" * 16 + mock_build.assert_not_called() + mock_enqueue.assert_not_called() + assert result.is_safe is True + assert result.confidence == 1.0 # ============================================================================= From c0d9a576692f0e8f31f1afcae105d6f56d761207 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 31 Mar 2026 20:12:00 -0700 Subject: [PATCH 9/9] ensure control sink exists for merged mode --- sdks/python/src/agent_control/__init__.py | 2 + sdks/python/src/agent_control/evaluation.py | 30 ++++++++++++- sdks/python/tests/test_init_step_merge.py | 2 + .../tests/test_observability_updates.py | 44 +++++++++++++++++-- sdks/python/tests/test_shutdown.py | 3 ++ 5 files changed, 75 insertions(+), 6 deletions(-) diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index d57ca0d9..dffdde05 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -477,6 +477,7 @@ async def handle(message: str): # Re-init behavior: always stop existing loop before mutating shared agent/session globals. _stop_policy_refresh_loop() + clear_control_event_sink() # Configure logging if provided (do this early before any logging happens) if log_config: @@ -647,6 +648,7 @@ def _reset_state() -> None: state.server_controls = None state.server_url = None state.api_key = None + clear_control_event_sink() async def ashutdown() -> None: diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index bc26bf37..170a46ef 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -117,6 +117,32 @@ def _has_applicable_prefiltered_server_controls( ) +def _is_merged_event_mode_enabled(agent_name: str) -> bool: + """Return whether SDK-side merged event emission is safe for this request. + + Merged reconstruction depends on initialized SDK session state: + a registered sink, an initialized agent, and a cached server-control + snapshot for the same agent. If any of those prerequisites are missing, + evaluation falls back to the default behavior where the server remains the + final emitter for server-side events. + + Args: + agent_name: Normalized agent name for the current request. + + Returns: + ``True`` when the current SDK session has enough state to reconstruct + and emit merged events safely. + """ + if not has_control_event_sink(): + return False + + current_agent = state.current_agent + if current_agent is None or current_agent.agent_name != agent_name: + return False + + return state.server_controls is not None + + def _merge_results( local_result: EvaluationResponse, server_result: EvaluationResponse, @@ -192,7 +218,7 @@ async def check_evaluation( The parsed evaluation result returned by the server. """ normalized_name = ensure_agent_name(agent_name) - merged_emission_enabled = has_control_event_sink() + merged_emission_enabled = _is_merged_event_mode_enabled(normalized_name) trace_id = None span_id = None headers: dict[str, str] | None = None @@ -351,7 +377,7 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: combined_errors = (result.errors or []) + parse_errors return result.model_copy(update={"errors": combined_errors}) - merged_emission_enabled = has_control_event_sink() + merged_emission_enabled = _is_merged_event_mode_enabled(normalized_name) should_reconstruct_local_events = merged_emission_enabled or is_observability_enabled() local_result: EvaluationResponse | None = None diff --git a/sdks/python/tests/test_init_step_merge.py b/sdks/python/tests/test_init_step_merge.py index 669e1236..cbe66fd9 100644 --- a/sdks/python/tests/test_init_step_merge.py +++ b/sdks/python/tests/test_init_step_merge.py @@ -20,9 +20,11 @@ class DoesNotExist: ... @pytest.fixture(autouse=True) def _clean_registry() -> Generator[None, None, None]: """Ensure each test starts with an empty step registry.""" + agent_control._reset_state() clear() yield clear() + agent_control._reset_state() def test_init_passes_merged_steps_to_register_agent( diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index 5b81be21..8dfe53b8 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -398,7 +398,7 @@ async def test_default_path_keeps_server_only_behavior(self): client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - with patch("agent_control.evaluation.has_control_event_sink", return_value=False), \ + with patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=False), \ patch("agent_control.evaluation.emit_control_events") as mock_emit: result = await evaluation.check_evaluation( client=client, @@ -455,7 +455,7 @@ async def test_merged_event_sink_emits_reconstructed_server_events(self): client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - with patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + with patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=True), \ patch("agent_control.evaluation.get_trace_and_span_ids", return_value=("a" * 32, "b" * 16)), \ patch("agent_control.evaluation.emit_control_events") as mock_emit, \ patch.object(evaluation.state, "server_controls", controls): @@ -508,7 +508,7 @@ async def test_skips_local_event_reconstruction_when_nothing_consumes_events(sel with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation.has_control_event_sink", return_value=False), \ + patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=False), \ patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ patch("agent_control.evaluation.build_control_execution_events") as mock_build, \ patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: @@ -525,6 +525,42 @@ async def test_skips_local_event_reconstruction_when_nothing_consumes_events(sel assert result.is_safe is True assert result.confidence == 1.0 + @pytest.mark.asyncio + async def test_check_evaluation_falls_back_when_initialized_agent_does_not_match(self): + from agent_control_models import Step + + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": None, + "errors": None, + "non_matches": None, + } + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch.object(evaluation.state, "current_agent", MagicMock(agent_name="agent-000000000002")), \ + patch.object(evaluation.state, "server_controls", []): + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_emit.assert_not_called() + assert result.is_safe is True + assert result.confidence == 0.9 + # ============================================================================= # control_decorators non_matches dict conversion @@ -606,7 +642,7 @@ async def test_merged_event_sink_emits_reconstructed_local_and_server_events_onc with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=True), \ patch("agent_control.evaluation.emit_control_events") as mock_emit, \ patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: result = await evaluation.check_evaluation_with_local( diff --git a/sdks/python/tests/test_shutdown.py b/sdks/python/tests/test_shutdown.py index 073745ed..1ee128d9 100644 --- a/sdks/python/tests/test_shutdown.py +++ b/sdks/python/tests/test_shutdown.py @@ -15,6 +15,7 @@ import agent_control.observability as obs_mod from agent_control._state import state from agent_control.observability import EventBatcher +from agent_control.telemetry.event_sink import has_control_event_sink, set_control_event_sink def _make_started_batcher() -> EventBatcher: @@ -64,6 +65,7 @@ def test_shutdown_resets_state(self): state.server_controls = [{"name": "test"}] state.server_url = "http://localhost:8000" state.api_key = "key" + set_control_event_sink(lambda events: None) agent_control.shutdown() @@ -72,6 +74,7 @@ def test_shutdown_resets_state(self): assert state.server_controls is None assert state.server_url is None assert state.api_key is None + assert has_control_event_sink() is False def test_shutdown_idempotent(self): agent_control.shutdown()