diff --git a/models/src/agent_control_models/evaluation.py b/models/src/agent_control_models/evaluation.py index 07ab4810..458c91a5 100644 --- a/models/src/agent_control_models/evaluation.py +++ b/models/src/agent_control_models/evaluation.py @@ -127,8 +127,6 @@ class EvaluationResponse(BaseModel): default=None, description="List of controls that were evaluated but did not match (if any)", ) - - class EvaluationResult(EvaluationResponse): """ Client-side result model for evaluation analysis. diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 24353411..dffdde05 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -95,8 +95,12 @@ async def handle_input(user_message: str) -> str: sync_shutdown_observability, ) from .telemetry import ( + clear_control_event_sink, clear_trace_context_provider, + emit_control_events, get_trace_context_from_provider, + has_control_event_sink, + set_control_event_sink, set_trace_context_provider, ) from .tracing import ( @@ -473,6 +477,7 @@ async def handle(message: str): # Re-init behavior: always stop existing loop before mutating shared agent/session globals. _stop_policy_refresh_loop() + clear_control_event_sink() # Configure logging if provided (do this early before any logging happens) if log_config: @@ -643,6 +648,7 @@ def _reset_state() -> None: state.server_controls = None state.server_url = None state.api_key = None + clear_control_event_sink() async def ashutdown() -> None: @@ -1307,6 +1313,10 @@ async def main(): "set_trace_context_provider", "get_trace_context_from_provider", "clear_trace_context_provider", + "set_control_event_sink", + "has_control_event_sink", + "emit_control_events", + "clear_control_event_sink", # Observability "init_observability", "add_event", diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index d76af177..170a46ef 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -1,14 +1,12 @@ """Evaluation check operations for Agent Control SDK.""" from dataclasses import dataclass -from datetime import UTC, datetime from typing import Any, Literal, cast from agent_control_engine import list_evaluators from agent_control_engine.core import ControlEngine from agent_control_models import ( ControlDefinition, - ControlExecutionEvent, ControlMatch, EvaluationRequest, EvaluationResponse, @@ -19,139 +17,12 @@ from ._state import state from .client import AgentControlClient -from .observability import add_event, get_logger, is_observability_enabled +from .evaluation_events import build_control_execution_events, enqueue_observability_events +from .observability import is_observability_enabled +from .telemetry import emit_control_events, has_control_event_sink from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name -_logger = get_logger(__name__) - -# Fallback IDs used when trace context is missing. -# All-zero values are invalid trace/span IDs per OpenTelemetry. -_FALLBACK_TRACE_ID = "0" * 32 -_FALLBACK_SPAN_ID = "0" * 16 -_trace_warning_logged = False - - -def _observability_metadata( - control_def: ControlDefinition, -) -> tuple[str | None, str | None, dict[str, object]]: - """Return representative event fields plus full composite context.""" - identity = control_def.observability_identity() - return ( - identity.selector_path, - identity.evaluator_name, - { - "primary_evaluator": identity.evaluator_name, - "primary_selector_path": identity.selector_path, - "leaf_count": identity.leaf_count, - "all_evaluators": identity.all_evaluators, - "all_selector_paths": identity.all_selector_paths, - }, - ) - - -def _map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: - return "tool_call" if step_type == "tool" else "llm_call" - - -def _emit_local_events( - local_result: "EvaluationResponse", - request: "EvaluationRequest", - local_controls: list["_ControlAdapter"], - trace_id: str | None, - span_id: str | None, - agent_name: str | None, -) -> None: - """Emit observability events for locally-evaluated controls. - - Mirrors the server's _emit_observability_events() so that SDK-evaluated - controls are visible in the observability pipeline. - - When trace_id/span_id are missing, fallback all-zero IDs are used so events - are still recorded (but clearly marked as uncorrelated). - - Only runs when observability is enabled. - """ - if not is_observability_enabled(): - return - - global _trace_warning_logged # noqa: PLW0603 - if not trace_id or not span_id: - if not _trace_warning_logged: - _logger.warning( - "Emitting local control events without trace context; " - "events will use fallback IDs and cannot be correlated with traces. " - "Pass trace_id/span_id for full observability." - ) - _trace_warning_logged = True - trace_id = trace_id or _FALLBACK_TRACE_ID - span_id = span_id or _FALLBACK_SPAN_ID - - applies_to = _map_applies_to(request.step.type) - control_lookup = {c.id: c for c in local_controls} - now = datetime.now(UTC) - resolved_agent_name = agent_name or request.agent_name - - def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: - if not matches: - return - for match in matches: - ctrl = control_lookup.get(match.control_id) - event_metadata = dict(match.result.metadata or {}) - selector_path = None - evaluator_name = None - if ctrl: - selector_path, evaluator_name, identity_metadata = _observability_metadata( - ctrl.control - ) - event_metadata.update(identity_metadata) - add_event( - ControlExecutionEvent( - control_execution_id=match.control_execution_id, - trace_id=trace_id, - span_id=span_id, - agent_name=resolved_agent_name, - control_id=match.control_id, - control_name=match.control_name, - check_stage=request.stage, - applies_to=applies_to, - action=match.action, - matched=matched, - confidence=match.result.confidence, - timestamp=now, - evaluator_name=evaluator_name, - selector_path=selector_path, - error_message=match.result.error if not matched else None, - metadata=event_metadata, - ) - ) - - _emit_matches(local_result.matches, matched=True) - _emit_matches(local_result.errors, matched=False) - _emit_matches(local_result.non_matches, matched=False) - - -async def check_evaluation( - client: AgentControlClient, - agent_name: str, - step: "Step", - stage: Literal["pre", "post"], -) -> EvaluationResult: - """Check if agent interaction is safe.""" - normalized_name = ensure_agent_name(agent_name) - - request = EvaluationRequest( - agent_name=normalized_name, - step=step, - stage=stage, - ) - request_payload = request.model_dump(mode="json") - - response = await client.http_client.post("/api/v1/evaluation", json=request_payload) - response.raise_for_status() - - return cast(EvaluationResult, EvaluationResult.from_dict(response.json())) - @dataclass class _ControlAdapter: @@ -159,7 +30,7 @@ class _ControlAdapter: id: int name: str - control: "ControlDefinition" + control: ControlDefinition def _get_applicable_controls( @@ -176,6 +47,35 @@ def _get_applicable_controls( return cast(list[_ControlAdapter], applicable_controls) +def _build_server_control_lookup( + server_control_payloads: list[dict[str, Any]], +) -> dict[int, ControlDefinition]: + """Build a best-effort lookup of server control definitions. + + The merged-event path reconstructs server-side events in the SDK after the + server returns a lightweight ``EvaluationResponse``. This helper parses the + cached server control payloads so the shared event builder can reconstruct + those events locally. + + Args: + server_control_payloads: Raw cached server control payloads. + + Returns: + A mapping of control ID to parsed ``ControlDefinition`` for every + payload that can be parsed locally. + """ + control_lookup: dict[int, ControlDefinition] = {} + + for control in server_control_payloads: + try: + control_lookup[control["id"]] = ControlDefinition.model_validate(control["control"]) + except Exception: + # The server remains authoritative for malformed/unparseable controls. + continue + + return control_lookup + + def _has_applicable_prefiltered_server_controls( server_control_payloads: list[dict[str, Any]], request: EvaluationRequest, @@ -217,11 +117,50 @@ def _has_applicable_prefiltered_server_controls( ) +def _is_merged_event_mode_enabled(agent_name: str) -> bool: + """Return whether SDK-side merged event emission is safe for this request. + + Merged reconstruction depends on initialized SDK session state: + a registered sink, an initialized agent, and a cached server-control + snapshot for the same agent. If any of those prerequisites are missing, + evaluation falls back to the default behavior where the server remains the + final emitter for server-side events. + + Args: + agent_name: Normalized agent name for the current request. + + Returns: + ``True`` when the current SDK session has enough state to reconstruct + and emit merged events safely. + """ + if not has_control_event_sink(): + return False + + current_agent = state.current_agent + if current_agent is None or current_agent.agent_name != agent_name: + return False + + return state.server_controls is not None + + def _merge_results( - local_result: "EvaluationResponse", - server_result: "EvaluationResponse", -) -> "EvaluationResult": - """Merge local and server evaluation results.""" + local_result: EvaluationResponse, + server_result: EvaluationResponse, +) -> EvaluationResult: + """Merge local and server evaluation results into one SDK-facing result. + + This helper merges only evaluation semantics. Event reconstruction happens + later so the response shape can stay lightweight regardless of which event + ingestion path is used. + + Args: + local_result: Evaluation response produced by SDK-local controls. + server_result: Evaluation response produced by server-side controls. + + Returns: + A merged ``EvaluationResult`` with combined matches, errors, + non-matches, and the strictest safety/confidence outcome. + """ is_safe = local_result.is_safe and server_result.is_safe confidence = min(local_result.confidence, server_result.confidence) @@ -255,41 +194,110 @@ def _merge_results( ) +async def check_evaluation( + client: AgentControlClient, + agent_name: str, + step: Step, + stage: Literal["pre", "post"], +) -> EvaluationResult: + """Check if agent interaction is safe through the public SDK helper. + + This helper preserves the default server-only evaluation path, but it also + participates in the optional merged-event flow when a control event sink is + registered. In that mode, the SDK asks the server to skip final event + ingestion, reconstructs server events from the lightweight response, and + emits them through the registered sink before returning the parsed result. + + Args: + client: Configured AgentControl client. + agent_name: Agent name to evaluate against. + step: Step payload to evaluate. + stage: Evaluation stage, ``pre`` or ``post``. + + Returns: + The parsed evaluation result returned by the server. + """ + normalized_name = ensure_agent_name(agent_name) + merged_emission_enabled = _is_merged_event_mode_enabled(normalized_name) + trace_id = None + span_id = None + headers: dict[str, str] | None = None + + if merged_emission_enabled: + trace_id, span_id = get_trace_and_span_ids() + headers = { + "X-Trace-Id": trace_id, + "X-Span-Id": span_id, + "X-Agent-Control-Merge-Events": "true", + } + + request = EvaluationRequest( + agent_name=normalized_name, + step=step, + stage=stage, + ) + request_payload = request.model_dump(mode="json") + + response = await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=headers, + ) + response.raise_for_status() + + evaluation_response = EvaluationResponse.model_validate(response.json()) + + if merged_emission_enabled: + server_control_lookup = _build_server_control_lookup(state.server_controls or []) + server_events = build_control_execution_events( + evaluation_response, + request, + server_control_lookup, + trace_id, + span_id, + normalized_name, + ) + emit_control_events(server_events) + + return cast(EvaluationResult, EvaluationResult.from_dict(evaluation_response.model_dump())) + + async def check_evaluation_with_local( client: AgentControlClient, agent_name: str, - step: "Step", + step: Step, stage: Literal["pre", "post"], controls: list[dict[str, Any]], trace_id: str | None = None, span_id: str | None = None, event_agent_name: str | None = None, ) -> EvaluationResult: - """ - Check if agent interaction is safe, running local controls first. + """Evaluate controls with local-first execution and configurable event flow. - This function executes controls with execution="sdk" locally in the SDK, - then calls the server for execution="server" controls. If a local control - denies, it short-circuits and returns immediately without calling the server. + This is the main decision boundary between the two supported event + ingestion styles: + - default behavior: local events are reconstructed and queued immediately in + the SDK, while server-side events are still emitted by the server + - merged-event behavior: local and server events are reconstructed in the + SDK and emitted once through a registered sink - Note on parse errors: If a local control fails to parse/validate, it is - skipped (logged as WARNING) and the error is included in result.errors. - This does NOT affect is_safe or confidence—callers concerned with safety - should check result.errors for any parse failures. + In both cases, the evaluation result itself stays lightweight and event + reconstruction happens after evaluation completes. Args: - client: AgentControlClient instance - agent_name: Normalized agent identifier - step: Step payload to evaluate - stage: 'pre' for pre-execution check, 'post' for post-execution check - controls: List of control dicts from initAgent response - (each has 'id', 'name', 'control' keys) + client: Configured AgentControl client. + agent_name: Agent name to evaluate against. + step: Step payload to evaluate. + stage: Evaluation stage, ``pre`` or ``post``. + controls: Cached control payloads used to split local vs server + execution. + trace_id: Optional explicit trace ID. + span_id: Optional explicit span ID. + event_agent_name: Optional override for the agent name stamped on + reconstructed events. Returns: - EvaluationResult with safety analysis (merged from local + server) - - Raises: - httpx.HTTPError: If server request fails + A merged evaluation result across local and server execution. """ normalized_name = ensure_agent_name(agent_name) resolved_trace_id = trace_id @@ -299,7 +307,6 @@ async def check_evaluation_with_local( resolved_trace_id = trace_id or current_trace_id resolved_span_id = span_id or current_span_id - # Partition controls by local flag local_controls: list[_ControlAdapter] = [] parse_errors: list[ControlMatch] = [] available_evaluators = list_evaluators() @@ -344,12 +351,6 @@ async def check_evaluation_with_local( except Exception as exc: control_id = control.get("id", -1) control_name = control.get("name", "unknown") - _logger.warning( - "Skipping invalid local control '%s' (id=%s): %s", - control_name, - control_id, - exc, - ) parse_errors.append( ControlMatch( control_id=control_id, @@ -374,16 +375,13 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if not parse_errors: return result combined_errors = (result.errors or []) + parse_errors - return EvaluationResult( - is_safe=result.is_safe, - confidence=result.confidence, - reason=result.reason, - matches=result.matches, - errors=combined_errors, - non_matches=result.non_matches, - ) + return result.model_copy(update={"errors": combined_errors}) + + merged_emission_enabled = _is_merged_event_mode_enabled(normalized_name) + should_reconstruct_local_events = merged_emission_enabled or is_observability_enabled() local_result: EvaluationResponse | None = None + local_events = [] applicable_local_controls = _get_applicable_controls( local_controls, request, @@ -392,27 +390,27 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if applicable_local_controls: engine = ControlEngine(applicable_local_controls, context="sdk") local_result = await engine.process(request) + if should_reconstruct_local_events: + local_control_lookup = { + control.id: control.control for control in applicable_local_controls + } + local_events = build_control_execution_events( + local_result, + request, + local_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, + ) - _emit_local_events( - local_result, - request, - applicable_local_controls, - resolved_trace_id, - resolved_span_id, - agent_name=event_agent_name, - ) + if not merged_emission_enabled: + enqueue_observability_events(local_events) if not local_result.is_safe: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) - ) + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if merged_emission_enabled: + emit_control_events(local_events) + return result if _has_applicable_prefiltered_server_controls(server_control_payloads, request): request_payload = request.model_dump(mode="json", exclude_none=True) @@ -421,6 +419,8 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: headers["X-Trace-Id"] = resolved_trace_id if resolved_span_id: headers["X-Span-Id"] = resolved_span_id + if merged_emission_enabled: + headers["X-Agent-Control-Merge-Events"] = "true" response = await client.http_client.post( "/api/v1/evaluation", @@ -429,32 +429,34 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: ) response.raise_for_status() server_result = EvaluationResponse.model_validate(response.json()) + server_events = [] + if merged_emission_enabled: + server_control_lookup = _build_server_control_lookup(server_control_payloads) + server_events = build_control_execution_events( + server_result, + request, + server_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, + ) if local_result is not None: - return _with_parse_errors(_merge_results(local_result, server_result)) - - return _with_parse_errors( - EvaluationResult( - is_safe=server_result.is_safe, - confidence=server_result.confidence, - reason=server_result.reason, - matches=server_result.matches, - errors=server_result.errors, - non_matches=server_result.non_matches, - ) - ) + result = _with_parse_errors(_merge_results(local_result, server_result)) + if merged_emission_enabled: + emit_control_events(local_events + server_events) + return result + + result = _with_parse_errors(EvaluationResult.model_validate(server_result.model_dump())) + if merged_emission_enabled: + emit_control_events(server_events) + return result if local_result is not None: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) - ) + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if merged_emission_enabled: + emit_control_events(local_events) + return result return _with_parse_errors(EvaluationResult(is_safe=True, confidence=1.0)) @@ -471,58 +473,10 @@ async def evaluate_controls( trace_id: str | None = None, span_id: str | None = None, ) -> EvaluationResult: - """ - Evaluate controls for a step. - - This convenience function evaluates controls (both local SDK-executed and - server-executed) for a given step. - - Args: - step_name: Name of the step (e.g., "chat", "search_db") - input: Input data for the step (for pre-stage evaluation) - output: Output data from the step (for post-stage evaluation) - context: Additional context metadata - step_type: Type of step - "llm" or "tool" (default: "llm") - stage: When to evaluate - "pre" or "post" (default: "pre") - agent_name: Agent name (required) - trace_id: Optional OpenTelemetry trace ID for observability - span_id: Optional OpenTelemetry span ID for observability - - Returns: - EvaluationResult with is_safe, confidence, reason, matches, errors - - Raises: - httpx.HTTPError: If server request fails - - Example: - import agent_control - - # Evaluate controls for an agent - result = await agent_control.evaluate_controls( - "chat", - input="User message here", - stage="pre", - agent_name="customer-service-bot" - ) - - # With trace/span IDs for observability - result = await agent_control.evaluate_controls( - "chat", - input="User message", - stage="pre", - agent_name="customer-service-bot", - trace_id="4bf92f3577b34da6a3ce929d0e0e4736", - span_id="00f067aa0ba902b7" - ) - """ - # Ensure server_url is set (for mypy type narrowing) + """Evaluate controls for a step.""" if state.server_url is None: - raise RuntimeError( - "Server URL not configured. Call agent_control.init() first." - ) + raise RuntimeError("Server URL not configured. Call agent_control.init() first.") - # Build Step dict (input and output are required by Step model) - # Tool steps require dict input/output, LLM steps use strings default_value = {} if step_type == "tool" else "" step_dict: dict[str, Any] = { "type": step_type, @@ -533,15 +487,11 @@ async def evaluate_controls( if context is not None: step_dict["context"] = context - # Convert to Step object if models available - step_obj = Step(**step_dict) # type: ignore - - # Get controls from server cache + step_obj = Step(**step_dict) # type: ignore[arg-type] resolved_controls = state.server_controls or [] - # Evaluate using local + server controls async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client: - result = await check_evaluation_with_local( + return await check_evaluation_with_local( client=client, agent_name=agent_name, step=step_obj, @@ -551,5 +501,3 @@ async def evaluate_controls( span_id=span_id, event_agent_name=agent_name, ) - - return result diff --git a/sdks/python/src/agent_control/evaluation_events.py b/sdks/python/src/agent_control/evaluation_events.py new file mode 100644 index 00000000..c21ac08a --- /dev/null +++ b/sdks/python/src/agent_control/evaluation_events.py @@ -0,0 +1,210 @@ +"""Derived control-execution event reconstruction for SDK evaluation flows.""" + +from datetime import UTC, datetime +from typing import Literal + +from agent_control_models import ( + ControlDefinition, + ControlExecutionEvent, + ControlMatch, + EvaluationRequest, + EvaluationResponse, +) + +from .observability import add_event, get_logger, is_observability_enabled + +_logger = get_logger(__name__) + +# All-zero values are invalid trace/span IDs per OpenTelemetry and make it +# obvious that the event could not be correlated to an external trace. +_FALLBACK_TRACE_ID = "0" * 32 +_FALLBACK_SPAN_ID = "0" * 16 +_trace_warning_logged = False + + +def observability_metadata( + control_def: ControlDefinition, +) -> tuple[str | None, str | None, dict[str, object]]: + """Return representative event fields plus full composite context.""" + identity = control_def.observability_identity() + return ( + identity.selector_path, + identity.evaluator_name, + { + "primary_evaluator": identity.evaluator_name, + "primary_selector_path": identity.selector_path, + "leaf_count": identity.leaf_count, + "all_evaluators": identity.all_evaluators, + "all_selector_paths": identity.all_selector_paths, + }, + ) + + +def map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: + """Map Agent Control step types to observability applies_to values.""" + return "tool_call" if step_type == "tool" else "llm_call" + + +def _resolve_event_trace_context( + trace_id: str | None, + span_id: str | None, +) -> tuple[str, str]: + """Return event IDs, applying fallback IDs and a one-time warning if needed.""" + global _trace_warning_logged # noqa: PLW0603 + + if trace_id and span_id: + return trace_id, span_id + + if not _trace_warning_logged: + _logger.warning( + "Emitting control events without trace context; events will use fallback " + "IDs and cannot be correlated with traces. Pass trace_id/span_id for " + "full observability." + ) + _trace_warning_logged = True + + return trace_id or _FALLBACK_TRACE_ID, span_id or _FALLBACK_SPAN_ID + + +def _build_events_for_matches( + matches: list[ControlMatch] | None, + *, + matched: bool, + request: EvaluationRequest, + control_lookup: dict[int, ControlDefinition], + trace_id: str, + span_id: str, + agent_name: str, + now: datetime, +) -> list[ControlExecutionEvent]: + if not matches: + return [] + + applies_to = map_applies_to(request.step.type) + events: list[ControlExecutionEvent] = [] + + for match in matches: + control_def = control_lookup.get(match.control_id) + event_metadata = dict(match.result.metadata or {}) + selector_path = None + evaluator_name = None + + if control_def is not None: + selector_path, evaluator_name, identity_metadata = observability_metadata(control_def) + event_metadata.update(identity_metadata) + + events.append( + ControlExecutionEvent( + control_execution_id=match.control_execution_id, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + control_id=match.control_id, + control_name=match.control_name, + check_stage=request.stage, + applies_to=applies_to, + action=match.action, + matched=matched, + confidence=match.result.confidence, + timestamp=now, + evaluator_name=evaluator_name, + selector_path=selector_path, + error_message=match.result.error if not matched else None, + metadata=event_metadata, + ) + ) + + return events + + +def build_control_execution_events( + response: EvaluationResponse, + request: EvaluationRequest, + control_lookup: dict[int, ControlDefinition], + trace_id: str | None, + span_id: str | None, + agent_name: str | None, +) -> list[ControlExecutionEvent]: + """Reconstruct control execution events from an evaluation response. + + This is the shared reconstruction step used by both supported ingestion + styles: + - the default SDK observability path, where reconstructed local events are + queued into the existing SDK batcher + - the merged-event path, where local and server events are reconstructed in + the SDK and emitted together through a registered sink + + Args: + response: Evaluation response containing matches, errors, and + non-matches. + request: Original evaluation request used to derive stage and + ``applies_to``. + control_lookup: Parsed controls keyed by control ID. + trace_id: Optional trace ID for correlation. + span_id: Optional span ID for correlation. + agent_name: Optional override for the agent name stamped on events. + + Returns: + A list of reconstructed ``ControlExecutionEvent`` objects. + """ + resolved_trace_id, resolved_span_id = _resolve_event_trace_context(trace_id, span_id) + resolved_agent_name = agent_name or request.agent_name + now = datetime.now(UTC) + + events: list[ControlExecutionEvent] = [] + events.extend( + _build_events_for_matches( + response.matches, + matched=True, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + events.extend( + _build_events_for_matches( + response.errors, + matched=False, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + events.extend( + _build_events_for_matches( + response.non_matches, + matched=False, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + return events + + +def enqueue_observability_events(events: list[ControlExecutionEvent]) -> None: + """Enqueue reconstructed events through the existing SDK observability path. + + This preserves the default SDK behavior of forwarding local events through + the existing observability batcher rather than a custom merged-event sink. + + Args: + events: Reconstructed control execution events to enqueue. + + Returns: + None. + """ + if not is_observability_enabled(): + return + + for event in events: + add_event(event) diff --git a/sdks/python/src/agent_control/telemetry/__init__.py b/sdks/python/src/agent_control/telemetry/__init__.py index 8d2ccf90..6e40b8a2 100644 --- a/sdks/python/src/agent_control/telemetry/__init__.py +++ b/sdks/python/src/agent_control/telemetry/__init__.py @@ -1,5 +1,12 @@ -"""Telemetry interfaces for provider-agnostic tracing.""" +"""Telemetry interfaces for provider-agnostic tracing and event emission.""" +from .event_sink import ( + ControlEventSink, + clear_control_event_sink, + emit_control_events, + has_control_event_sink, + set_control_event_sink, +) from .trace_context import ( TraceContext, TraceContextProvider, @@ -9,9 +16,14 @@ ) __all__ = [ + "ControlEventSink", "TraceContext", "TraceContextProvider", + "clear_control_event_sink", "clear_trace_context_provider", + "emit_control_events", "get_trace_context_from_provider", + "has_control_event_sink", + "set_control_event_sink", "set_trace_context_provider", ] diff --git a/sdks/python/src/agent_control/telemetry/event_sink.py b/sdks/python/src/agent_control/telemetry/event_sink.py new file mode 100644 index 00000000..cb9910c3 --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/event_sink.py @@ -0,0 +1,65 @@ +"""Provider-agnostic sink for merged control execution events.""" + +from collections.abc import Callable + +from agent_control_models import ControlExecutionEvent + +ControlEventSink = Callable[[list[ControlExecutionEvent]], None] + +_control_event_sink: ControlEventSink | None = None + + +def set_control_event_sink(sink: ControlEventSink | None) -> None: + """Register a sink for merged control execution events. + + Registering a sink enables the optional merged-event path, where the SDK + reconstructs local and server events and emits them together after merging + results. + + Args: + sink: Sink callback to receive merged control execution events, or + ``None`` to clear the current sink. + + Returns: + None. + """ + global _control_event_sink + _control_event_sink = sink + + +def emit_control_events(events: list[ControlExecutionEvent]) -> None: + """Emit merged control execution events to the registered sink. + + Args: + events: Merged control execution events to emit. + + Returns: + None. Sink failures are swallowed so evaluation behavior is not changed + by telemetry issues. + """ + if not events or _control_event_sink is None: + return + + try: + _control_event_sink(events) + except Exception: + # Sink failures should not break control evaluation. + pass + + +def has_control_event_sink() -> bool: + """Return whether the optional merged-event path is enabled. + + Args: + None. + + Returns: + ``True`` when a merged control event sink has been registered. + """ + return _control_event_sink is not None + + +def clear_control_event_sink() -> None: + """Clear the registered control event sink.""" + global _control_event_sink + _control_event_sink = None diff --git a/sdks/python/src/agent_control/telemetry/trace_context.py b/sdks/python/src/agent_control/telemetry/trace_context.py index a871fb29..be545725 100644 --- a/sdks/python/src/agent_control/telemetry/trace_context.py +++ b/sdks/python/src/agent_control/telemetry/trace_context.py @@ -38,8 +38,6 @@ def get_trace_context_from_provider() -> TraceContext | None: trace_id = trace_context.get("trace_id") span_id = trace_context.get("span_id") - if not isinstance(trace_id, str) or not isinstance(span_id, str): - return None if not trace_id or not span_id: return None diff --git a/sdks/python/tests/test_evaluation.py b/sdks/python/tests/test_evaluation.py index e9842313..4c7a647b 100644 --- a/sdks/python/tests/test_evaluation.py +++ b/sdks/python/tests/test_evaluation.py @@ -66,6 +66,7 @@ def json(self) -> dict[str, object]: }, "stage": "pre", }, + headers=None, ) diff --git a/sdks/python/tests/test_event_sink.py b/sdks/python/tests/test_event_sink.py new file mode 100644 index 00000000..8013f4d6 --- /dev/null +++ b/sdks/python/tests/test_event_sink.py @@ -0,0 +1,59 @@ +"""Tests for the telemetry merged control event sink interface.""" + +from datetime import UTC, datetime + +from agent_control.telemetry.event_sink import ( + clear_control_event_sink, + emit_control_events, + set_control_event_sink, +) +from agent_control_models import ControlExecutionEvent + + +def _event() -> ControlExecutionEvent: + return ControlExecutionEvent( + control_execution_id="ce-1", + trace_id="a" * 32, + span_id="b" * 16, + agent_name="test-agent", + control_id=1, + control_name="pii_check", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=0.95, + timestamp=datetime.now(UTC), + metadata={}, + ) + + +def teardown_function() -> None: + clear_control_event_sink() + + +def test_emit_control_events_calls_registered_sink() -> None: + seen: list[list[ControlExecutionEvent]] = [] + + def _sink(events: list[ControlExecutionEvent]) -> None: + seen.append(events) + + event = _event() + set_control_event_sink(_sink) + + emit_control_events([event]) + + assert seen == [[event]] + + +def test_emit_control_events_noops_without_sink() -> None: + emit_control_events([_event()]) + + +def test_emit_control_events_swallows_sink_failures() -> None: + def _sink(_events: list[ControlExecutionEvent]) -> None: + raise RuntimeError("boom") + + set_control_event_sink(_sink) + + emit_control_events([_event()]) diff --git a/sdks/python/tests/test_init_step_merge.py b/sdks/python/tests/test_init_step_merge.py index 669e1236..cbe66fd9 100644 --- a/sdks/python/tests/test_init_step_merge.py +++ b/sdks/python/tests/test_init_step_merge.py @@ -20,9 +20,11 @@ class DoesNotExist: ... @pytest.fixture(autouse=True) def _clean_registry() -> Generator[None, None, None]: """Ensure each test starts with an empty step registry.""" + agent_control._reset_state() clear() yield clear() + agent_control._reset_state() def test_init_passes_merged_steps_to_register_agent( diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index bb11a5ae..8dfe53b8 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -1,14 +1,18 @@ -"""Tests for observability updates: event emission, non_matches propagation, applies_to mapping.""" +"""Tests for reconstructed control-execution events in SDK evaluation flows.""" from unittest.mock import AsyncMock, MagicMock, patch import pytest from agent_control import evaluation -from agent_control.evaluation import ( - _ControlAdapter, - _emit_local_events, - _map_applies_to, - _merge_results, +from agent_control.evaluation import _ControlAdapter, _merge_results +from agent_control.evaluation_events import ( + build_control_execution_events, + enqueue_observability_events, + map_applies_to, +) +from agent_control.telemetry.trace_context import ( + clear_trace_context_provider, + set_trace_context_provider, ) from agent_control.telemetry.trace_context import ( clear_trace_context_provider, @@ -16,37 +20,19 @@ ) from agent_control_models import ControlDefinition -# ============================================================================= -# _map_applies_to tests -# ============================================================================= - class TestMapAppliesTo: - """Tests for _map_applies_to helper.""" - def test_maps_tool_to_tool_call(self): - assert _map_applies_to("tool") == "tool_call" + assert map_applies_to("tool") == "tool_call" def test_maps_llm_to_llm_call(self): - assert _map_applies_to("llm") == "llm_call" - - def test_maps_unknown_to_llm_call(self): - """Unknown types default to llm_call (matches server pattern).""" - assert _map_applies_to("unknown") == "llm_call" - assert _map_applies_to("") == "llm_call" - - -# ============================================================================= -# _merge_results tests -# ============================================================================= + assert map_applies_to("llm") == "llm_call" class TestMergeResults: - """Tests for _merge_results combining non_matches.""" - def _make_response(self, **kwargs): - """Create a mock EvaluationResponse.""" from agent_control_models import EvaluationResponse + defaults = { "is_safe": True, "confidence": 1.0, @@ -60,6 +46,7 @@ def _make_response(self, **kwargs): def _make_match(self, control_id, control_name="ctrl", action="allow", matched=True): from agent_control_models import ControlMatch, EvaluatorResult + return ControlMatch( control_id=control_id, control_name=control_name, @@ -67,69 +54,55 @@ def _make_match(self, control_id, control_name="ctrl", action="allow", matched=T result=EvaluatorResult(matched=matched, confidence=0.9), ) - def test_combines_non_matches(self): - """non_matches from both sides should be combined.""" - nm1 = self._make_match(1, "ctrl-1", matched=False) - nm2 = self._make_match(2, "ctrl-2", matched=False) - - local = self._make_response(non_matches=[nm1]) - server = self._make_response(non_matches=[nm2]) - - result = _merge_results(local, server) - assert result.non_matches is not None - assert len(result.non_matches) == 2 - ids = {nm.control_id for nm in result.non_matches} - assert ids == {1, 2} + def test_combines_matches_errors_and_non_matches(self): + local = self._make_response( + matches=[self._make_match(1)], + errors=[self._make_match(2, matched=False)], + ) + server = self._make_response(non_matches=[self._make_match(3, matched=False)]) - def test_non_matches_none_when_both_empty(self): - local = self._make_response() - server = self._make_response() result = _merge_results(local, server) - assert result.non_matches is None - def test_non_matches_from_one_side(self): - nm = self._make_match(1, matched=False) - local = self._make_response(non_matches=[nm]) - server = self._make_response() - result = _merge_results(local, server) - assert result.non_matches is not None - assert len(result.non_matches) == 1 - - def test_still_combines_matches_and_errors(self): - m1 = self._make_match(1, "m1") - m2 = self._make_match(2, "m2") - e1 = self._make_match(3, "e1", matched=False) + assert [match.control_id for match in result.matches or []] == [1] + assert [match.control_id for match in result.errors or []] == [2] + assert [match.control_id for match in result.non_matches or []] == [3] - local = self._make_response(matches=[m1], errors=[e1]) - server = self._make_response(matches=[m2]) - - result = _merge_results(local, server) - assert len(result.matches) == 2 - assert len(result.errors) == 1 +class TestBuildControlExecutionEvents: + def _make_control(self, id, name, condition): + return _ControlAdapter( + id=id, + name=name, + control=ControlDefinition( + execution="sdk", + condition=condition, + action={"decision": "allow"}, + ), + ) -# ============================================================================= -# _emit_local_events tests -# ============================================================================= + def _make_request(self, step_type="llm"): + from agent_control_models import EvaluationRequest + step_input = {"query": "hello"} if step_type == "tool" else "hello" + return EvaluationRequest( + agent_name="agent-000000000001", + step={"type": step_type, "name": "test-step", "input": step_input}, + stage="pre", + ) -class TestEmitLocalEvents: - """Tests for _emit_local_events helper.""" + def _make_match(self, control_id, control_name="ctrl", action="allow", matched=True): + from agent_control_models import ControlMatch, EvaluatorResult - def _make_control_adapter(self, id, name, evaluator_name="regex", selector_path="input"): - """Create a _ControlAdapter for testing.""" - control_def = ControlDefinition( - execution="sdk", - condition={ - "evaluator": {"name": evaluator_name, "config": {"pattern": "test"}}, - "selector": {"path": selector_path}, - }, - action={"decision": "deny"}, + return ControlMatch( + control_id=control_id, + control_name=control_name, + action=action, + result=EvaluatorResult(matched=matched, confidence=0.9), ) - return _ControlAdapter(id=id, name=name, control=control_def) def _make_response(self, matches=None, errors=None, non_matches=None): from agent_control_models import EvaluationResponse + return EvaluationResponse( is_safe=not bool(matches), confidence=1.0 if not matches else 0.5, @@ -138,130 +111,45 @@ def _make_response(self, matches=None, errors=None, non_matches=None): non_matches=non_matches, ) - def _make_match(self, control_id, control_name="ctrl", action="deny", matched=True): - from agent_control_models import ControlMatch, EvaluatorResult - return ControlMatch( - control_id=control_id, - control_name=control_name, - action=action, - result=EvaluatorResult(matched=matched, confidence=0.9), - ) - - def _make_request(self, step_type="llm"): - from agent_control_models import EvaluationRequest - # Tool steps require object input, LLM steps accept string - step_input = {"query": "hello"} if step_type == "tool" else "hello" - return EvaluationRequest( - agent_name="agent-000000000001", - step={"type": step_type, "name": "test-step", "input": step_input}, - stage="pre", - ) - - def test_emits_events_when_observability_enabled(self): - """Should call add_event for each match/error/non_match.""" - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - non_match = self._make_match(2, "ctrl-2", matched=False) - response = self._make_response(matches=[match], non_matches=[non_match]) - request = self._make_request() - - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, - [ctrl, self._make_control_adapter(2, "ctrl-2")], - "trace123", "span456", "test-agent", - ) - assert mock_add.call_count == 2 - # Verify event fields for the match - event = mock_add.call_args_list[0][0][0] - assert event.trace_id == "trace123" - assert event.span_id == "span456" - assert event.agent_name == "test-agent" - assert event.matched is True - assert event.evaluator_name == "regex" - assert event.selector_path == "input" - - def test_skips_when_observability_disabled(self): - """Should not call add_event when observability is disabled.""" - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) + def test_builds_events_with_trace_context(self): + response = self._make_response(matches=[self._make_match(1, "ctrl-1")]) request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + ).control + } - with patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) - mock_add.assert_not_called() - - def test_maps_tool_step_to_tool_call(self): - """Should set applies_to='tool_call' for tool steps.""" - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request(step_type="tool") - - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) - event = mock_add.call_args_list[0][0][0] - assert event.applies_to == "tool_call" - - def test_uses_fallback_ids_when_trace_context_missing(self): - """Should emit events with all-zero fallback IDs when trace context is absent.""" - import agent_control.evaluation as eval_mod - from agent_control.evaluation import ( - _FALLBACK_SPAN_ID, - _FALLBACK_TRACE_ID, - _emit_local_events, + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", ) - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request() - - # Reset the once-only warning flag so the warning fires in this test - eval_mod._trace_warning_logged = False + assert len(events) == 1 + event = events[0] + assert event.trace_id == "trace123" + assert event.span_id == "span456" + assert event.agent_name == "test-agent" + assert event.evaluator_name == "regex" + assert event.selector_path == "input" - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add, \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events( - response, request, [ctrl], - None, None, "test-agent", - ) - assert mock_add.call_count == 1 - event = mock_add.call_args_list[0][0][0] - assert event.trace_id == _FALLBACK_TRACE_ID - assert event.span_id == _FALLBACK_SPAN_ID - assert event.trace_id == "0" * 32 - assert event.span_id == "0" * 16 - # Warning should have been logged - mock_logger.warning.assert_called_once() - assert "fallback" in mock_logger.warning.call_args[0][0].lower() - - def test_composite_control_emits_representative_leaf_metadata(self): - """Composite local controls should emit stable representative metadata.""" - # Given: a composite local control and a non-match response for that control - ctrl = _ControlAdapter( - id=1, - name="composite-ctrl", - control=ControlDefinition( - execution="sdk", - condition={ + def test_composite_control_uses_representative_observability_identity(self): + response = self._make_response(non_matches=[self._make_match(1, "ctrl-1", matched=False)]) + request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { "and": [ { "selector": {"path": "input"}, @@ -273,27 +161,20 @@ def test_composite_control_emits_representative_leaf_metadata(self): }, ] }, - action={"decision": "allow"}, - ), - ) - non_match = self._make_match(1, "composite-ctrl", action="allow", matched=False) - response = self._make_response(non_matches=[non_match]) - request = self._make_request() + ).control + } - # When: emitting local observability events - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, - request, - [ctrl], - "trace123", - "span456", - "test-agent", - ) - event = mock_add.call_args_list[0][0][0] + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", + ) - # Then: the first leaf becomes the event identity and full context is preserved + assert len(events) == 1 + event = events[0] assert event.evaluator_name == "regex" assert event.selector_path == "input" assert event.metadata["primary_evaluator"] == "regex" @@ -302,52 +183,45 @@ def test_composite_control_emits_representative_leaf_metadata(self): assert event.metadata["all_evaluators"] == ["regex"] assert event.metadata["all_selector_paths"] == ["input", "output"] - def test_fallback_warning_logged_only_once(self): - """The missing-trace-context warning should fire only on the first call.""" - import agent_control.evaluation as eval_mod - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request() + def test_enqueue_observability_events_uses_existing_batcher(self): + from agent_control_models import ControlExecutionEvent - eval_mod._trace_warning_logged = False - - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event"), \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") - assert mock_logger.warning.call_count == 1 + events = [ + ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=1, + control_name="ctrl-1", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=1.0, + ) + ] + with patch("agent_control.evaluation_events.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation_events.add_event") as mock_add: + enqueue_observability_events(events) -# ============================================================================= -# check_evaluation_with_local event emission + header forwarding -# ============================================================================= + mock_add.assert_called_once_with(events[0]) class TestCheckEvaluationWithLocal: - """Tests for check_evaluation_with_local event emission and non_matches.""" + def teardown_method(self) -> None: + clear_trace_context_provider() def teardown_method(self) -> None: clear_trace_context_provider() @pytest.mark.asyncio - async def test_emits_events_when_trace_context_provided(self): - """Should emit observability events when trace_id and span_id are passed.""" - from agent_control_models import ( - ControlMatch, - EvaluationResponse, - EvaluatorResult, - Step, - ) + async def test_delivers_local_events_in_oss_mode(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step mock_response = EvaluationResponse( is_safe=True, confidence=1.0, - matches=None, - errors=None, non_matches=[ ControlMatch( control_id=1, @@ -357,7 +231,6 @@ async def test_emits_events_when_trace_context_provided(self): ) ], ) - mock_engine = MagicMock() mock_engine.process = AsyncMock(return_value=mock_response) @@ -380,7 +253,7 @@ async def test_emits_events_when_trace_context_provided(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -392,29 +265,32 @@ async def test_emits_events_when_trace_context_provided(self): event_agent_name="test-agent", ) - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][2] is not None # local_controls - assert call_args[0][3] == "abc123" # trace_id - assert call_args[0][4] == "def456" # span_id - assert call_args.kwargs["agent_name"] == "test-agent" - - # Also verify non_matches propagated + mock_enqueue.assert_called_once() + delivered_events = mock_enqueue.call_args.args[0] + assert len(delivered_events) == 1 + assert delivered_events[0].trace_id == "abc123" + assert delivered_events[0].span_id == "def456" assert result.non_matches is not None assert len(result.non_matches) == 1 @pytest.mark.asyncio - async def test_emits_events_without_trace_context(self): - """Should resolve trace context from the provider when IDs are omitted.""" - from agent_control_models import EvaluationResponse, Step + async def test_resolves_provider_trace_context_for_local_events(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step mock_response = EvaluationResponse( - is_safe=True, confidence=1.0, matches=None, errors=None, non_matches=None, + is_safe=True, + confidence=1.0, + non_matches=[ + ControlMatch( + control_id=1, + control_name="test-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.1), + ) + ], ) - mock_engine = MagicMock() mock_engine.process = AsyncMock(return_value=mock_response) - controls = [{ "id": 1, "name": "test-ctrl", @@ -431,35 +307,28 @@ async def test_emits_events_without_trace_context(self): client = MagicMock() client.http_client = AsyncMock() step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider( - lambda: { - "trace_id": "a" * 32, - "span_id": "b" * 16, - } - ) + set_trace_context_provider(lambda: {"trace_id": "a" * 32, "span_id": "b" * 16}) with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", step=step, stage="pre", controls=controls, - # No trace_id/span_id ) - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][3] == "a" * 32 - assert call_args[0][4] == "b" * 16 + + delivered_events = mock_enqueue.call_args.args[0] + assert delivered_events[0].trace_id == "a" * 32 + assert delivered_events[0].span_id == "b" * 16 @pytest.mark.asyncio - async def test_forwards_trace_headers_to_server(self): - """Server POST should include X-Trace-Id and X-Span-Id headers.""" + async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): + """Server POST should resolve trace headers from the provider when omitted.""" from agent_control_models import Step - # Only server controls, no local controls controls = [{ "id": 1, "name": "server-ctrl", @@ -487,6 +356,12 @@ async def test_forwards_trace_headers_to_server(self): client.http_client = AsyncMock() client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "c" * 32, + "span_id": "d" * 16, + } + ) with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): await evaluation.check_evaluation_with_local( @@ -495,68 +370,196 @@ async def test_forwards_trace_headers_to_server(self): step=step, stage="pre", controls=controls, - trace_id="aaaa1111bbbb2222cccc3333dddd4444", - span_id="eeee5555ffff6666", ) - # Verify POST was called with headers call_kwargs = client.http_client.post.call_args headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" - assert headers["X-Span-Id"] == "eeee5555ffff6666" + assert headers["X-Trace-Id"] == "c" * 32 + assert headers["X-Span-Id"] == "d" * 16 + +class TestCheckEvaluation: @pytest.mark.asyncio - async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): - """Server POST should resolve trace headers from the provider when omitted.""" + async def test_default_path_keeps_server_only_behavior(self): from agent_control_models import Step + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": None, + "errors": None, + "non_matches": None, + } + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=False), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit: + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_emit.assert_not_called() + assert result.is_safe is True + assert result.confidence == 0.9 + + @pytest.mark.asyncio + async def test_merged_event_sink_emits_reconstructed_server_events(self): + from agent_control_models import Step + + controls = [ + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + } + ] + + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": [ + { + "control_id": 2, + "control_name": "server-ctrl", + "action": "allow", + "control_execution_id": "ce-server", + "result": {"matched": True, "confidence": 0.4}, + } + ], + "errors": None, + "non_matches": None, + } + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=True), \ + patch("agent_control.evaluation.get_trace_and_span_ids", return_value=("a" * 32, "b" * 16)), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch.object(evaluation.state, "server_controls", controls): + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + headers = client.http_client.post.call_args.kwargs["headers"] + assert headers["X-Trace-Id"] == "a" * 32 + assert headers["X-Span-Id"] == "b" * 16 + assert headers["X-Agent-Control-Merge-Events"] == "true" + + mock_emit.assert_called_once() + emitted_events = mock_emit.call_args.args[0] + assert len(emitted_events) == 1 + assert emitted_events[0].control_id == 2 + assert emitted_events[0].trace_id == "a" * 32 + assert emitted_events[0].span_id == "b" * 16 + assert emitted_events[0].metadata["primary_evaluator"] == "regex" + assert result.matches is not None + assert len(result.matches) == 1 + + @pytest.mark.asyncio + async def test_skips_local_event_reconstruction_when_nothing_consumes_events(self): + from agent_control_models import EvaluationResponse, Step + controls = [{ "id": 1, - "name": "server-ctrl", + "name": "local-ctrl", "control": { "condition": { "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "deny"}, - "execution": "server", + "action": {"decision": "allow"}, + "execution": "sdk", }, }] + mock_response = EvaluationResponse(is_safe=True, confidence=1.0) + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=mock_response) + + client = MagicMock() + client.http_client = AsyncMock() + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=False), \ + patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ + patch("agent_control.evaluation.build_control_execution_events") as mock_build, \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + ) + + mock_build.assert_not_called() + mock_enqueue.assert_not_called() + assert result.is_safe is True + assert result.confidence == 1.0 + + @pytest.mark.asyncio + async def test_check_evaluation_falls_back_when_initialized_agent_does_not_match(self): + from agent_control_models import Step + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() mock_http_response.json.return_value = { "is_safe": True, - "confidence": 1.0, + "confidence": 0.9, "matches": None, "errors": None, "non_matches": None, } - mock_http_response.raise_for_status = MagicMock() client = MagicMock() client.http_client = AsyncMock() client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider( - lambda: { - "trace_id": "c" * 32, - "span_id": "d" * 16, - } - ) - with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): - await evaluation.check_evaluation_with_local( + with patch("agent_control.evaluation.has_control_event_sink", return_value=True), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch.object(evaluation.state, "current_agent", MagicMock(agent_name="agent-000000000002")), \ + patch.object(evaluation.state, "server_controls", []): + result = await evaluation.check_evaluation( client=client, agent_name="agent-000000000001", step=step, stage="pre", - controls=controls, ) - call_kwargs = client.http_client.post.call_args - headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "c" * 32 - assert headers["X-Span-Id"] == "d" * 16 + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_emit.assert_not_called() + assert result.is_safe is True + assert result.confidence == 0.9 # ============================================================================= @@ -568,45 +571,97 @@ class TestControlDecoratorsNonMatches: """Tests for non_matches dict conversion in control_decorators._evaluate.""" @pytest.mark.asyncio - async def test_non_matches_populated_in_stats(self): - """non_matches should be properly converted to dicts for stats tracking.""" - from agent_control.control_decorators import ControlContext + async def test_merged_event_sink_emits_reconstructed_local_and_server_events_once(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step - # Simulate a result dict with non_matches - result = { + local_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + matches=[ + ControlMatch( + control_id=1, + control_name="local-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.8), + ) + ], + ) + server_response = { "is_safe": True, - "confidence": 1.0, - "matches": None, - "errors": None, - "non_matches": [ - { - "control_id": 1, - "control_name": "ctrl-1", - "action": "allow", - "result": {"matched": False, "confidence": 0.1}, - }, + "confidence": 0.9, + "matches": [ { "control_id": 2, - "control_name": "ctrl-2", - "action": "deny", - "result": {"matched": False, "confidence": 0.2}, - }, + "control_name": "server-ctrl", + "action": "allow", + "control_execution_id": "ce-server", + "result": {"matched": False, "confidence": 0.4}, + } ], + "errors": None, + "non_matches": None, } - ctx = ControlContext( - agent_name="test-agent", - server_url="http://localhost:8000", - func=lambda: None, - args=(), - kwargs={}, - trace_id="trace123", - span_id="span456", - start_time=0, - ) + controls = [ + { + "id": 1, + "name": "local-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "sdk", + }, + }, + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + }, + ] - ctx._update_stats(result) - assert ctx.total_executions == 2 - assert ctx.total_non_matches == 2 - assert ctx.total_matches == 0 - assert ctx.total_errors == 0 + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=local_response) + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = server_response + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation._is_merged_event_mode_enabled", return_value=True), \ + patch("agent_control.evaluation.emit_control_events") as mock_emit, \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + trace_id="abc123", + span_id="def456", + event_agent_name="test-agent", + ) + + mock_enqueue.assert_not_called() + mock_emit.assert_called_once() + merged_events = mock_emit.call_args.args[0] + assert len(merged_events) == 2 + assert {event.control_id for event in merged_events} == {1, 2} + headers = client.http_client.post.call_args.kwargs["headers"] + assert headers["X-Agent-Control-Merge-Events"] == "true" + assert result.matches is not None + assert len(result.matches) == 2 diff --git a/sdks/python/tests/test_shutdown.py b/sdks/python/tests/test_shutdown.py index 073745ed..1ee128d9 100644 --- a/sdks/python/tests/test_shutdown.py +++ b/sdks/python/tests/test_shutdown.py @@ -15,6 +15,7 @@ import agent_control.observability as obs_mod from agent_control._state import state from agent_control.observability import EventBatcher +from agent_control.telemetry.event_sink import has_control_event_sink, set_control_event_sink def _make_started_batcher() -> EventBatcher: @@ -64,6 +65,7 @@ def test_shutdown_resets_state(self): state.server_controls = [{"name": "test"}] state.server_url = "http://localhost:8000" state.api_key = "key" + set_control_event_sink(lambda events: None) agent_control.shutdown() @@ -72,6 +74,7 @@ def test_shutdown_resets_state(self): assert state.server_controls is None assert state.server_url is None assert state.api_key is None + assert has_control_event_sink() is False def test_shutdown_idempotent(self): agent_control.shutdown() diff --git a/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts b/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts index 47ec39e4..81862c54 100644 --- a/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts +++ b/sdks/typescript/src/generated/funcs/evaluation-evaluate.ts @@ -109,6 +109,11 @@ async function $do( const headers = new Headers(compactMap({ "Content-Type": "application/json", Accept: "application/json", + "X-Agent-Control-Merge-Events": encodeSimple( + "X-Agent-Control-Merge-Events", + payload["X-Agent-Control-Merge-Events"], + { explode: false, charEncoding: "none" }, + ), "X-Span-Id": encodeSimple("X-Span-Id", payload["X-Span-Id"], { explode: false, charEncoding: "none", diff --git a/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts b/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts index 026e4065..204841a6 100644 --- a/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts +++ b/sdks/typescript/src/generated/models/operations/evaluate-api-v1-evaluation-post.ts @@ -9,6 +9,7 @@ import * as models from "../index.js"; export type EvaluateApiV1EvaluationPostRequest = { xTraceId?: string | null | undefined; xSpanId?: string | null | undefined; + xAgentControlMergeEvents?: string | null | undefined; body: models.EvaluationRequest; }; @@ -16,6 +17,7 @@ export type EvaluateApiV1EvaluationPostRequest = { export type EvaluateApiV1EvaluationPostRequest$Outbound = { "X-Trace-Id"?: string | null | undefined; "X-Span-Id"?: string | null | undefined; + "X-Agent-Control-Merge-Events"?: string | null | undefined; body: models.EvaluationRequest$Outbound; }; @@ -27,12 +29,14 @@ export const EvaluateApiV1EvaluationPostRequest$outboundSchema: z.ZodMiniType< z.object({ xTraceId: z.optional(z.nullable(z.string())), xSpanId: z.optional(z.nullable(z.string())), + xAgentControlMergeEvents: z.optional(z.nullable(z.string())), body: models.EvaluationRequest$outboundSchema, }), z.transform((v) => { return remap$(v, { xTraceId: "X-Trace-Id", xSpanId: "X-Span-Id", + xAgentControlMergeEvents: "X-Agent-Control-Merge-Events", }); }), ); diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index c92ea315..68a315b4 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -156,6 +156,7 @@ async def evaluate( db: AsyncSession = Depends(get_async_db), x_trace_id: str | None = Header(default=None, alias="X-Trace-Id"), x_span_id: str | None = Header(default=None, alias="X-Span-Id"), + x_merge_events: str | None = Header(default=None, alias="X-Agent-Control-Merge-Events"), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -238,30 +239,32 @@ async def evaluate( # Calculate total execution time total_duration_ms = (time.perf_counter() - start_time) * 1000 - # Emit observability events if enabled - if observability_settings.enabled: + merge_events_requested = (x_merge_events or "").lower() == "true" + response_events = _build_observability_events( + response=raw_response, + request=request, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + applies_to=applies_to, + control_lookup=control_lookup, + total_duration_ms=total_duration_ms, + ) + + # OSS keeps server-side ingestion as the default. Enterprise merged mode + # returns events to the SDK and skips this server-side delivery step. + if observability_settings.enabled and not merge_events_requested: # Get ingestor from app.state (None if not initialized) try: ingestor = get_event_ingestor(req) except RuntimeError: ingestor = None - - await _emit_observability_events( - response=raw_response, - request=request, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - applies_to=applies_to, - control_lookup=control_lookup, - total_duration_ms=total_duration_ms, - ingestor=ingestor, - ) + await _ingest_observability_events(response_events, ingestor) return _sanitize_evaluation_response(raw_response) -async def _emit_observability_events( +def _build_observability_events( response: EvaluationResponse, request: EvaluationRequest, trace_id: str, @@ -270,12 +273,25 @@ async def _emit_observability_events( applies_to: Literal["llm_call", "tool_call"], control_lookup: dict, total_duration_ms: float, - ingestor: EventIngestor | None, -) -> None: - """Create and enqueue observability events for all evaluated controls. - - Uses control_execution_id from the engine response to ensure correlation - between SDK logs and server observability events. + ) -> list[ControlExecutionEvent]: + """Build observability events for all evaluated controls. + + This preserves the existing server-side event shape while allowing the + merged-event path to skip server-side ingestion and keep the response + lightweight. + + Args: + response: Raw evaluation response from the engine. + request: Original evaluation request. + trace_id: Trace ID to stamp on emitted events. + span_id: Span ID to stamp on emitted events. + agent_name: Agent name to stamp on emitted events. + applies_to: Observability applies_to value derived from the step type. + control_lookup: Controls keyed by control ID. + total_duration_ms: Total request execution duration in milliseconds. + + Returns: + A list of reconstructed server-side control execution events. """ events: list[ControlExecutionEvent] = [] now = datetime.now(UTC) @@ -379,11 +395,45 @@ async def _emit_observability_events( ) ) - # Ingest events - if events and ingestor: - result = await ingestor.ingest(events) - if result.dropped > 0: - _logger.warning( - f"Dropped {result.dropped} observability events, " - f"processed {result.processed}" - ) + return events + + +async def _ingest_observability_events( + events: list[ControlExecutionEvent], + ingestor: EventIngestor | None, +) -> None: + """Ingest server-side observability events when OSS batching is active.""" + if not events or ingestor is None: + return + + result = await ingestor.ingest(events) + if result.dropped > 0: + _logger.warning( + f"Dropped {result.dropped} observability events, " + f"processed {result.processed}" + ) + + +async def _emit_observability_events( + response: EvaluationResponse, + request: EvaluationRequest, + trace_id: str, + span_id: str, + agent_name: str, + applies_to: Literal["llm_call", "tool_call"], + control_lookup: dict, + total_duration_ms: float, + ingestor: EventIngestor | None, +) -> None: + """Backward-compatible wrapper around build + ingest observability helpers.""" + events = _build_observability_events( + response=response, + request=request, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + applies_to=applies_to, + control_lookup=control_lookup, + total_duration_ms=total_duration_ms, + ) + await _ingest_observability_events(events, ingestor) diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index 942dca66..dd2035eb 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -3,7 +3,13 @@ import uuid from unittest.mock import AsyncMock, MagicMock -from agent_control_models import ControlMatch, EvaluationRequest, EvaluatorResult, Step +from agent_control_models import ( + ControlExecutionEvent, + ControlMatch, + EvaluationRequest, + EvaluatorResult, + Step, +) from fastapi.testclient import TestClient from agent_control_server.endpoints.evaluation import ( @@ -198,8 +204,10 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani lambda _config: mock_evaluator, ) - emit_mock = AsyncMock() - monkeypatch.setattr(evaluation_module, "_emit_observability_events", emit_mock) + build_mock = MagicMock(return_value=[]) + ingest_mock = AsyncMock() + monkeypatch.setattr(evaluation_module, "_build_observability_events", build_mock) + monkeypatch.setattr(evaluation_module, "_ingest_observability_events", ingest_mock) monkeypatch.setattr(evaluation_module.observability_settings, "enabled", True) # When: sending an evaluation request @@ -220,8 +228,8 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani assert data["errors"][0]["result"]["error"] == SAFE_EVALUATOR_ERROR # And: observability receives the raw engine response with unsanitized diagnostics - emit_mock.assert_awaited_once() - raw_response = emit_mock.await_args.kwargs["response"] + build_mock.assert_called_once() + raw_response = build_mock.call_args.kwargs["response"] assert raw_response.errors is not None raw_error = raw_response.errors[0] assert raw_error.control_name == control_name @@ -229,6 +237,7 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani raw_trace = raw_error.result.metadata["condition_trace"] assert raw_trace["error"] == "RuntimeError: Simulated evaluator crash" assert raw_trace["message"] == "Evaluation failed: RuntimeError: Simulated evaluator crash" + ingest_mock.assert_awaited_once() def test_sanitize_control_match_redacts_nested_condition_trace_errors() -> None: @@ -372,3 +381,43 @@ async def ingest(self, events): # type: ignore[no-untyped-def] del app.state.event_ingestor else: app.state.event_ingestor = previous_ingestor + + +def test_evaluation_skips_ingest_for_merge_mode( + client: TestClient, monkeypatch +) -> None: + """Merged-event mode should skip server-side observability ingestion.""" + agent_name, _ = create_and_assign_policy(client) + + import agent_control_server.endpoints.evaluation as evaluation_module + + event = ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name=agent_name, + control_id=1, + control_name="test-control", + check_stage="pre", + applies_to="llm_call", + action="deny", + matched=True, + confidence=0.9, + ) + build_mock = MagicMock(return_value=[event]) + ingest_mock = AsyncMock() + monkeypatch.setattr(evaluation_module, "_build_observability_events", build_mock) + monkeypatch.setattr(evaluation_module, "_ingest_observability_events", ingest_mock) + monkeypatch.setattr(evaluation_module.observability_settings, "enabled", True) + + payload = Step(type="llm", name="test-step", input="x", output=None) + req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") + resp = client.post( + "/api/v1/evaluation", + json=req.model_dump(mode="json"), + headers={"X-Agent-Control-Merge-Events": "true"}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert "events" not in body + ingest_mock.assert_not_awaited()