Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
a2a

Check warning on line 1 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
A2A

Check warning on line 2 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
A2AFastAPI

Check warning on line 3 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
AAgent
Expand Down Expand Up @@ -27,7 +27,7 @@
AUser
autouse
backticks
base64url

Check warning on line 30 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
buf
bufbuild
cla
Expand All @@ -42,9 +42,10 @@
drivername
DSNs
dunders
ES256

Check warning on line 45 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
euo
EUR
evt
excinfo
FastAPI
fernet
Expand All @@ -56,8 +57,8 @@
gle
GVsb
hazmat
HS256

Check warning on line 60 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
HS384

Check warning on line 61 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
ietf
importlib
initdb
Expand Down Expand Up @@ -95,10 +96,10 @@
Oneof
OpenAPI
openapiv
openapiv2

Check warning on line 99 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
opensource
otherurl
pb2

Check warning on line 102 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
poolclass
postgres
POSTGRES
Expand All @@ -118,7 +119,7 @@
respx
resub
rmi
RS256

Check warning on line 122 in .github/actions/spelling/allow.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Ignoring entry because it contains non-alpha characters (non-alpha-in-dictionary)
RUF
SECP256R1
SLF
Expand Down
27 changes: 20 additions & 7 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HTTP_EXTENSION_HEADER,
)
from a2a.server.context import ServerCallContext
from a2a.server.events import Event
from a2a.server.jsonrpc_models import (
InternalError,
InvalidParamsError,
Expand Down Expand Up @@ -376,20 +377,32 @@ async def _process_streaming_request(
if stream is None:
raise UnsupportedOperationError(message='Stream not supported')

# Eagerly fetch the first event to trigger validation/upfront errors
try:
first_event = await anext(stream)
except StopAsyncIteration:
first_event = None

async def _wrap_stream(
st: AsyncGenerator,
st: AsyncGenerator, first_evt: Event | None
) -> AsyncGenerator[dict[str, Any], None]:
def _map_event(evt: Event) -> dict[str, Any]:
stream_response = proto_utils.to_stream_response(evt)
result = MessageToDict(
stream_response, preserving_proto_field_name=False
)
return JSONRPC20Response(result=result, _id=request_id).data

try:
if first_evt is not None:
yield _map_event(first_evt)

async for event in st:
stream_response = proto_utils.to_stream_response(event)
result = MessageToDict(
stream_response, preserving_proto_field_name=False
)
yield JSONRPC20Response(result=result, _id=request_id).data
yield _map_event(event)
except A2AError as e:
yield build_error_response(request_id, e)

return _wrap_stream(stream)
return _wrap_stream(stream, first_event)

async def _handle_send_message(
self, request_obj: SendMessageRequest, context: ServerCallContext
Expand Down
65 changes: 65 additions & 0 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,71 @@ async def mock_generator(*args, **kwargs):
await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'error_cls,handler_attr,client_method,request_params',
[
pytest.param(
UnsupportedOperationError,
'on_subscribe_to_task',
'subscribe',
SubscribeToTaskRequest(id='some-id'),
id='subscribe',
),
],
)
async def test_server_rejects_stream_on_validation_error(
transport_setups, error_cls, handler_attr, client_method, request_params
) -> None:
"""Verify that the server returns an error directly and doesn't open a stream on validation error."""
client = transport_setups.client
handler = transport_setups.handler

async def mock_generator(*args, **kwargs):
raise error_cls('Validation failed')
yield

getattr(handler, handler_attr).side_effect = mock_generator

transport = client._transport

if isinstance(transport, (RestTransport, JsonRpcTransport)):
# Spy on httpx client to check response headers
original_send = transport.httpx_client.send
response_headers = {}

async def mock_send(*args, **kwargs):
resp = await original_send(*args, **kwargs)
response_headers['Content-Type'] = resp.headers.get('Content-Type')
return resp

transport.httpx_client.send = mock_send

try:
with pytest.raises(error_cls):
async for _ in getattr(client, client_method)(
request=request_params
):
pass
finally:
transport.httpx_client.send = original_send

# Verify that the response content type was NOT text/event-stream
assert not response_headers.get('Content-Type', '').startswith(
'text/event-stream'
)
else:
# For gRPC, we just verify it raises the error
with pytest.raises(error_cls):
async for _ in getattr(client, client_method)(
request=request_params
):
pass

getattr(handler, handler_attr).side_effect = None
await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'request_kwargs, expected_error_code',
Expand Down
Loading