diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 468868ede..c17801606 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -31,7 +31,6 @@ DefaultServerCallContextBuilder, ServerCallContextBuilder, ) -from a2a.types import A2ARequest from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -349,7 +348,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, else: try: raw_result = await self._process_non_streaming_request( - request_id, specific_request, call_context + specific_request, call_context ) handler_result = JSONRPC20Response( result=raw_result, _id=request_id @@ -385,7 +384,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, async def _process_streaming_request( self, request_id: str | int | None, - request_obj: A2ARequest, + request_obj: Any, context: ServerCallContext, ) -> AsyncGenerator[dict[str, Any], None]: """Processes streaming requests (SendStreamingMessage or SubscribeToTask). @@ -399,11 +398,12 @@ async def _process_streaming_request( An `AsyncGenerator` object to stream results to the client. """ stream: AsyncGenerator | None = None - if isinstance(request_obj, SendMessageRequest): + method = context.state.get('method') + if method == 'SendStreamingMessage': stream = self.request_handler.on_message_send_stream( request_obj, context ) - elif isinstance(request_obj, SubscribeToTaskRequest): + elif method == 'SubscribeToTask': stream = self.request_handler.on_subscribe_to_task( request_obj, context ) @@ -538,55 +538,53 @@ async def _handle_get_extended_agent_card( @validate_version(constants.PROTOCOL_VERSION_1_0) async def _process_non_streaming_request( # noqa: PLR0911 self, - request_id: str | int | None, - request_obj: A2ARequest, + request_obj: Any, context: ServerCallContext, ) -> dict[str, Any] | None: - """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). + """Processes non-streaming requests. Args: - request_id: The ID of the request. request_obj: The proto request message. context: The ServerCallContext for the request. Returns: A dict containing the result or error. """ - match request_obj: - case SendMessageRequest(): + method = context.state.get('method') + match method: + case 'SendMessage': return await self._handle_send_message(request_obj, context) - case CancelTaskRequest(): + case 'CancelTask': return await self._handle_cancel_task(request_obj, context) - case GetTaskRequest(): + case 'GetTask': return await self._handle_get_task(request_obj, context) - case ListTasksRequest(): + case 'ListTasks': return await self._handle_list_tasks(request_obj, context) - case TaskPushNotificationConfig(): + case 'CreateTaskPushNotificationConfig': return await self._handle_create_task_push_notification_config( request_obj, context ) - case GetTaskPushNotificationConfigRequest(): + case 'GetTaskPushNotificationConfig': return await self._handle_get_task_push_notification_config( request_obj, context ) - case ListTaskPushNotificationConfigsRequest(): + case 'ListTaskPushNotificationConfigs': return await self._handle_list_task_push_notification_configs( request_obj, context ) - case DeleteTaskPushNotificationConfigRequest(): - return await self._handle_delete_task_push_notification_config( + case 'DeleteTaskPushNotificationConfig': + await self._handle_delete_task_push_notification_config( request_obj, context ) - case GetExtendedAgentCardRequest(): + return None + case 'GetExtendedAgentCard': return await self._handle_get_extended_agent_card( request_obj, context ) case _: - logger.error( - 'Unhandled validated request type: %s', type(request_obj) - ) + logger.error('Unhandled method: %s', method) raise UnsupportedOperationError( - message=f'Request type {type(request_obj).__name__} is unknown.' + message=f'Method {method} is not supported.' ) def _create_response( diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 31a550de3..f884bb38e 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -1,3 +1,4 @@ +import asyncio import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -15,10 +16,19 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( + AgentCapabilities, AgentCard, + Artifact, + ListTaskPushNotificationConfigsResponse, + ListTasksResponse, Message, Part, Role, + Task, + TaskArtifactUpdateEvent, + TaskPushNotificationConfig, + TaskState, + TaskStatus, ) from a2a.server.routes import jsonrpc_dispatcher @@ -259,5 +269,361 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): assert mock_handle.call_args[1]['method'] == 'message/send' +def _make_jsonrpc_request(method: str, params: dict | None = None) -> dict: + """Helper to build a JSON-RPC 2.0 request dict.""" + return { + 'jsonrpc': '2.0', + 'id': '1', + 'method': method, + 'params': params or {}, + } + + +class TestJsonRpcDispatcherMethodRouting: + """Tests that each JSON-RPC method name routes to the correct handler.""" + + @pytest.fixture + def handler(self): + handler = AsyncMock(spec=RequestHandler) + handler.on_message_send.return_value = Message( + message_id='test', + role=Role.ROLE_AGENT, + parts=[Part(text='ok')], + ) + handler.on_cancel_task.return_value = Task( + id='task1', + context_id='ctx1', + status=TaskStatus(state=TaskState.TASK_STATE_CANCELED), + ) + handler.on_get_task.return_value = Task( + id='task1', + context_id='ctx1', + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + ) + handler.on_list_tasks.return_value = ListTasksResponse() + handler.on_create_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(task_id='t1', url='https://example.com') + ) + handler.on_get_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(task_id='t1', url='https://example.com') + ) + handler.on_list_task_push_notification_configs.return_value = ( + ListTaskPushNotificationConfigsResponse() + ) + handler.on_delete_task_push_notification_config.return_value = None + return handler + + @pytest.fixture + def agent_card(self): + return AgentCard( + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), + name='TestAgent', + version='1.0', + ) + + @pytest.fixture + def client(self, handler, agent_card): + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes) + return TestClient(app, headers={'A2A-Version': '1.0'}) + + # --- Non-streaming method routing tests --- + + def test_send_message_routes_to_on_message_send(self, client, handler): + response = client.post( + '/', + json=_make_jsonrpc_request( + 'SendMessage', + { + 'message': { + 'messageId': '1', + 'role': 'ROLE_USER', + 'parts': [{'text': 'hello'}], + } + }, + ), + ) + response.raise_for_status() + + handler.on_message_send.assert_called_once() + call_context = handler.on_message_send.call_args[0][1] + assert call_context.state['method'] == 'SendMessage' + + def test_cancel_task_routes_to_on_cancel_task(self, client, handler): + response = client.post( + '/', + json=_make_jsonrpc_request('CancelTask', {'id': 'task1'}), + ) + response.raise_for_status() + + handler.on_cancel_task.assert_called_once() + call_context = handler.on_cancel_task.call_args[0][1] + assert call_context.state['method'] == 'CancelTask' + + def test_get_task_routes_to_on_get_task(self, client, handler): + response = client.post( + '/', + json=_make_jsonrpc_request('GetTask', {'id': 'task1'}), + ) + response.raise_for_status() + + handler.on_get_task.assert_called_once() + call_context = handler.on_get_task.call_args[0][1] + assert call_context.state['method'] == 'GetTask' + + def test_list_tasks_routes_to_on_list_tasks(self, client, handler): + response = client.post( + '/', + json=_make_jsonrpc_request('ListTasks'), + ) + response.raise_for_status() + + handler.on_list_tasks.assert_called_once() + call_context = handler.on_list_tasks.call_args[0][1] + assert call_context.state['method'] == 'ListTasks' + + def test_create_push_notification_config_routes_correctly( + self, client, handler + ): + response = client.post( + '/', + json=_make_jsonrpc_request( + 'CreateTaskPushNotificationConfig', + {'taskId': 't1', 'url': 'https://example.com'}, + ), + ) + response.raise_for_status() + + handler.on_create_task_push_notification_config.assert_called_once() + call_context = ( + handler.on_create_task_push_notification_config.call_args[0][1] + ) + assert ( + call_context.state['method'] == 'CreateTaskPushNotificationConfig' + ) + + def test_get_push_notification_config_routes_correctly( + self, client, handler + ): + response = client.post( + '/', + json=_make_jsonrpc_request( + 'GetTaskPushNotificationConfig', + {'taskId': 't1', 'id': 'config1'}, + ), + ) + response.raise_for_status() + + handler.on_get_task_push_notification_config.assert_called_once() + call_context = handler.on_get_task_push_notification_config.call_args[ + 0 + ][1] + assert call_context.state['method'] == 'GetTaskPushNotificationConfig' + + def test_list_push_notification_configs_routes_correctly( + self, client, handler + ): + response = client.post( + '/', + json=_make_jsonrpc_request( + 'ListTaskPushNotificationConfigs', + {'taskId': 't1'}, + ), + ) + response.raise_for_status() + + handler.on_list_task_push_notification_configs.assert_called_once() + call_context = handler.on_list_task_push_notification_configs.call_args[ + 0 + ][1] + assert call_context.state['method'] == 'ListTaskPushNotificationConfigs' + + def test_delete_push_notification_config_routes_correctly( + self, client, handler + ): + response = client.post( + '/', + json=_make_jsonrpc_request( + 'DeleteTaskPushNotificationConfig', + {'taskId': 't1', 'id': 'config1'}, + ), + ) + response.raise_for_status() + data = response.json() + assert data.get('result') is None + + handler.on_delete_task_push_notification_config.assert_called_once() + call_context = ( + handler.on_delete_task_push_notification_config.call_args[0][1] + ) + assert ( + call_context.state['method'] == 'DeleteTaskPushNotificationConfig' + ) + + def test_get_extended_agent_card_routes_correctly( + self, handler, agent_card + ): + captured: dict[str, Any] = {} + + async def capture_modifier(card, context): + captured['method'] = context.state.get('method') + return card + + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + extended_card_modifier=capture_modifier, + rpc_url='/', + ) + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes) + client = TestClient(app, headers={'A2A-Version': '1.0'}) + + response = client.post( + '/', + json=_make_jsonrpc_request('GetExtendedAgentCard'), + ) + response.raise_for_status() + data = response.json() + assert 'result' in data + assert data['result']['name'] == 'TestAgent' + assert captured['method'] == 'GetExtendedAgentCard' + + # --- Streaming method routing tests --- + + @pytest.mark.asyncio + async def test_send_streaming_message_routes_to_on_message_send_stream( + self, handler, agent_card + ): + async def stream_generator(): + yield TaskArtifactUpdateEvent( + artifact=Artifact( + artifact_id='a1', + name='result', + parts=[Part(text='streamed')], + ), + task_id='task1', + context_id='ctx1', + append=False, + last_chunk=True, + ) + + handler.on_message_send_stream = MagicMock( + return_value=stream_generator() + ) + + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=handler, + rpc_url='/', + ) + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes) + client = TestClient(app, headers={'A2A-Version': '1.0'}) + + try: + with client.stream( + 'POST', + '/', + json=_make_jsonrpc_request( + 'SendStreamingMessage', + { + 'message': { + 'messageId': '1', + 'role': 'ROLE_USER', + 'parts': [{'text': 'hello'}], + } + }, + ), + ) as response: + assert response.status_code == 200 + assert response.headers['content-type'].startswith( + 'text/event-stream' + ) + content = b'' + for chunk in response.iter_bytes(): + content += chunk + assert b'a1' in content + finally: + client.close() + await asyncio.sleep(0.1) + + handler.on_message_send_stream.assert_called_once() + call_context = handler.on_message_send_stream.call_args[0][1] + assert call_context.state['method'] == 'SendStreamingMessage' + + @pytest.mark.asyncio + async def test_subscribe_to_task_routes_to_on_subscribe_to_task( + self, handler, agent_card + ): + async def stream_generator(): + yield TaskArtifactUpdateEvent( + artifact=Artifact( + artifact_id='a1', + name='result', + parts=[Part(text='streamed')], + ), + task_id='task1', + context_id='ctx1', + append=False, + last_chunk=True, + ) + + handler.on_subscribe_to_task = MagicMock( + return_value=stream_generator() + ) + + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=handler, + rpc_url='/', + ) + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes) + client = TestClient(app, headers={'A2A-Version': '1.0'}) + + try: + with client.stream( + 'POST', + '/', + json=_make_jsonrpc_request( + 'SubscribeToTask', + { + 'id': 'task1', + }, + ), + ) as response: + assert response.status_code == 200 + assert response.headers['content-type'].startswith( + 'text/event-stream' + ) + content = b'' + for chunk in response.iter_bytes(): + content += chunk + assert b'a1' in content + finally: + client.close() + await asyncio.sleep(0.1) + + handler.on_subscribe_to_task.assert_called_once() + call_context = handler.on_subscribe_to_task.call_args[0][1] + assert call_context.state['method'] == 'SubscribeToTask' + + if __name__ == '__main__': pytest.main([__file__])