diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md index 6ead3c39d58d..7745922fba25 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md @@ -4,6 +4,20 @@ ### Features Added +- `invocations_ws` (WebSocket) protocol support on `InvocationAgentServerHost`. + Register a handler with the new `@app.ws_handler` decorator to host a + full-duplex WebSocket endpoint at `/invocations_ws` on the same host that + serves `POST /invocations`. The SDK calls `await websocket.accept()` before + invoking the handler, runs WebSocket Ping/Pong keep-alive in the background + (default 30 s; configurable via the new `ws_ping_interval` constructor + argument), closes the connection cleanly on handler return, and maps + uncaught exceptions to RFC 6455 close code `1011`. Each connection emits a + structured close-event log line carrying `ws.session_id`, `ws.close_code`, + and `ws.duration_ms`, and the same fields are recorded as OpenTelemetry + span attributes. `/readiness`, OTEL export, graceful shutdown, and the + `x-platform-server` identity header continue to be inherited from + `azure-ai-agentserver-core`. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/README.md b/sdk/agentserver/azure-ai-agentserver-invocations/README.md index 5e9dfe515657..63ec794bbc8f 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/README.md +++ b/sdk/agentserver/azure-ai-agentserver-invocations/README.md @@ -1,6 +1,9 @@ # Azure AI Agent Server Invocations client library for Python -The `azure-ai-agentserver-invocations` package provides the invocation protocol endpoints for Azure AI Hosted Agent containers. It plugs into the [`azure-ai-agentserver-core`](https://pypi.org/project/azure-ai-agentserver-core/) host framework and adds the full invocation lifecycle: `POST /invocations`, `GET /invocations/{id}`, `POST /invocations/{id}/cancel`, and `GET /invocations/docs/openapi.json`. +The `azure-ai-agentserver-invocations` package provides the invocation protocol endpoints for Azure AI Hosted Agent containers. It plugs into the [`azure-ai-agentserver-core`](https://pypi.org/project/azure-ai-agentserver-core/) host framework and supports two transports on the same host: + +- **HTTP** (`invocations` protocol) — `POST /invocations`, `GET /invocations/{id}`, `POST /invocations/{id}/cancel`, `GET /invocations/docs/openapi.json`. +- **WebSocket** (`invocations_ws` protocol) — full-duplex streaming at `/invocations_ws`, registered with `@app.ws_handler`. ## Getting started @@ -25,6 +28,7 @@ This automatically installs `azure-ai-agentserver-core` as a dependency. - `@app.invoke_handler` — **Required.** Handles `POST /invocations`. - `@app.get_invocation_handler` — Optional. Handles `GET /invocations/{id}`. - `@app.cancel_invocation_handler` — Optional. Handles `POST /invocations/{id}/cancel`. +- `@app.ws_handler` — Optional. Handles WebSocket connections at `/invocations_ws`. ### Protocol endpoints @@ -34,6 +38,7 @@ This automatically installs `azure-ai-agentserver-core` as a dependency. | `GET` | `/invocations/{invocation_id}` | No | Retrieve invocation status or result | | `POST` | `/invocations/{invocation_id}/cancel` | No | Cancel a running invocation | | `GET` | `/invocations/docs/openapi.json` | No | Serve the agent's OpenAPI 3.x spec | +| `WS` | `/invocations_ws` | No | Full-duplex WebSocket transport (`invocations_ws` protocol) | ### Request and response headers @@ -182,6 +187,57 @@ app = InvocationAgentServerHost(openapi_spec={ }) ``` +## WebSocket protocol (`invocations_ws`) + +The same `InvocationAgentServerHost` object also exposes a WebSocket transport at `/invocations_ws`. Container authors do not install or import a second package — registering an `@app.ws_handler` is the only step. A multi-protocol agent shares one host, one session, and one process. + +### Quick start + +```python +from azure.ai.agentserver.invocations import InvocationAgentServerHost +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.websockets import WebSocket + +app = InvocationAgentServerHost() + + +@app.invoke_handler # POST /invocations (HTTP) +async def invoke(request: Request) -> Response: + payload = await request.json() + return JSONResponse({"echo": payload}) + + +@app.ws_handler # /invocations_ws (WebSocket) +async def ws(websocket: WebSocket) -> None: + async for message in websocket.iter_text(): + await websocket.send_text(message) + + +app.run() +``` + +### What the SDK does for `@app.ws_handler` + +- Registers `/invocations_ws` on the same Starlette host as `/invocations` and `/readiness`. +- Calls `await websocket.accept()` before invoking your handler. +- Runs WebSocket Ping/Pong keep-alive in the background — default 30 s, configurable via `InvocationAgentServerHost(ws_ping_interval=...)`. Set `ws_ping_interval=0` to disable. Frames are sent at the WebSocket protocol layer (RFC 6455 opcode `0x9`/`0xA`) by the underlying Hypercorn server, which keeps the connection alive across Azure APIM and Azure Load Balancer's ~4 minute idle timeout without any extra application traffic. +- Closes the connection cleanly on handler return (close code `1000`) or maps an uncaught handler exception to close code `1011`. +- Emits a structured close-event log line carrying `ws.session_id`, `ws.close_code`, and `ws.duration_ms`. The same fields are recorded as OpenTelemetry span attributes so the connection lifetime is visible end-to-end. +- Inherits `/readiness`, OpenTelemetry export, graceful shutdown, and the `x-platform-server` identity header from `azure-ai-agentserver-core`. + +### Handler signature + +The handler receives a Starlette [`WebSocket`][starlette-ws] and returns `None`. The full WebSocket API — `iter_text`, `iter_bytes`, `iter_json`, `send_text`, `send_bytes`, `send_json`, `close`, `headers`, `query_params`, `client`, `state` — is available, so application protocols on top of `invocations_ws` are entirely under your control. + +[starlette-ws]: https://www.starlette.io/websockets/ + +### Reference: configuration + +| Constructor argument | Default | Description | +|---|---|---| +| `ws_ping_interval` | `30.0` (seconds) | WebSocket protocol Ping interval. `0` disables keep-alive. Negative or non-finite values are rejected. | + ## Troubleshooting ### Reporting issues @@ -196,6 +252,8 @@ Visit the [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ |---|---| | [simple_invoke_agent](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-invocations/samples/simple_invoke_agent/) | Minimal synchronous request-response | | [async_invoke_agent](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-invocations/samples/async_invoke_agent/) | Long-running operations with polling and cancellation | +| [ws_invoke_agent](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/) | Combined `POST /invocations` (HTTP) and `/invocations_ws` (WebSocket) host | +| [ws_bidirectional_streaming_agent](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/) | Full-duplex `/invocations_ws` agent: server-pushed heartbeats + concurrent token streams + mid-flight cancel | ## Contributing diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_constants.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_constants.py index 62f8600a44bd..611772000efc 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_constants.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_constants.py @@ -19,3 +19,31 @@ class InvocationConstants: ATTR_SPAN_SESSION_ID = "azure.ai.agentserver.invocations.session_id" ATTR_SPAN_ERROR_CODE = "azure.ai.agentserver.invocations.error.code" ATTR_SPAN_ERROR_MESSAGE = "azure.ai.agentserver.invocations.error.message" + + +class InvocationsWSConstants: + """invocations_ws (WebSocket) protocol constants. + + Route, span attribute keys, and ping/pong defaults for the + WebSocket endpoint hosted alongside the HTTP invocations protocol. + """ + + # Route + ROUTE_PATH = "/invocations_ws" + + # Default WebSocket Ping interval in seconds. + # Azure APIM and Azure Load Balancer drop idle WebSocket connections + # after ~4 minutes; 30 s gives a comfortable safety margin. + DEFAULT_PING_INTERVAL_S = 30.0 + + # Close codes (RFC 6455) + CLOSE_NORMAL = 1000 # handler returned cleanly + CLOSE_INTERNAL_ERROR = 1011 # handler raised an unhandled exception + CLOSE_SERVICE_RESTART = 1012 # graceful shutdown drained the connection + + # Span attribute keys + ATTR_SPAN_SESSION_ID = "ws.session_id" + ATTR_SPAN_CLOSE_CODE = "ws.close_code" + ATTR_SPAN_DURATION_MS = "ws.duration_ms" + ATTR_SPAN_ERROR_CODE = "azure.ai.agentserver.invocations_ws.error.code" + ATTR_SPAN_ERROR_MESSAGE = "azure.ai.agentserver.invocations_ws.error.message" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py index bf3120974fa0..8c706e55e5b4 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py @@ -31,6 +31,7 @@ ) from ._constants import InvocationConstants +from ._invocation_ws import _WSHandlerMixin logger = logging.getLogger("azure.ai.agentserver") @@ -93,13 +94,18 @@ def _sanitize_id(value: str, fallback: str) -> str: return value -class InvocationAgentServerHost(AgentServerHost): +class InvocationAgentServerHost(_WSHandlerMixin, AgentServerHost): """Invocation protocol host for Azure AI Hosted Agents. A :class:`~azure.ai.agentserver.core.AgentServerHost` subclass that adds the invocation protocol endpoints. Use the decorator methods to wire handler functions to the endpoints. + The same host object also exposes the ``invocations_ws`` (WebSocket) + transport at :data:`/invocations_ws` — register a handler with the + :meth:`ws_handler` decorator. Multi-protocol agents share a single + host, session, and process. + For multi-protocol agents, compose via cooperative inheritance:: class MyHost(InvocationAgentServerHost, ResponsesAgentServerHost): @@ -108,18 +114,29 @@ class MyHost(InvocationAgentServerHost, ResponsesAgentServerHost): Usage:: from azure.ai.agentserver.invocations import InvocationAgentServerHost + from starlette.websockets import WebSocket app = InvocationAgentServerHost() - @app.invoke_handler + @app.invoke_handler # POST /invocations async def handle(request): return JSONResponse({"ok": True}) + @app.ws_handler # /invocations_ws + async def ws(websocket: WebSocket) -> None: + async for message in websocket.iter_text(): + await websocket.send_text(message) + app.run() :param openapi_spec: Optional OpenAPI spec dict. When provided, the spec is served at ``GET /invocations/docs/openapi.json``. :type openapi_spec: Optional[dict[str, Any]] + :param ws_ping_interval: Seconds between WebSocket protocol Ping frames + on ``/invocations_ws``. ``None`` (default) selects 30 s; ``0`` + disables keep-alive. Configured on the underlying Hypercorn + server so the framing is opcode 0x9 / 0xA, not application JSON. + :type ws_ping_interval: Optional[float] """ _INSTRUMENTATION_SCOPE = "Azure.AI.AgentServer.Invocations" @@ -128,6 +145,7 @@ def __init__( self, *, openapi_spec: Optional[dict[str, Any]] = None, + ws_ping_interval: Optional[float] = None, **kwargs: Any, ) -> None: self._invoke_fn: Optional[Callable] = None @@ -135,8 +153,11 @@ def __init__( self._cancel_invocation_fn: Optional[Callable] = None self._openapi_spec = openapi_spec + # Initialise WS handler slots (raises ValueError on a bad interval). + self._init_ws_state(ws_ping_interval) + # Build invocation routes and pass to parent via routes kwarg - invocation_routes = [ + invocation_routes: list[Any] = [ Route( "/invocations/docs/openapi.json", self._get_openapi_spec_endpoint, @@ -161,6 +182,7 @@ def __init__( methods=["POST"], name="cancel_invocation", ), + self._build_ws_route(self._ws_endpoint), ] # Merge with any routes from sibling mixins via cooperative init @@ -169,10 +191,53 @@ def __init__( # --- Invocations startup configuration logging --- logger.info( - "Invocations protocol: openapi_spec_configured=%s", + "Invocations protocol: openapi_spec_configured=%s, " + "ws_ping_interval=%s", self._openapi_spec is not None, + ( + "disabled" + if self._ws_ping_interval == 0 + else f"{self._ws_ping_interval}s" + ), ) + # ------------------------------------------------------------------ + # Hypercorn server config (WebSocket Ping/Pong keep-alive) + # ------------------------------------------------------------------ + + def _build_hypercorn_config(self, host: str, port: int) -> object: + """Extend the base Hypercorn config with the WebSocket Ping interval. + + Hypercorn sends WS protocol Ping frames every + ``websocket_ping_interval`` seconds on every active WebSocket + connection — exactly the keep-alive the ``invocations_ws`` spec + requires. ``ws_ping_interval=0`` leaves the default + ``None`` (disabled). + + :param host: Network interface to bind. + :type host: str + :param port: Port to bind. + :type port: int + :return: The configured Hypercorn config. + :rtype: hypercorn.config.Config + """ + config = super()._build_hypercorn_config(host, port) + if self._ws_ping_interval and self._ws_ping_interval > 0: + try: + # ``websocket_ping_interval`` is a float-or-None on + # Hypercorn ≥0.14; assigning a positive float enables + # protocol-level Ping frames. + config.websocket_ping_interval = self._ws_ping_interval # type: ignore[attr-defined] + except Exception: # pylint: disable=broad-exception-caught + # Hypercorn <0.14 does not support per-server WS ping — + # leave the default and warn so operators can upgrade. + logger.warning( + "Hypercorn does not support websocket_ping_interval; " + "WebSocket keep-alive will be best-effort.", + exc_info=True, + ) + return config + # ------------------------------------------------------------------ # Handler decorators # ------------------------------------------------------------------ diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation_ws.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation_ws.py new file mode 100644 index 000000000000..7d3230ad4d1b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation_ws.py @@ -0,0 +1,465 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""invocations_ws (WebSocket) protocol support for ``InvocationAgentServerHost``. + +Implements the ``@app.ws_handler`` decorator and the ``/invocations_ws`` +route described in the ``invocations_ws`` protocol spec. The SDK wraps +the user handler with: + +* ``await websocket.accept()`` before the handler runs; +* WebSocket protocol-level Ping/Pong keep-alive (default 30 s) so idle + connections survive Azure APIM / Azure Load Balancer's ~4-minute idle + timeout; +* a clean close on handler return (code 1000) or a 1011 close on uncaught + handler exceptions; +* a structured close-event log line and OTel span attributes carrying + ``ws.session_id``, ``ws.close_code``, and ``ws.duration_ms``. +""" +from __future__ import annotations + +import inspect +import logging +import math +import time +import uuid +from collections.abc import Awaitable, Callable +from typing import Any, Optional + +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState + +from azure.ai.agentserver.core import end_span, record_error # pylint: disable=no-name-in-module + +from ._constants import InvocationsWSConstants + +logger = logging.getLogger("azure.ai.agentserver") + + +WSHandler = Callable[[WebSocket], Awaitable[None]] + + +def _safe_set_attrs(span: Any, attrs: dict[str, Any]) -> None: + """Best-effort ``span.set_attribute`` for a batch of attributes. + + :param span: The OTel span (or ``None`` when tracing is disabled). + :type span: any + :param attrs: Mapping of attribute keys to values. + :type attrs: dict[str, any] + """ + if span is None: + return + try: + for key, value in attrs.items(): + span.set_attribute(key, value) + except Exception: # pylint: disable=broad-exception-caught + logger.debug("Failed to set WS span attributes: %s", list(attrs.keys()), exc_info=True) + + +class _WSHandlerMixin: + """Mixin that adds the ``@app.ws_handler`` decorator and ``/invocations_ws`` route. + + Designed to be mixed into :class:`InvocationAgentServerHost` so the same + host object exposes both ``POST /invocations`` (HTTP) and + ``/invocations_ws`` (WebSocket) on the same Starlette application. + + Subclasses must: + + 1. Inherit from :class:`~azure.ai.agentserver.core.AgentServerHost` + (transitively) so ``self.config`` and ``self.request_span`` are + available. + 2. Append the route returned by :meth:`_build_ws_route` to their + ``routes`` list before calling ``super().__init__``. + """ + + # Slots populated by __init__. + _ws_fn: Optional[WSHandler] + _ws_ping_interval: float + + def _init_ws_state(self, ws_ping_interval: Optional[float]) -> None: + """Initialize WS handler slots. + + :param ws_ping_interval: Seconds between WS protocol Ping frames. + ``None`` selects the default (30 s); ``0`` disables keep-alive. + :type ws_ping_interval: Optional[float] + :raises ValueError: If *ws_ping_interval* is negative or non-finite. + """ + self._ws_fn = None + if ws_ping_interval is None: + resolved = InvocationsWSConstants.DEFAULT_PING_INTERVAL_S + else: + try: + resolved = float(ws_ping_interval) + except (TypeError, ValueError) as exc: + raise ValueError( + f"ws_ping_interval must be a number, got {ws_ping_interval!r}" + ) from exc + # Reject negative / NaN / inf — those are programming errors that + # would silently mis-configure the keep-alive. + if math.isnan(resolved) or math.isinf(resolved) or resolved < 0.0: + raise ValueError( + f"ws_ping_interval must be a non-negative finite number, " + f"got {ws_ping_interval!r}" + ) + self._ws_ping_interval = resolved + + # ------------------------------------------------------------------ + # Public configuration accessor + # ------------------------------------------------------------------ + + @property + def ws_ping_interval(self) -> float: + """Configured WebSocket Ping interval in seconds (``0`` = disabled). + + :return: The configured interval, or ``0`` when keep-alive is disabled. + :rtype: float + """ + return self._ws_ping_interval + + # ------------------------------------------------------------------ + # Decorator + # ------------------------------------------------------------------ + + def ws_handler(self, fn: WSHandler) -> WSHandler: + """Register an async function as the ``/invocations_ws`` handler. + + The SDK calls ``await websocket.accept()`` before invoking *fn* and + cleanly closes the connection on return (code 1000) or maps an + uncaught exception to close code 1011. + + Usage:: + + from starlette.websockets import WebSocket + + @app.ws_handler + async def handle(websocket: WebSocket) -> None: + async for msg in websocket.iter_text(): + await websocket.send_text(msg) + + :param fn: An async function accepting a Starlette + :class:`~starlette.websockets.WebSocket` and returning ``None``. + :type fn: Callable[[WebSocket], Awaitable[None]] + :return: The original function (unmodified). + :rtype: Callable[[WebSocket], Awaitable[None]] + :raises TypeError: If *fn* is not an ``async def`` function. + """ + if not inspect.iscoroutinefunction(fn): + raise TypeError( + f"ws_handler expects an async function, got {type(fn).__name__}. " + "Use 'async def' to define your handler." + ) + self._ws_fn = fn + return fn + + # ------------------------------------------------------------------ + # Route factory (for cooperative __init__) + # ------------------------------------------------------------------ + + @staticmethod + def _build_ws_route(endpoint: Callable[[WebSocket], Awaitable[None]]) -> Any: + """Return a :class:`~starlette.routing.WebSocketRoute` for ``/invocations_ws``. + + Imported lazily to avoid hard-coding the route type in the public + module body and keep the import surface symmetric with the HTTP + ``Route`` import in :mod:`._invocation`. + + :param endpoint: The async endpoint to wire to the route. + :type endpoint: Callable[[WebSocket], Awaitable[None]] + :return: A configured ``WebSocketRoute`` instance. + :rtype: ~starlette.routing.WebSocketRoute + """ + from starlette.routing import WebSocketRoute # pylint: disable=import-outside-toplevel + + return WebSocketRoute( + InvocationsWSConstants.ROUTE_PATH, + endpoint, + name="invocations_ws", + ) + + # ------------------------------------------------------------------ + # Endpoint + # ------------------------------------------------------------------ + + async def _ws_endpoint(self, websocket: WebSocket) -> None: + """ASGI endpoint for ``/invocations_ws``. + + Wraps the user-registered handler with: accept, span lifecycle, + graceful close on success, 1011 close on failure, and a structured + close event log + span attributes. + + :param websocket: The incoming Starlette WebSocket. + :type websocket: ~starlette.websockets.WebSocket + """ + # Per-connection identifiers. Session ID is generated server-side; + # the spec carries it in the close-event metric so an operator can + # correlate logs/metrics for a given long-lived connection. + session_id = str(uuid.uuid4()) + start_ns = time.monotonic_ns() + + # Open the OTel span before accepting so any tracecontext header + # the client sent is honoured for parenting child spans inside the + # user handler. ``end_on_exit=False`` so we can attach the close + # code / duration before ending. + span_ctx = self.request_span( # type: ignore[attr-defined] + websocket.headers, + session_id, + "websocket_session", + operation_name="websocket_session", + session_id=session_id, + end_on_exit=False, + ) + otel_span = span_ctx.__enter__() + _safe_set_attrs(otel_span, {InvocationsWSConstants.ATTR_SPAN_SESSION_ID: session_id}) + + if self._ws_fn is None: + await self._reject_no_handler(websocket, span_ctx, otel_span, session_id, start_ns) + return + + # Accept the upgrade *before* invoking the user handler — per spec. + try: + await websocket.accept() + except Exception as exc: # pylint: disable=broad-exception-caught + await self._finalize_session( + websocket=None, + span_ctx=span_ctx, + otel_span=otel_span, + session_id=session_id, + start_ns=start_ns, + close_code=InvocationsWSConstants.CLOSE_INTERNAL_ERROR, + handler_exc=exc, + error_code="accept_failed", + ) + logger.error( + "WebSocket accept failed for session %s: %s", + session_id, exc, exc_info=True, + ) + return + + close_code, handler_exc = await self._invoke_user_handler(websocket, session_id) + await self._finalize_session( + websocket=websocket, + span_ctx=span_ctx, + otel_span=otel_span, + session_id=session_id, + start_ns=start_ns, + close_code=close_code, + handler_exc=handler_exc, + error_code="internal_error" if handler_exc is not None else None, + ) + + async def _invoke_user_handler( + self, websocket: WebSocket, session_id: str, + ) -> tuple[int, Optional[BaseException]]: + """Run the registered user handler and classify the outcome. + + :param websocket: The accepted WebSocket to pass to the handler. + :type websocket: ~starlette.websockets.WebSocket + :param session_id: Per-connection session ID for diagnostic logs. + :type session_id: str + :return: ``(close_code, exception_or_None)``. ``close_code`` is the + RFC 6455 code that should be sent to the client; ``exception`` + is set only for an *unhandled* exception (so the caller can map + it to span error events and a 1011 close). + :rtype: tuple[int, Optional[BaseException]] + """ + assert self._ws_fn is not None # checked by caller + try: + await self._ws_fn(websocket) + return InvocationsWSConstants.CLOSE_NORMAL, None + except WebSocketDisconnect as exc: + # Client (or proxy) closed first — surface their code, not 1011. + return ( + int(exc.code) if exc.code else InvocationsWSConstants.CLOSE_NORMAL, + None, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error( + "WebSocket handler raised for session %s: %s", + session_id, exc, exc_info=True, + ) + return InvocationsWSConstants.CLOSE_INTERNAL_ERROR, exc + + async def _reject_no_handler( + self, + websocket: WebSocket, + span_ctx: Any, + otel_span: Any, + session_id: str, + start_ns: int, + ) -> None: + """Refuse a WS upgrade when no ``@ws_handler`` is registered. + + :param websocket: The pending WebSocket awaiting upgrade. + :type websocket: ~starlette.websockets.WebSocket + :param span_ctx: The active ``request_span`` context manager. + :type span_ctx: any + :param otel_span: The current OTel span (or ``None``). + :type otel_span: any + :param session_id: Per-connection session ID. + :type session_id: str + :param start_ns: ``time.monotonic_ns()`` at connection start. + :type start_ns: int + """ + logger.error( + "WebSocket connection on %s rejected: no @ws_handler registered", + InvocationsWSConstants.ROUTE_PATH, + ) + duration_ms = (time.monotonic_ns() - start_ns) // 1_000_000 + self._emit_close_event( + otel_span, + session_id, + InvocationsWSConstants.CLOSE_INTERNAL_ERROR, + duration_ms, + error_code="not_implemented", + error_message="No ws_handler registered.", + ) + try: + span_ctx.__exit__(None, None, None) + finally: + end_span(otel_span) + await websocket.close( + code=InvocationsWSConstants.CLOSE_INTERNAL_ERROR, + reason="No ws_handler registered", + ) + + async def _finalize_session( + self, + *, + websocket: Optional[WebSocket], + span_ctx: Any, + otel_span: Any, + session_id: str, + start_ns: int, + close_code: int, + handler_exc: Optional[BaseException], + error_code: Optional[str], + ) -> None: + """Close the WS (best-effort), emit metrics, and end the span. + + Called from both the success path and the accept-failure path. + + :keyword websocket: The connected WebSocket, or ``None`` when the + ASGI ``accept`` itself failed (no socket to close). + :paramtype websocket: Optional[~starlette.websockets.WebSocket] + :keyword span_ctx: The active ``request_span`` context manager. + :paramtype span_ctx: any + :keyword otel_span: The current OTel span (or ``None`` when tracing is off). + :paramtype otel_span: any + :keyword session_id: Per-connection session ID. + :paramtype session_id: str + :keyword start_ns: ``time.monotonic_ns()`` at connection start. + :paramtype start_ns: int + :keyword close_code: The RFC 6455 code to report to the client. + :paramtype close_code: int + :keyword handler_exc: Unhandled exception raised by the user handler, + or ``None`` for a clean close. + :paramtype handler_exc: Optional[BaseException] + :keyword error_code: Short error tag for span / log; ``None`` for success. + :paramtype error_code: Optional[str] + """ + duration_ms = (time.monotonic_ns() - start_ns) // 1_000_000 + + # Best-effort clean close: only send a close frame if the + # application hasn't already done so (e.g. the user handler + # may have called ``websocket.close`` itself, or the client + # may have disconnected). + if ( + websocket is not None + and websocket.application_state != WebSocketState.DISCONNECTED + ): + reason = ( + "Internal server error" + if close_code == InvocationsWSConstants.CLOSE_INTERNAL_ERROR + else "" + ) + try: + await websocket.close(code=close_code, reason=reason) + except Exception: # pylint: disable=broad-exception-caught + # Connection already gone — nothing to recover here. + logger.debug( + "Error closing WebSocket session %s", session_id, exc_info=True, + ) + + self._emit_close_event( + otel_span, + session_id, + close_code, + duration_ms, + error_code=error_code, + error_message=str(handler_exc) if handler_exc is not None else None, + ) + + if handler_exc is not None: + try: + record_error(otel_span, handler_exc) + finally: + try: + span_ctx.__exit__( + type(handler_exc), handler_exc, handler_exc.__traceback__, + ) + finally: + end_span(otel_span) + else: + try: + span_ctx.__exit__(None, None, None) + finally: + end_span(otel_span) + + # ------------------------------------------------------------------ + # Close event + # ------------------------------------------------------------------ + + @staticmethod + def _emit_close_event( + otel_span: Any, + session_id: str, + close_code: int, + duration_ms: int, + *, + error_code: Optional[str] = None, + error_message: Optional[str] = None, + ) -> None: + """Record close-event span attributes and emit a structured log line. + + The log record carries the ``ws.session_id``, ``ws.close_code``, + and ``ws.duration_ms`` fields listed in the spec via the standard + ``logging`` ``extra`` dict — a structured-logging formatter or an + OTel logging bridge can pick them up directly without having to + parse the message. + + :param otel_span: The connection span (or ``None`` when tracing is off). + :type otel_span: any + :param session_id: Per-connection session ID. + :type session_id: str + :param close_code: The RFC 6455 close code reported to the client. + :type close_code: int + :param duration_ms: Connection duration in milliseconds (monotonic). + :type duration_ms: int + :keyword error_code: Optional short error tag for span attribute. + :paramtype error_code: Optional[str] + :keyword error_message: Optional human-readable error message for + span attribute (NOT included in the log line — exception details + are logged separately by ``logger.error(..., exc_info=True)``). + :paramtype error_message: Optional[str] + """ + attrs: dict[str, Any] = { + InvocationsWSConstants.ATTR_SPAN_SESSION_ID: session_id, + InvocationsWSConstants.ATTR_SPAN_CLOSE_CODE: close_code, + InvocationsWSConstants.ATTR_SPAN_DURATION_MS: duration_ms, + } + if error_code: + attrs[InvocationsWSConstants.ATTR_SPAN_ERROR_CODE] = error_code + if error_message: + attrs[InvocationsWSConstants.ATTR_SPAN_ERROR_MESSAGE] = error_message + _safe_set_attrs(otel_span, attrs) + + logger.info( + "invocations_ws connection closed: session_id=%s code=%s duration_ms=%s", + session_id, + close_code, + duration_ms, + extra={ + "ws.session_id": session_id, + "ws.close_code": close_code, + "ws.duration_ms": duration_ms, + }, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/requirements.txt new file mode 100644 index 000000000000..bc5cf4644e14 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-invocations diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/ws_bidirectional_streaming_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/ws_bidirectional_streaming_agent.py new file mode 100644 index 000000000000..ae9f435d9c36 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_bidirectional_streaming_agent/ws_bidirectional_streaming_agent.py @@ -0,0 +1,264 @@ +"""Bidirectional streaming agent over the ``invocations_ws`` (WebSocket) protocol. + +Unlike the request/reply echo in :mod:`samples.ws_invoke_agent`, this sample +exercises the *full-duplex* nature of WebSocket: the server and the client +send and receive on the same socket **concurrently and independently**. + +The handler runs three groups of coroutines in parallel: + +1. ``_reader`` — consumes inbound JSON frames (prompts and control + messages) and dispatches them. Multiple prompts + may arrive while previous ones are still streaming. +2. ``_heartbeat`` — server-initiated push: emits a ``heartbeat`` event + every few seconds without any client input. Proves + the outbound direction is not gated on inbound traffic. +3. ``_stream_tokens`` — one task per prompt; streams generated tokens back + to the client at its own pace. Multiple generations + can run in parallel; ``cancel`` control messages + interrupt them mid-flight. + +Wire protocol (JSON over text frames) +------------------------------------- + +Inbound (client -> server):: + + {"type": "prompt", "id": "p1", "text": "..."} + {"type": "cancel", "id": "p1"} + {"type": "bye"} # graceful shutdown + +Outbound (server -> client):: + + {"type": "ready"} # sent on connect + {"type": "heartbeat", "ts": 1715200000} # periodic, server-initiated + {"type": "token", "id": "p1", "token": "..."} + {"type": "done", "id": "p1"} + {"type": "cancelled", "id": "p1"} + {"type": "error", "id": "p1", "message": "..."} + +Run it:: + + python ws_bidirectional_streaming_agent.py + +Drive it with the ``websockets`` CLI; the server keeps streaming heartbeats +and tokens while you type the next prompt:: + + python -m websockets ws://localhost:8088/invocations_ws + > {"type": "prompt", "id": "p1", "text": "Tell me a story"} + > {"type": "prompt", "id": "p2", "text": "And another"} + > {"type": "cancel", "id": "p1"} + > {"type": "bye"} +""" +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import time +from collections.abc import AsyncGenerator + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.websockets import WebSocket, WebSocketDisconnect + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + + +logger = logging.getLogger("ws_bidirectional_streaming_agent") + +app = InvocationAgentServerHost() + + +# Simulated tokens — in production these would come from a model. +_SIMULATED_TOKENS = [ + "Once", " upon", " a", " time", ",", " in", " a", " land", + " of", " full", "-", "duplex", " sockets", ",", " a", " server", + " and", " a", " client", " spoke", " at", " the", " same", " time", ".", +] + +_HEARTBEAT_INTERVAL_S = 2.0 +_TOKEN_DELAY_S = 0.2 + + +# --------------------------------------------------------------------------- +# HTTP — same host, kept for parity with the rest of the samples. +# --------------------------------------------------------------------------- + +@app.invoke_handler # POST /invocations +async def handle_invoke(request: Request) -> Response: + """Echo the JSON payload back over HTTP.""" + payload = await request.json() + return JSONResponse({"echo": payload}) + + +# --------------------------------------------------------------------------- +# WebSocket — true bidirectional streaming. +# --------------------------------------------------------------------------- + +async def _generate_tokens(text: str) -> AsyncGenerator[str, None]: + """Yield simulated tokens with a small per-token delay. + + Replace this with a real streaming model call (e.g. Azure OpenAI) in + production. + + :param text: The user prompt (unused in this demo). + :type text: str + :return: An async generator of token strings. + :rtype: AsyncGenerator[str, None] + """ + del text # demo: ignore prompt content + for token in _SIMULATED_TOKENS: + await asyncio.sleep(_TOKEN_DELAY_S) + yield token + + +async def _stream_tokens( + websocket: WebSocket, prompt_id: str, text: str, +) -> None: + """Stream tokens for one prompt; cancellable via ``asyncio.CancelledError``. + + :param websocket: The accepted WebSocket. + :type websocket: ~starlette.websockets.WebSocket + :param prompt_id: Caller-supplied prompt identifier (echoed in events). + :type prompt_id: str + :param text: The user prompt to "generate" against. + :type text: str + """ + try: + async for token in _generate_tokens(text): + await websocket.send_json( + {"type": "token", "id": prompt_id, "token": token}, + ) + await websocket.send_json({"type": "done", "id": prompt_id}) + except asyncio.CancelledError: + # Best-effort: tell the client we honoured their cancel. Suppress + # any send error (the socket may already be closed) and re-raise + # so the caller observes the cancellation. + with contextlib.suppress(Exception): + await websocket.send_json({"type": "cancelled", "id": prompt_id}) + raise + + +async def _heartbeat(websocket: WebSocket) -> None: + """Push a ``heartbeat`` event every ``_HEARTBEAT_INTERVAL_S`` seconds. + + Demonstrates server-initiated traffic that does **not** wait for any + inbound message — the defining property of full-duplex. + + :param websocket: The accepted WebSocket. + :type websocket: ~starlette.websockets.WebSocket + """ + while True: + await asyncio.sleep(_HEARTBEAT_INTERVAL_S) + await websocket.send_json( + {"type": "heartbeat", "ts": int(time.time())}, + ) + + +async def _reader( + websocket: WebSocket, + in_flight: "dict[str, asyncio.Task[None]]", +) -> None: + """Consume inbound frames and dispatch prompt / cancel / bye control messages. + + Returns (instead of raising) on a ``bye`` message or a clean client + disconnect. Returning lets the caller cancel the heartbeat task and + end the connection. + + :param websocket: The accepted WebSocket. + :type websocket: ~starlette.websockets.WebSocket + :param in_flight: Map of ``prompt_id`` -> generation task, used to + honour ``cancel`` messages. + :type in_flight: dict[str, asyncio.Task[None]] + """ + try: + async for raw in websocket.iter_text(): + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await websocket.send_json( + {"type": "error", "message": "invalid JSON"}, + ) + continue + + msg_type = msg.get("type") + + if msg_type == "prompt": + prompt_id = str(msg.get("id", "")) + text = str(msg.get("text", "")) + if not prompt_id: + await websocket.send_json( + {"type": "error", "message": "prompt requires 'id'"}, + ) + continue + # Schedule the generation as an independent task so it + # runs in parallel with the reader (and any other + # in-flight generations). + task = asyncio.create_task( + _stream_tokens(websocket, prompt_id, text), + name=f"stream-{prompt_id}", + ) + in_flight[prompt_id] = task + task.add_done_callback( + lambda _t, k=prompt_id: in_flight.pop(k, None), + ) + + elif msg_type == "cancel": + prompt_id = str(msg.get("id", "")) + task = in_flight.get(prompt_id) + if task is not None and not task.done(): + task.cancel() + + elif msg_type == "bye": + return + + else: + await websocket.send_json( + {"type": "error", "message": f"unknown type: {msg_type!r}"}, + ) + except WebSocketDisconnect: + # Client closed first — let the caller unwind normally. + return + + +async def _cancel_and_wait(tasks: "list[asyncio.Task[None]]") -> None: + """Cancel every task in *tasks* and wait for them to actually finish. + + :param tasks: Tasks to cancel; already-done tasks are ignored. + :type tasks: list[asyncio.Task[None]] + """ + for task in tasks: + if not task.done(): + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +@app.ws_handler # /invocations_ws +async def handle_ws(websocket: WebSocket) -> None: + """Run reader, heartbeat, and per-prompt generation tasks concurrently. + + The SDK has already accepted the connection by the time this function + runs. When this coroutine returns the SDK closes the socket with + code ``1000``; if it raises, the SDK maps the exception to ``1011``. + """ + await websocket.send_json({"type": "ready"}) + + in_flight: "dict[str, asyncio.Task[None]]" = {} + heartbeat_task = asyncio.create_task( + _heartbeat(websocket), name="heartbeat", + ) + + try: + await _reader(websocket, in_flight) + except WebSocketDisconnect: + # Client went away mid-read — fall through to cleanup. + logger.info("client disconnected during streaming") + finally: + await _cancel_and_wait( + [heartbeat_task, *in_flight.values()], + ) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/requirements.txt new file mode 100644 index 000000000000..bc5cf4644e14 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-invocations diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/ws_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/ws_invoke_agent.py new file mode 100644 index 000000000000..2f36c48a3b20 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/samples/ws_invoke_agent/ws_invoke_agent.py @@ -0,0 +1,58 @@ +"""Echo agent over the ``invocations_ws`` (WebSocket) protocol. + +Exposes the same host on: + +* ``POST /invocations`` — HTTP, JSON request/response; +* ``/invocations_ws`` — WebSocket, full-duplex streaming; +* ``GET /readiness`` — readiness probe inherited from + ``azure-ai-agentserver-core``. + +Usage:: + + # Start the agent + python ws_invoke_agent.py + + # HTTP turn + curl -X POST http://localhost:8088/invocations \\ + -H "Content-Type: application/json" \\ + -d '{"name": "Alice"}' + # -> {"echo": {"name": "Alice"}} + + # WebSocket turn (with the `websockets` client library) + python -m websockets ws://localhost:8088/invocations_ws + # > hello + # < hello +""" +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.websockets import WebSocket + +from azure.ai.agentserver.invocations import InvocationAgentServerHost + + +app = InvocationAgentServerHost() + + +@app.invoke_handler # POST /invocations +async def handle_invoke(request: Request) -> Response: + """Echo the JSON payload back over HTTP.""" + payload = await request.json() + return JSONResponse({"echo": payload}) + + +@app.ws_handler # /invocations_ws +async def handle_ws(websocket: WebSocket) -> None: + """Echo every text frame back over the WebSocket connection. + + The SDK has already accepted the connection by the time this function + runs, sends WebSocket Ping frames every 30 s in the background to keep + Azure APIM / Azure Load Balancer from idling the socket out, and will + close the connection cleanly on return. An uncaught exception here + is mapped to RFC 6455 close code 1011. + """ + async for message in websocket.iter_text(): + await websocket.send_text(message) + + +if __name__ == "__main__": + app.run() diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_ws_handler.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_ws_handler.py new file mode 100644 index 000000000000..df0affa4fd64 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_ws_handler.py @@ -0,0 +1,339 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the ``@app.ws_handler`` decorator and ``/invocations_ws`` route. + +The tests use Starlette's :class:`~starlette.testclient.TestClient` which +supports WebSocket connections in-process, so the SDK's accept / handler +dispatch / close-code mapping / close-event log all run end-to-end without +needing a real network listener. +""" +import logging + +import pytest +from starlette.testclient import TestClient +from starlette.websockets import WebSocket, WebSocketDisconnect + +from azure.ai.agentserver.invocations import InvocationAgentServerHost +from azure.ai.agentserver.invocations._constants import InvocationsWSConstants + + +# --------------------------------------------------------------------------- +# Factory helpers +# --------------------------------------------------------------------------- + +def _make_echo_ws(**kwargs) -> InvocationAgentServerHost: + """Build a host whose ws handler echoes every text frame back.""" + app = InvocationAgentServerHost(**kwargs) + + @app.ws_handler + async def echo(websocket: WebSocket) -> None: + async for message in websocket.iter_text(): + await websocket.send_text(message) + + return app + + +def _make_failing_ws(exc_factory=ValueError) -> InvocationAgentServerHost: + """Build a host whose ws handler raises after one received frame.""" + app = InvocationAgentServerHost() + + @app.ws_handler + async def boom(websocket: WebSocket) -> None: + await websocket.receive_text() + raise exc_factory("boom") + + return app + + +# --------------------------------------------------------------------------- +# Route registration +# --------------------------------------------------------------------------- + +def test_ws_route_is_registered(): + """The /invocations_ws route exists alongside HTTP routes.""" + app = InvocationAgentServerHost() + paths = [getattr(r, "path", None) for r in app.routes] + assert "/invocations_ws" in paths + assert "/invocations" in paths + assert "/readiness" in paths + + +def test_readiness_still_works_with_ws_registered(): + """Adding the WS route doesn't break /readiness.""" + app = _make_echo_ws() + client = TestClient(app) + resp = client.get("/readiness") + assert resp.status_code == 200 + # x-platform-server header still applied via core middleware + assert "x-platform-server" in resp.headers + + +# --------------------------------------------------------------------------- +# Decorator validation +# --------------------------------------------------------------------------- + +def test_ws_handler_rejects_sync_function(): + """``@app.ws_handler`` must be applied to ``async def`` callables.""" + app = InvocationAgentServerHost() + + with pytest.raises(TypeError, match="async function"): + @app.ws_handler # type: ignore[arg-type] + def sync_handler(websocket): # noqa: ARG001 + pass + + +def test_ws_handler_returns_function_unchanged(): + """The decorator must return the original function unmodified.""" + app = InvocationAgentServerHost() + + async def handler(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.close() + + result = app.ws_handler(handler) + assert result is handler + + +# --------------------------------------------------------------------------- +# Accept happens automatically +# --------------------------------------------------------------------------- + +def test_ws_sdk_accepts_connection_before_handler_runs(): + """The SDK calls ``websocket.accept()`` before invoking the user handler. + + The handler in this test never calls ``accept`` itself; the connection + must still be usable, which proves the SDK accepted it on the user's + behalf. + """ + app = InvocationAgentServerHost() + + @app.ws_handler + async def handler(websocket: WebSocket) -> None: + await websocket.send_text("ready") + + client = TestClient(app) + with client.websocket_connect("/invocations_ws") as ws: + assert ws.receive_text() == "ready" + + +# --------------------------------------------------------------------------- +# Echo round-trip +# --------------------------------------------------------------------------- + +def test_ws_echo_round_trip(): + """End-to-end: send a frame, receive it echoed back.""" + app = _make_echo_ws() + client = TestClient(app) + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("hello") + assert ws.receive_text() == "hello" + ws.send_text("world") + assert ws.receive_text() == "world" + + +# --------------------------------------------------------------------------- +# Handler exception → close code 1011 +# --------------------------------------------------------------------------- + +def test_ws_handler_exception_maps_to_close_code_1011(): + """Uncaught handler exceptions must surface as RFC 6455 close code 1011.""" + app = _make_failing_ws() + client = TestClient(app) + + with pytest.raises(WebSocketDisconnect) as excinfo: + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("trigger") + # Server closes after the handler raises; receiving forces + # the close frame to surface as WebSocketDisconnect. + ws.receive_text() + + assert excinfo.value.code == InvocationsWSConstants.CLOSE_INTERNAL_ERROR + assert excinfo.value.code == 1011 + + +def test_ws_clean_return_uses_close_code_1000(): + """A handler that returns normally yields a 1000 (normal) close code.""" + app = InvocationAgentServerHost() + + @app.ws_handler + async def handler(websocket: WebSocket) -> None: + # Receive once, then return — SDK closes cleanly. + await websocket.receive_text() + + client = TestClient(app) + with pytest.raises(WebSocketDisconnect) as excinfo: + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("done") + ws.receive_text() # Forces the close to surface. + + assert excinfo.value.code == InvocationsWSConstants.CLOSE_NORMAL + + +# --------------------------------------------------------------------------- +# No handler registered +# --------------------------------------------------------------------------- + +def test_ws_with_no_handler_registered_closes_with_1011(): + """If no @ws_handler is registered, the SDK closes with 1011.""" + app = InvocationAgentServerHost() + client = TestClient(app) + + with pytest.raises(WebSocketDisconnect) as excinfo: + with client.websocket_connect("/invocations_ws"): + pass + + assert excinfo.value.code == InvocationsWSConstants.CLOSE_INTERNAL_ERROR + + +# --------------------------------------------------------------------------- +# Close-event log line carries ws.session_id, ws.close_code, ws.duration_ms +# --------------------------------------------------------------------------- + +def _records_with_ws_extras(records): + """Filter log records that carry the spec's ws.* extras.""" + out = [] + for rec in records: + if hasattr(rec, "ws.session_id") and hasattr(rec, "ws.close_code"): + out.append(rec) + return out + + +def test_ws_close_event_log_contains_required_fields(caplog): + """The close-event log line carries ws.session_id, ws.close_code, ws.duration_ms.""" + app = _make_echo_ws() + client = TestClient(app) + + with caplog.at_level(logging.INFO, logger="azure.ai.agentserver"): + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("ping") + assert ws.receive_text() == "ping" + + matches = _records_with_ws_extras(caplog.records) + assert matches, "expected a structured close-event log record" + rec = matches[-1] + + session_id = getattr(rec, "ws.session_id") + close_code = getattr(rec, "ws.close_code") + duration_ms = getattr(rec, "ws.duration_ms") + + assert isinstance(session_id, str) and session_id # generated UUID + assert close_code == InvocationsWSConstants.CLOSE_NORMAL + assert isinstance(duration_ms, int) + assert duration_ms >= 0 + + +def test_ws_close_event_on_handler_exception_records_1011(caplog): + """Handler raising → close-event log records ws.close_code = 1011.""" + app = _make_failing_ws() + client = TestClient(app) + + with caplog.at_level(logging.INFO, logger="azure.ai.agentserver"): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("trigger") + ws.receive_text() + + matches = _records_with_ws_extras(caplog.records) + assert matches + assert getattr(matches[-1], "ws.close_code") == 1011 + + +# --------------------------------------------------------------------------- +# Hypercorn ws_ping_interval wiring +# --------------------------------------------------------------------------- + +def test_ws_ping_interval_default_is_30_seconds(): + """Default ping interval matches the spec (30 s).""" + app = InvocationAgentServerHost() + assert app.ws_ping_interval == InvocationsWSConstants.DEFAULT_PING_INTERVAL_S + assert app.ws_ping_interval == 30.0 + + +def test_ws_ping_interval_custom_value(): + """``ws_ping_interval`` is honoured.""" + app = InvocationAgentServerHost(ws_ping_interval=15) + assert app.ws_ping_interval == 15.0 + + +def test_ws_ping_interval_zero_disables_keepalive(): + """``ws_ping_interval=0`` disables WS-level keep-alive.""" + app = InvocationAgentServerHost(ws_ping_interval=0) + assert app.ws_ping_interval == 0.0 + + +def test_ws_ping_interval_negative_rejected(): + """Negative intervals are programming errors.""" + with pytest.raises(ValueError, match="non-negative"): + InvocationAgentServerHost(ws_ping_interval=-1) + + +def test_ws_ping_interval_propagates_to_hypercorn_config(): + """The configured interval lands on the Hypercorn server config.""" + app = InvocationAgentServerHost(ws_ping_interval=20) + config = app._build_hypercorn_config("0.0.0.0", 8088) # noqa: SLF001 + # Hypercorn ≥0.14 exposes this attribute. + assert getattr(config, "websocket_ping_interval", None) == 20.0 + + +def test_ws_ping_interval_zero_does_not_override_hypercorn_default(): + """Zero leaves Hypercorn's default (None = disabled) intact.""" + app = InvocationAgentServerHost(ws_ping_interval=0) + config = app._build_hypercorn_config("0.0.0.0", 8088) # noqa: SLF001 + # Hypercorn default is None — our wiring leaves it unset for 0. + assert getattr(config, "websocket_ping_interval", None) is None + + +# --------------------------------------------------------------------------- +# Coexistence with HTTP /invocations +# --------------------------------------------------------------------------- + +def test_http_and_ws_share_same_host(): + """Both transports work on the same app object — single session, single process.""" + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + app = InvocationAgentServerHost() + + @app.invoke_handler + async def http_handle(request: Request) -> Response: + body = await request.json() + return JSONResponse({"http": body}) + + @app.ws_handler + async def ws_handle(websocket: WebSocket) -> None: + async for msg in websocket.iter_text(): + await websocket.send_text(f"ws:{msg}") + + client = TestClient(app) + + # HTTP route still works + resp = client.post("/invocations", json={"hello": "world"}) + assert resp.status_code == 200 + assert resp.json() == {"http": {"hello": "world"}} + + # WS route works on the same host + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("hi") + assert ws.receive_text() == "ws:hi" + + +# --------------------------------------------------------------------------- +# Client-initiated disconnect +# --------------------------------------------------------------------------- + +def test_ws_client_disconnect_does_not_log_as_error(caplog): + """A client-initiated disconnect is a normal close, not a 1011 error.""" + app = _make_echo_ws() + client = TestClient(app) + + with caplog.at_level(logging.INFO, logger="azure.ai.agentserver"): + with client.websocket_connect("/invocations_ws") as ws: + ws.send_text("hello") + ws.receive_text() + # __exit__ sends websocket.disconnect — the SDK should treat + # this as normal, not raise from the handler. + + error_records = [r for r in caplog.records if r.levelno >= logging.ERROR] + # No ERROR-level records should be emitted for a clean client disconnect. + assert not error_records, [r.getMessage() for r in error_records]