Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/a2a/client/transports/http_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from a2a.client.errors import A2AClientError, A2AClientTimeoutError


def _default_sse_error_handler(sse_data: str) -> NoReturn:
raise A2AClientError(f'SSE stream error event received: {sse_data}')


@contextmanager
def handle_http_exceptions(
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
Expand Down Expand Up @@ -71,9 +75,22 @@ async def send_http_stream_request(
url: str,
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
| None = None,
sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler,
**kwargs: Any,
) -> AsyncGenerator[str]:
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.

Args:
httpx_client: The async HTTP client.
method: The HTTP method (e.g. 'POST', 'GET').
url: The URL to send the request to.
status_error_handler: Handler for HTTP status errors. Should raise an
appropriate domain-specific exception.
sse_error_handler: Handler for SSE error events. Called with the
raw SSE data string when an ``event: error`` SSE event is received.
Should raise an appropriate domain-specific exception.
**kwargs: Additional keyword arguments forwarded to ``aconnect_sse``.
"""
with handle_http_exceptions(status_error_handler):
async with _SSEEventSource(
httpx_client, method, url, **kwargs
Expand All @@ -97,6 +114,8 @@ async def send_http_stream_request(
async for sse in event_source.aiter_sse():
if not sse.data:
continue
if sse.event == 'error':
sse_error_handler(sse.data)
yield sse.data


Expand Down
10 changes: 9 additions & 1 deletion src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, NoReturn
from uuid import uuid4

import httpx
Expand Down Expand Up @@ -350,6 +350,7 @@ async def _send_stream_request(
'POST',
self.url,
None,
self._handle_sse_error,
json=rpc_request_payload,
**http_kwargs,
):
Expand All @@ -360,3 +361,10 @@ async def _send_stream_request(
json_rpc_response.result, StreamResponse()
)
yield response

def _handle_sse_error(self, sse_data: str) -> NoReturn:
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
json_rpc_response = JSONRPC20Response.from_json(sse_data)
if json_rpc_response.error:
raise self._create_jsonrpc_error(json_rpc_response.error)
raise A2AClientError(f'SSE stream error: {sse_data}')
83 changes: 53 additions & 30 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@
logger = logging.getLogger(__name__)


def _parse_rest_error(
error_payload: dict[str, Any],
fallback_message: str,
) -> Exception | None:
"""Parses a REST error payload and returns the appropriate A2AError.

Args:
error_payload: The parsed JSON error payload.
fallback_message: Message to use if the payload has no ``message``.

Returns:
The mapped A2AError if a known reason was found, otherwise ``None``.
"""
error_data = error_payload.get('error', {})
message = error_data.get('message', fallback_message)
details = error_data.get('details', [])
if not isinstance(details, list):
return None

# The `details` array can contain multiple different error objects.
# We extract the first `ErrorInfo` object because it contains the
# specific `reason` code needed to map this back to a Python A2AError.
for d in details:
if (
isinstance(d, dict)
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
):
reason = d.get('reason')
metadata = d.get('metadata') or {}
if isinstance(reason, str):
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if exception_cls:
exc = exception_cls(message)
if metadata:
exc.data = metadata
return exc
break

return None


@trace_class(kind=SpanKind.CLIENT)
class RestTransport(ClientTransport):
"""A REST transport for the A2A client."""
Expand Down Expand Up @@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
"""Handles HTTP status errors and raises the appropriate A2AError."""
try:
error_payload = e.response.json()
error_data = error_payload.get('error', {})

message = error_data.get('message', str(e))
details = error_data.get('details', [])
if not isinstance(details, list):
details = []

# The `details` array can contain multiple different error objects.
# We extract the first `ErrorInfo` object because it contains the
# specific `reason` code needed to map this back to a Python A2AError.
error_info = {}
for d in details:
if (
isinstance(d, dict)
and d.get('@type')
== 'type.googleapis.com/google.rpc.ErrorInfo'
):
error_info = d
break
reason = error_info.get('reason')
metadata = error_info.get('metadata') or {}

if isinstance(reason, str):
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if exception_cls:
exc = exception_cls(message)
if metadata:
exc.data = metadata
raise exc from e
mapped = _parse_rest_error(error_payload, str(e))
if mapped:
raise mapped from e
except (json.JSONDecodeError, ValueError):
pass

# Fallback mappings for status codes if 'type' is missing or unknown
status_code = e.response.status_code
if status_code == httpx.codes.NOT_FOUND:
raise MethodNotFoundError(
Expand All @@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:

raise A2AClientError(f'HTTP Error {status_code}: {e}') from e

def _handle_sse_error(self, sse_data: str) -> NoReturn:
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
error_payload = json.loads(sse_data)
mapped = _parse_rest_error(error_payload, sse_data)
if mapped:
raise mapped
raise A2AClientError(sse_data)

async def _send_stream_request(
self,
method: str,
Expand All @@ -352,6 +374,7 @@ async def _send_stream_request(
method,
f'{self.url}{path}',
self._handle_http_error,
self._handle_sse_error,
json=json,
**http_kwargs,
):
Expand Down
26 changes: 24 additions & 2 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,30 @@ def _create_response(
async def event_generator(
stream: AsyncGenerator[dict[str, Any]],
) -> AsyncGenerator[dict[str, str]]:
async for item in stream:
yield {'data': json.dumps(item)}
try:
async for item in stream:
event: dict[str, str] = {
'data': json.dumps(item),
}
if 'error' in item:
event['event'] = 'error'
yield event
except Exception as e:
logger.exception(
'Unhandled error during JSON-RPC SSE stream'
)
rpc_error: A2AError | JSONRPCError = (
e
if isinstance(e, A2AError | JSONRPCError)
else InternalError(message=str(e))
)
error_response = build_error_response(
context.state.get('request_id'), rpc_error
)
yield {
'event': 'error',
'data': json.dumps(error_response),
}

return EventSourceResponse(
event_generator(handler_result), headers=headers
Expand Down
19 changes: 15 additions & 4 deletions src/a2a/server/routes/rest_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from a2a.utils import constants, proto_utils
from a2a.utils.error_handlers import (
build_rest_error_payload,
rest_error_handler,
rest_stream_error_handler,
)
Expand All @@ -32,20 +33,23 @@


if TYPE_CHECKING:
from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

_package_starlette_installed = True
else:
try:
from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

_package_starlette_installed = True
except ImportError:
EventSourceResponse = Any
ServerSentEvent = Any
Request = Any
JSONResponse = Any
Response = Any
Expand Down Expand Up @@ -135,10 +139,17 @@ async def _handle_streaming(
except StopAsyncIteration:
return EventSourceResponse(iter([]))

async def event_generator() -> AsyncIterator[str]:
yield json.dumps(first_item)
async for item in stream:
yield json.dumps(item)
async def event_generator() -> AsyncIterator[ServerSentEvent]:
yield ServerSentEvent(data=json.dumps(first_item))
try:
async for item in stream:
yield ServerSentEvent(data=json.dumps(item))
except Exception as e:
logger.exception('Error during REST SSE stream')
yield ServerSentEvent(
data=json.dumps(build_rest_error_payload(e)),
event='error',
)

return EventSourceResponse(event_generator())

Expand Down
80 changes: 41 additions & 39 deletions src/a2a/utils/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,64 +54,66 @@
return {'error': payload}


def _create_error_response(error: Exception) -> Response:
"""Helper function to create a JSONResponse for an error."""
def build_rest_error_payload(error: Exception) -> dict[str, Any]:
"""Build a REST error payload dict from an exception.

Returns:
A dict with the error payload in the standard REST error format.
"""
if isinstance(error, A2AError):
mapping = A2A_REST_ERROR_MAPPING.get(
type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR')
)
http_code = mapping.http_code
grpc_status = mapping.grpc_status
reason = mapping.reason
# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
metadata = getattr(error, 'data', None) or {}
return _build_error_payload(
code=mapping.http_code,
status=mapping.grpc_status,
message=getattr(error, 'message', str(error)),
reason=mapping.reason,
metadata=metadata,
)
if isinstance(error, ParseError):
return _build_error_payload(
code=400,
status='INVALID_ARGUMENT',
message=str(error),
reason='INVALID_REQUEST',
metadata={},
)
return _build_error_payload(
code=500,
status='INTERNAL',
message='unknown exception',
)


def _create_error_response(error: Exception) -> Response:
"""Helper function to create a JSONResponse for an error."""
if isinstance(error, A2AError):
log_level = (
logging.ERROR
if isinstance(error, InternalError)
else logging.WARNING
)
logger.log(
log_level,
"Request error: Code=%s, Message='%s'%s",
getattr(error, 'code', 'N/A'),
getattr(error, 'message', str(error)),
f', Data={error.data}' if error.data else '',
)

# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
metadata = getattr(error, 'data', None) or {}

return JSONResponse(
content=_build_error_payload(
code=http_code,
status=grpc_status,
message=getattr(error, 'message', str(error)),
reason=reason,
metadata=metadata,
),
status_code=http_code,
media_type='application/json',
)
if isinstance(error, ParseError):
elif isinstance(error, ParseError):

Check notice on line 106 in src/a2a/utils/error_handlers.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/utils/error_handlers.py (142-155)
logger.warning('Parse error: %s', str(error))
return JSONResponse(
content=_build_error_payload(
code=400,
status='INVALID_ARGUMENT',
message=str(error),
reason='INVALID_REQUEST',
metadata={},
),
status_code=400,
media_type='application/json',
)
logger.exception('Unknown error occurred')
else:
logger.exception('Unknown error occurred')

payload = build_rest_error_payload(error)
# Extract HTTP status code from the payload
http_code = payload.get('error', {}).get('code', 500)
return JSONResponse(
content=_build_error_payload(
code=500,
status='INTERNAL',
message='unknown exception',
),
status_code=500,
content=payload,
status_code=http_code,
media_type='application/json',
)

Expand Down
Loading
Loading