diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 90097440..b3b2d56e 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -45,6 +45,7 @@ dunders ES256 euo EUR +evt excinfo FastAPI fernet diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index d9ea4ff1..60620081 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -15,6 +15,7 @@ HTTP_EXTENSION_HEADER, ) from a2a.server.context import ServerCallContext +from a2a.server.events import Event from a2a.server.jsonrpc_models import ( InternalError, InvalidParamsError, @@ -376,20 +377,32 @@ async def _process_streaming_request( if stream is None: raise UnsupportedOperationError(message='Stream not supported') + # Eagerly fetch the first event to trigger validation/upfront errors + try: + first_event = await anext(stream) + except StopAsyncIteration: + first_event = None + async def _wrap_stream( - st: AsyncGenerator, + st: AsyncGenerator, first_evt: Event | None ) -> AsyncGenerator[dict[str, Any], None]: + def _map_event(evt: Event) -> dict[str, Any]: + stream_response = proto_utils.to_stream_response(evt) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + return JSONRPC20Response(result=result, _id=request_id).data + try: + if first_evt is not None: + yield _map_event(first_evt) + async for event in st: - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield JSONRPC20Response(result=result, _id=request_id).data + yield _map_event(event) except A2AError as e: yield build_error_response(request_id, e) - return _wrap_stream(stream) + return _wrap_stream(stream, first_event) async def _handle_send_message( self, request_obj: SendMessageRequest, context: ServerCallContext diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index c7fa29ea..1ac8a716 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1019,6 +1019,71 @@ async def mock_generator(*args, **kwargs): await client.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls,handler_attr,client_method,request_params', + [ + pytest.param( + UnsupportedOperationError, + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_server_rejects_stream_on_validation_error( + transport_setups, error_cls, handler_attr, client_method, request_params +) -> None: + """Verify that the server returns an error directly and doesn't open a stream on validation error.""" + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + raise error_cls('Validation failed') + yield + + getattr(handler, handler_attr).side_effect = mock_generator + + transport = client._transport + + if isinstance(transport, (RestTransport, JsonRpcTransport)): + # Spy on httpx client to check response headers + original_send = transport.httpx_client.send + response_headers = {} + + async def mock_send(*args, **kwargs): + resp = await original_send(*args, **kwargs) + response_headers['Content-Type'] = resp.headers.get('Content-Type') + return resp + + transport.httpx_client.send = mock_send + + try: + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + finally: + transport.httpx_client.send = original_send + + # Verify that the response content type was NOT text/event-stream + assert not response_headers.get('Content-Type', '').startswith( + 'text/event-stream' + ) + else: + # For gRPC, we just verify it raises the error + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + + getattr(handler, handler_attr).side_effect = None + await client.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'request_kwargs, expected_error_code',