Skip to content

Commit 6c807d5

Browse files
fix: fix JSONRPC error handling (#957)
# Description Do one iteration to catch exceptions occurred beforehand to return an error instead of sending headers for SSE.
1 parent a669521 commit 6c807d5

3 files changed

Lines changed: 86 additions & 7 deletions

File tree

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dunders
4545
ES256
4646
euo
4747
EUR
48+
evt
4849
excinfo
4950
FastAPI
5051
fernet

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
HTTP_EXTENSION_HEADER,
1616
)
1717
from a2a.server.context import ServerCallContext
18+
from a2a.server.events import Event
1819
from a2a.server.jsonrpc_models import (
1920
InternalError,
2021
InvalidParamsError,
@@ -376,20 +377,32 @@ async def _process_streaming_request(
376377
if stream is None:
377378
raise UnsupportedOperationError(message='Stream not supported')
378379

380+
# Eagerly fetch the first event to trigger validation/upfront errors
381+
try:
382+
first_event = await anext(stream)
383+
except StopAsyncIteration:
384+
first_event = None
385+
379386
async def _wrap_stream(
380-
st: AsyncGenerator,
387+
st: AsyncGenerator, first_evt: Event | None
381388
) -> AsyncGenerator[dict[str, Any], None]:
389+
def _map_event(evt: Event) -> dict[str, Any]:
390+
stream_response = proto_utils.to_stream_response(evt)
391+
result = MessageToDict(
392+
stream_response, preserving_proto_field_name=False
393+
)
394+
return JSONRPC20Response(result=result, _id=request_id).data
395+
382396
try:
397+
if first_evt is not None:
398+
yield _map_event(first_evt)
399+
383400
async for event in st:
384-
stream_response = proto_utils.to_stream_response(event)
385-
result = MessageToDict(
386-
stream_response, preserving_proto_field_name=False
387-
)
388-
yield JSONRPC20Response(result=result, _id=request_id).data
401+
yield _map_event(event)
389402
except A2AError as e:
390403
yield build_error_response(request_id, e)
391404

392-
return _wrap_stream(stream)
405+
return _wrap_stream(stream, first_event)
393406

394407
async def _handle_send_message(
395408
self, request_obj: SendMessageRequest, context: ServerCallContext

tests/integration/test_client_server_integration.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,71 @@ async def mock_generator(*args, **kwargs):
10191019
await client.close()
10201020

10211021

1022+
@pytest.mark.asyncio
1023+
@pytest.mark.parametrize(
1024+
'error_cls,handler_attr,client_method,request_params',
1025+
[
1026+
pytest.param(
1027+
UnsupportedOperationError,
1028+
'on_subscribe_to_task',
1029+
'subscribe',
1030+
SubscribeToTaskRequest(id='some-id'),
1031+
id='subscribe',
1032+
),
1033+
],
1034+
)
1035+
async def test_server_rejects_stream_on_validation_error(
1036+
transport_setups, error_cls, handler_attr, client_method, request_params
1037+
) -> None:
1038+
"""Verify that the server returns an error directly and doesn't open a stream on validation error."""
1039+
client = transport_setups.client
1040+
handler = transport_setups.handler
1041+
1042+
async def mock_generator(*args, **kwargs):
1043+
raise error_cls('Validation failed')
1044+
yield
1045+
1046+
getattr(handler, handler_attr).side_effect = mock_generator
1047+
1048+
transport = client._transport
1049+
1050+
if isinstance(transport, (RestTransport, JsonRpcTransport)):
1051+
# Spy on httpx client to check response headers
1052+
original_send = transport.httpx_client.send
1053+
response_headers = {}
1054+
1055+
async def mock_send(*args, **kwargs):
1056+
resp = await original_send(*args, **kwargs)
1057+
response_headers['Content-Type'] = resp.headers.get('Content-Type')
1058+
return resp
1059+
1060+
transport.httpx_client.send = mock_send
1061+
1062+
try:
1063+
with pytest.raises(error_cls):
1064+
async for _ in getattr(client, client_method)(
1065+
request=request_params
1066+
):
1067+
pass
1068+
finally:
1069+
transport.httpx_client.send = original_send
1070+
1071+
# Verify that the response content type was NOT text/event-stream
1072+
assert not response_headers.get('Content-Type', '').startswith(
1073+
'text/event-stream'
1074+
)
1075+
else:
1076+
# For gRPC, we just verify it raises the error
1077+
with pytest.raises(error_cls):
1078+
async for _ in getattr(client, client_method)(
1079+
request=request_params
1080+
):
1081+
pass
1082+
1083+
getattr(handler, handler_attr).side_effect = None
1084+
await client.close()
1085+
1086+
10221087
@pytest.mark.asyncio
10231088
@pytest.mark.parametrize(
10241089
'request_kwargs, expected_error_code',

0 commit comments

Comments
 (0)