Skip to content

Commit 5d42015

Browse files
committed
Merge remote-tracking branch 'origin/main' into 1.0-dev
2 parents bab2a11 + 2acd838 commit 5d42015

File tree

6 files changed

+119
-3
lines changed

6 files changed

+119
-3
lines changed

src/a2a/client/base_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from collections.abc import AsyncIterator, Callable
2+
from types import TracebackType
23
from typing import Any
34

5+
from typing_extensions import Self
6+
47
from a2a.client.client import (
58
Client,
69
ClientCallContext,
@@ -45,6 +48,19 @@ def __init__(
4548
self._config = config
4649
self._transport = transport
4750

51+
async def __aenter__(self) -> Self:
52+
"""Enters the async context manager, returning the client itself."""
53+
return self
54+
55+
async def __aexit__(
56+
self,
57+
exc_type: type[BaseException] | None,
58+
exc_val: BaseException | None,
59+
exc_tb: TracebackType | None,
60+
) -> None:
61+
"""Exits the async context manager, ensuring close() is called."""
62+
await self.close()
63+
4864
async def send_message(
4965
self,
5066
request: Message,

src/a2a/client/transports/jsonrpc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ async def send_message_streaming(
188188
if isinstance(response.root, JSONRPCErrorResponse):
189189
raise A2AClientJSONRPCError(response.root)
190190
yield response.root.result
191+
except httpx.TimeoutException as e:
192+
raise A2AClientTimeoutError('Client Request timed out') from e
191193
except httpx.HTTPStatusError as e:
192194
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
193195
except SSEError as e:
@@ -212,7 +214,7 @@ async def _send_request(
212214
)
213215
response.raise_for_status()
214216
return response.json()
215-
except httpx.ReadTimeout as e:
217+
except httpx.TimeoutException as e:
216218
raise A2AClientTimeoutError('Client Request timed out') from e
217219
except httpx.HTTPStatusError as e:
218220
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
@@ -389,6 +391,8 @@ async def resubscribe(
389391
if isinstance(response.root, JSONRPCErrorResponse):
390392
raise A2AClientJSONRPCError(response.root)
391393
yield response.root.result
394+
except httpx.TimeoutException as e:
395+
raise A2AClientTimeoutError('Client Request timed out') from e
392396
except SSEError as e:
393397
raise A2AClientHTTPError(
394398
400, f'Invalid SSE response or protocol error: {e}'

src/a2a/client/transports/rest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from pydantic import BaseModel
1212

1313
from a2a.client.card_resolver import A2ACardResolver
14-
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
14+
from a2a.client.errors import (
15+
A2AClientHTTPError,
16+
A2AClientJSONError,
17+
A2AClientTimeoutError,
18+
)
1519
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1620
from a2a.client.transports.base import ClientTransport
1721
from a2a.extensions.common import update_extension_header
@@ -163,6 +167,8 @@ async def send_message_streaming(
163167
event = a2a_pb2.StreamResponse()
164168
Parse(sse.data, event)
165169
yield proto_utils.FromProto.stream_response(event)
170+
except httpx.TimeoutException as e:
171+
raise A2AClientTimeoutError('Client Request timed out') from e
166172
except httpx.HTTPStatusError as e:
167173
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
168174
except SSEError as e:
@@ -181,6 +187,8 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
181187
response = await self.httpx_client.send(request)
182188
response.raise_for_status()
183189
return response.json()
190+
except httpx.TimeoutException as e:
191+
raise A2AClientTimeoutError('Client Request timed out') from e
184192
except httpx.HTTPStatusError as e:
185193
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
186194
except json.JSONDecodeError as e:
@@ -383,6 +391,8 @@ async def resubscribe(
383391
event = a2a_pb2.StreamResponse()
384392
Parse(sse.data, event)
385393
yield proto_utils.FromProto.stream_response(event)
394+
except httpx.TimeoutException as e:
395+
raise A2AClientTimeoutError('Client Request timed out') from e
386396
except SSEError as e:
387397
raise A2AClientHTTPError(
388398
400, f'Invalid SSE response or protocol error: {e}'

tests/client/test_base_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ async def test_transport_async_context_manager_on_exception() -> None:
8787
transport.close.assert_awaited_once()
8888

8989

90+
@pytest.mark.asyncio
91+
async def test_base_client_async_context_manager(
92+
base_client: BaseClient, mock_transport: AsyncMock
93+
) -> None:
94+
async with base_client as client:
95+
assert client is base_client
96+
mock_transport.close.assert_not_awaited()
97+
mock_transport.close.assert_awaited_once()
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_base_client_async_context_manager_on_exception(
102+
base_client: BaseClient, mock_transport: AsyncMock
103+
) -> None:
104+
with pytest.raises(RuntimeError, match='boom'):
105+
async with base_client:
106+
raise RuntimeError('boom')
107+
mock_transport.close.assert_awaited_once()
108+
109+
90110
@pytest.mark.asyncio
91111
async def test_send_message_streaming(
92112
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message

tests/client/transports/test_jsonrpc_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,38 @@ async def test_send_message_client_timeout(
601601

602602
assert 'Client Request timed out' in str(exc_info.value)
603603

604+
@pytest.mark.asyncio
605+
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
606+
async def test_send_message_streaming_timeout(
607+
self,
608+
mock_aconnect_sse: AsyncMock,
609+
mock_httpx_client: AsyncMock,
610+
mock_agent_card: MagicMock,
611+
):
612+
client = JsonRpcTransport(
613+
httpx_client=mock_httpx_client, agent_card=mock_agent_card
614+
)
615+
params = MessageSendParams(
616+
message=create_text_message_object(content='Hello stream')
617+
)
618+
mock_event_source = AsyncMock(spec=EventSource)
619+
mock_event_source.response = MagicMock(spec=httpx.Response)
620+
mock_event_source.response.raise_for_status.return_value = None
621+
mock_event_source.aiter_sse.side_effect = httpx.TimeoutException(
622+
'Read timed out'
623+
)
624+
mock_aconnect_sse.return_value.__aenter__.return_value = (
625+
mock_event_source
626+
)
627+
628+
with pytest.raises(A2AClientTimeoutError) as exc_info:
629+
_ = [
630+
item
631+
async for item in client.send_message_streaming(request=params)
632+
]
633+
634+
assert 'Client Request timed out' in str(exc_info.value)
635+
604636
@pytest.mark.asyncio
605637
async def test_get_task_success(
606638
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock

tests/client/transports/test_rest_client.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from httpx_sse import EventSource, ServerSentEvent
1010

1111
from a2a.client import create_text_message_object
12-
from a2a.client.errors import A2AClientHTTPError
12+
from a2a.client.errors import A2AClientHTTPError, A2AClientTimeoutError
1313
from a2a.client.transports.rest import RestTransport
1414
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1515
from a2a.grpc import a2a_pb2
@@ -50,6 +50,40 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
5050
assert actual_extensions == expected_extensions
5151

5252

53+
class TestRestTransport:
54+
@pytest.mark.asyncio
55+
@patch('a2a.client.transports.rest.aconnect_sse')
56+
async def test_send_message_streaming_timeout(
57+
self,
58+
mock_aconnect_sse: AsyncMock,
59+
mock_httpx_client: AsyncMock,
60+
mock_agent_card: MagicMock,
61+
):
62+
client = RestTransport(
63+
httpx_client=mock_httpx_client, agent_card=mock_agent_card
64+
)
65+
params = MessageSendParams(
66+
message=create_text_message_object(content='Hello stream')
67+
)
68+
mock_event_source = AsyncMock(spec=EventSource)
69+
mock_event_source.response = MagicMock(spec=httpx.Response)
70+
mock_event_source.response.raise_for_status.return_value = None
71+
mock_event_source.aiter_sse.side_effect = httpx.TimeoutException(
72+
'Read timed out'
73+
)
74+
mock_aconnect_sse.return_value.__aenter__.return_value = (
75+
mock_event_source
76+
)
77+
78+
with pytest.raises(A2AClientTimeoutError) as exc_info:
79+
_ = [
80+
item
81+
async for item in client.send_message_streaming(request=params)
82+
]
83+
84+
assert 'Client Request timed out' in str(exc_info.value)
85+
86+
5387
class TestRestTransportExtensions:
5488
@pytest.mark.asyncio
5589
async def test_send_message_with_default_extensions(

0 commit comments

Comments
 (0)