diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index defdd5244..a3cd94cbe 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -5,7 +5,7 @@ import logging import uuid -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from a2a.server.agent_execution.context import RequestContext @@ -56,6 +56,12 @@ } +class _RequestStarted: + def __init__(self, request_id: uuid.UUID, request_context: RequestContext): + self.request_id = request_id + self.request_context = request_context + + class _RequestCompleted: def __init__(self, request_id: uuid.UUID): self.request_id = request_id @@ -199,25 +205,13 @@ async def start( logger.debug('TASK (start): %s', task) if task: + self._task_created.set() if task.status.state in TERMINAL_TASK_STATES: raise InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state}' ) - else: - if not create_task_if_missing: - raise TaskNotFoundError - - # New task. Create and save it so it's not "missing" if queried immediately - # (especially important for return_immediately=True) - if self._task_manager.context_id is None: - raise ValueError('Context ID is required for new tasks') - task = self._task_manager._init_task_obj( - self._task_id, - self._task_manager.context_id, - ) - await self._task_manager.save_task_event(task) - if self._push_sender: - await self._push_sender.send_notification(task.id, task) + elif not create_task_if_missing: + raise TaskNotFoundError except Exception: logger.debug( @@ -253,9 +247,9 @@ async def _run_producer(self) -> None: Runs as a detached asyncio.Task. Safe to cancel. """ logger.debug('Producer[%s]: Started', self._task_id) + request_context = None try: - active = True - while active: + while True: ( request_context, request_id, @@ -263,22 +257,11 @@ async def _run_producer(self) -> None: await self._request_lock.acquire() # TODO: Should we create task manager every time? self._task_manager._call_context = request_context.call_context + request_context.current_task = ( await self._task_manager.get_task() ) - message = request_context.message - if message: - request_context.current_task = ( - self._task_manager.update_with_message( - message, - cast('Task', request_context.current_task), - ) - ) - await self._task_manager.save_task_event( - request_context.current_task - ) - self._task_created.set() logger.debug( 'Producer[%s]: Executing agent task %s', self._task_id, @@ -286,6 +269,13 @@ async def _run_producer(self) -> None: ) try: + await self._event_queue_agent.enqueue_event( + cast( + 'Event', + _RequestStarted(request_id, request_context), + ) + ) + await self._agent_executor.execute( request_context, self._event_queue_agent ) @@ -293,32 +283,36 @@ async def _run_producer(self) -> None: 'Producer[%s]: Execution finished successfully', self._task_id, ) - except QueueShutDown: - logger.debug( - 'Producer[%s]: Request queue shut down', self._task_id - ) - raise - except asyncio.CancelledError: - logger.debug('Producer[%s]: Cancelled', self._task_id) - raise - except Exception as e: - logger.exception( - 'Producer[%s]: Execution failed', - self._task_id, - ) - async with self._lock: - await self._mark_task_as_failed(e) - active = False finally: logger.debug( 'Producer[%s]: Enqueuing request completed event', self._task_id, ) - # TODO: Hide from external consumers await self._event_queue_agent.enqueue_event( cast('Event', _RequestCompleted(request_id)) ) self._request_queue.task_done() + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + + except QueueShutDown: + logger.debug('Producer[%s]: Queue shut down', self._task_id) + + except Exception as e: + logger.exception( + 'Producer[%s]: Execution failed', + self._task_id, + ) + # Create task and mark as failed. + if request_context: + await self._task_manager.ensure_task_id( + self._task_id, + request_context.context_id or '', + ) + self._task_created.set() + async with self._lock: + await self._mark_task_as_failed(e) + finally: self._request_queue.shutdown(immediate=True) await self._event_queue_agent.close(immediate=False) @@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 `_is_finished`, unblocking all global subscribers and wait() calls. """ logger.debug('Consumer[%s]: Started', self._task_id) + task_mode = None + message_to_save = None + # TODO: Make helper methods + # TODO: Support Task enqueue try: try: try: @@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 'Consumer[%s]: Waiting for event', self._task_id, ) + new_task = None event = await self._event_queue_agent.dequeue_event() logger.debug( 'Consumer[%s]: Dequeued event %s', @@ -361,17 +360,70 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 self._task_id, ) self._request_lock.release() + elif isinstance(event, _RequestStarted): + logger.debug( + 'Consumer[%s]: Request started', + self._task_id, + ) + message_to_save = event.request_context.message + elif isinstance(event, Message): + if task_mode is not None: + if task_mode: + logger.error( + 'Received Message() object in task mode.' + ) + else: + logger.error( + 'Multiple Message() objects received.' + ) + task_mode = False logger.debug( 'Consumer[%s]: Setting result to Message: %s', self._task_id, event, ) else: + if task_mode is False: + logger.error( + 'Received %s in message mode.', + type(event).__name__, + ) + + if isinstance(event, Task): + new_task = event + await self._task_manager.save_task_event( + new_task + ) + # TODO: Avoid duplicated messages + else: + new_task = ( + await self._task_manager.ensure_task_id( + self._task_id, + event.context_id, + ) + ) + + if message_to_save is not None: + new_task = self._task_manager.update_with_message( + message_to_save, + new_task, + ) + await ( + self._task_manager.save_task_event( + new_task + ) + ) + message_to_save = None + + task_mode = True # Save structural events (like TaskStatusUpdate) to DB. - # TODO: Create task manager every time ? + self._task_manager.context_id = event.context_id - await self._task_manager.process(event) + if not isinstance(event, Task): + await self._task_manager.process(event) + + self._task_created.set() # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states new_task = await self._task_manager.get_task() @@ -379,6 +431,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 raise RuntimeError( f'Task {self.task_id} not found' ) + if isinstance(event, Task): + event = new_task is_interrupted = ( new_task.status.state in INTERRUPTED_TASK_STATES @@ -432,8 +486,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 self._task_id, event ) finally: + if new_task is not None: + new_task_copy = Task() + new_task_copy.CopyFrom(new_task) + new_task = new_task_copy + if isinstance(event, Task): + new_task_copy = Task() + new_task_copy.CopyFrom(event) + event = new_task_copy + + logger.debug( + 'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n', + self._task_id, + event, + new_task, + ) await self._event_queue_subscribers.enqueue_event( - event + cast('Any', (event, new_task)) ) self._event_queue_agent.task_done() except QueueShutDown: @@ -459,6 +528,7 @@ async def subscribe( # noqa: PLR0912, PLR0915 *, request: RequestContext | None = None, include_initial_task: bool = False, + replace_status_update_with_task: bool = False, ) -> AsyncGenerator[Event, None]: """Creates a queue tap and yields events as they are produced. @@ -506,9 +576,25 @@ async def subscribe( # noqa: PLR0912, PLR0915 # Wait for next event or task completion try: - event = await asyncio.wait_for( + dequeued = await asyncio.wait_for( tapped_queue.dequeue_event(), timeout=0.1 ) + event, updated_task = cast('Any', dequeued) + logger.debug( + 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + self._task_id, + event, + updated_task, + ) + if replace_status_update_with_task and isinstance( + event, TaskStatusUpdateEvent + ): + logger.debug( + 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s', + self._task_id, + updated_task, + ) + event = updated_task if self._exception: raise self._exception from None if isinstance(event, _RequestCompleted): @@ -522,6 +608,12 @@ async def subscribe( # noqa: PLR0912, PLR0915 ) return continue + elif isinstance(event, _RequestStarted): + logger.debug( + 'Subscriber[%s]: Request started', + self._task_id, + ) + continue except (asyncio.TimeoutError, TimeoutError): if self._is_finished.is_set(): if self._exception: @@ -545,7 +637,7 @@ async def subscribe( # noqa: PLR0912, PLR0915 # Evaluate if this was the last subscriber on a finished task. await self._maybe_cleanup() - async def cancel(self, call_context: ServerCallContext) -> Task | Message: + async def cancel(self, call_context: ServerCallContext) -> Task: """Cancels the running active task. Concurrency Guarantee: @@ -558,11 +650,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: # TODO: Conflicts with call_context on the pending request. self._task_manager._call_context = call_context - task = await self.get_task() + task = await self._task_manager.get_task() request_context = RequestContext( call_context=call_context, task_id=self._task_id, - context_id=task.context_id, + context_id=task.context_id if task else None, task=task, ) @@ -591,7 +683,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: ) await self._is_finished.wait() - return await self.get_task() + task = await self._task_manager.get_task() + if not task: + raise RuntimeError('Task should have been created') + return task async def _maybe_cleanup(self) -> None: """Triggers cleanup if task is finished and has no subscribers. diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 764bef4b2..2da8ddfd7 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -34,6 +34,9 @@ async def execute( - Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc) - Explain if execute can wait for cancel and if cancel can wait for execute. - Explain behaviour of streaming / not-immediate when execute() returns in active state. + - Possible workflows: + - Enqueue a SINGLE Message object + - Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent. Args: context: The request context containing the message, task ID, etc. diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index ccc9cdd0e..1a8464687 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -242,63 +242,56 @@ async def on_message_send( # noqa: D102 active_task, request_context = await self._setup_active_task( params, context ) + task_id = cast('str', request_context.task_id) - if params.configuration and params.configuration.return_immediately: - await active_task.enqueue_request(request_context) - - task = await active_task.get_task() - if params.configuration: - task = apply_history_length(task, params.configuration) - return task + result: Message | Task | None = None - try: - result_states = TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES - - result = None - async for event in active_task.subscribe(request=request_context): - logger.debug( - 'Processing[%s] event [%s] %s', - request_context.task_id, - type(event).__name__, - event, - ) - if isinstance(event, Message) or ( - isinstance(event, Task) - and event.status.state in result_states - ): - result = event - break - if ( - isinstance(event, TaskStatusUpdateEvent) - and event.status.state in result_states - ): - result = await self.task_store.get(event.task_id, context) - break - - if result is None: + async for raw_event in active_task.subscribe( + request=request_context, + include_initial_task=False, + replace_status_update_with_task=True, + ): + event = raw_event + logger.debug( + 'Processing[%s] event [%s] %s', + params.message.task_id, + type(event).__name__, + event, + ) + if isinstance(event, TaskStatusUpdateEvent): + self._validate_task_id_match(task_id, event.task_id) + event = await active_task.get_task() logger.debug( - 'Missing result for task %s', request_context.task_id + 'Replaced TaskStatusUpdateEvent with Task: %s', event ) - result = await active_task.get_task() - logger.debug( - 'Processing[%s] result: %s', request_context.task_id, result - ) + if isinstance(event, Task) and ( + params.configuration.return_immediately + or event.status.state + in (TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES) + ): + self._validate_task_id_match(task_id, event.id) + result = event + break + + if isinstance(event, Message): + result = event + break - except Exception: - logger.exception('Agent execution failed') - raise + if result is None: + logger.debug('Missing result for task %s', request_context.task_id) + result = await active_task.get_task() if isinstance(result, Task): - self._validate_task_id_match( - cast('str', request_context.task_id), result.id - ) - if params.configuration: - result = apply_history_length(result, params.configuration) + result = apply_history_length(result, params.configuration) + logger.debug( + 'Returning result for task %s: %s', + request_context.task_id, + result, + ) return result - # TODO: Unify with on_message_send @validate_request_params @validate( lambda self: self._agent_card.capabilities.streaming, @@ -313,19 +306,20 @@ async def on_message_send_stream( # noqa: D102 params, context ) - include_initial_task = bool( - params.configuration and params.configuration.return_immediately - ) - task_id = cast('str', request_context.task_id) async for event in active_task.subscribe( - request=request_context, include_initial_task=include_initial_task + request=request_context, + include_initial_task=False, ): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - logger.debug('Sending event [%s] %s', type(event).__name__, event) - yield event + yield apply_history_length(event, params.configuration) + else: + yield event + + if isinstance(event, Message): + break @validate_request_params @validate( diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 905b11af3..143413d5b 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -147,13 +147,12 @@ async def save_task_event( await self._save_task(task) return task - async def ensure_task( - self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ) -> Task: + async def ensure_task_id(self, task_id: str, context_id: str) -> Task: """Ensures a Task object exists in memory, loading from store or creating new if needed. Args: - event: The task-related event triggering the need for a Task object. + task_id: The ID for the new task. + context_id: The context ID for the new task. Returns: An existing or newly created `Task` object. @@ -168,16 +167,29 @@ async def ensure_task( if not task: logger.info( 'Task not found or task_id not set. Creating new task for event (task_id: %s, context_id: %s).', - event.task_id, - event.context_id, + task_id, + context_id, ) # streaming agent did not previously stream task object. # Create a task object with the available information and persist the event - task = self._init_task_obj(event.task_id, event.context_id) + task = self._init_task_obj(task_id, context_id) await self._save_task(task) return task + async def ensure_task( + self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task: + """Ensures a Task object exists in memory, loading from store or creating new if needed. + + Args: + event: The task-related event triggering the need for a Task object. + + Returns: + An existing or newly created `Task` object. + """ + return await self.ensure_task_id(event.task_id, event.context_id) + async def process(self, event: Event) -> Event: """Processes an event, updates the task state if applicable, stores it, and returns the event. diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 1e2253430..4683dc3e9 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -16,11 +16,14 @@ from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import DefaultRequestHandlerV2, GrpcHandler +from a2a.server.request_handlers import ( + DefaultRequestHandlerV2, + GrpcHandler, + GrpcServerCallContextBuilder, +) from a2a.server.request_handlers.default_request_handler import ( LegacyRequestHandler, ) -from a2a.server.request_handlers import GrpcServerCallContextBuilder from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -701,24 +704,12 @@ async def send_message_and_get_first_response(): ) return await asyncio.wait_for(it.__anext__(), timeout=0.1) - if use_legacy: - # Legacy client hangs forever. - with pytest.raises(asyncio.TimeoutError): - await send_message_and_get_first_response() - else: - event = await send_message_and_get_first_response() - task = event.task - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - (message,) = task.history - assert message.message_id == 'test-msg' + # First response should not be there yet. + with pytest.raises(asyncio.TimeoutError): + await send_message_and_get_first_response() tasks = (await client.list_tasks(ListTasksRequest())).tasks - if use_legacy: - # Legacy didn't create a task - assert len(tasks) == 0 - else: - (task,) = tasks - assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(tasks) == 0 # Scenario 17: Cancellation of a working task. @@ -1090,39 +1081,13 @@ async def cancel( ) states = [get_state(event) async for event in it] - if use_legacy: - if streaming: - assert states == [ - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] - else: - assert states == [TaskState.TASK_STATE_WORKING] - elif streaming: - assert states == [ - TaskState.TASK_STATE_SUBMITTED, - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] - else: - assert states == [TaskState.TASK_STATE_SUBMITTED] - - # Test blocking return. - it = client.send_message( - SendMessageRequest( - message=msg, - configuration=SendMessageConfiguration(return_immediately=False), - ) - ) - states = [get_state(event) async for event in it] - if streaming: assert states == [ TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_COMPLETED, ] else: - assert states == [TaskState.TASK_STATE_COMPLETED] + assert states == [TaskState.TASK_STATE_WORKING] # Scenario: Test TASK_STATE_INPUT_REQUIRED. @@ -1305,7 +1270,7 @@ async def cancel( @pytest.mark.timeout(5.0) @pytest.mark.asyncio @pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) -async def test_scenario_parallel_subscribe_attach_detach(use_legacy): +async def test_scenario_parallel_subscribe_attach_detach(use_legacy): # noqa: PLR0915 events = collections.defaultdict(asyncio.Event) class EmitAgent(AgentExecutor): @@ -1434,11 +1399,11 @@ async def collect(): await events['emitted_phase_4'].wait() def get_artifact_updates(evs): - txts = [] - for sr in evs: - if sr.HasField('artifact_update'): - txts.append([p.text for p in sr.artifact_update.artifact.parts]) - return txts + return [ + [p.text for p in sr.artifact_update.artifact.parts] + for sr in evs + if sr.HasField('artifact_update') + ] assert get_artifact_updates(await sub1_task) == [ ['artifact_1'], @@ -1459,3 +1424,137 @@ def get_artifact_updates(evs): ] monitor_task.cancel() + + +# Return message directly. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'return_immediately', + [False, True], + ids=['no_return_immediately', 'return_immediately'], +) +async def test_scenario_publish_message( + use_legacy, streaming, return_immediately +): + class MessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + Message( + task_id=context.task_id, + context_id=context.context_id, + message_id='msg-1', + role=Role.ROLE_AGENT, + parts=[Part(text='response text')], + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(MessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=return_immediately + ), + ) + ) + events = [event async for event in it] + + (event,) = events + assert event.HasField('message') + assert event.message.parts[0].text == 'response text' + + tasks = (await client.list_tasks(ListTasksRequest())).tasks + assert len(tasks) == 0 + + +# Scenario: Publish ArtifactUpdateEvent +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_publish_artifact(use_legacy, streaming): + class ArtifactAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ArtifactAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events = [event async for event in it] + + if streaming: + last_event = events[-1] + assert get_state(last_event) == TaskState.TASK_STATE_COMPLETED + + artifact_events = [e for e in events if e.HasField('artifact_update')] + assert len(artifact_events) > 0, ( + 'Bug: Streaming should return the artifact update event' + ) + assert ( + artifact_events[0].artifact_update.artifact.artifact_id == 'art-1' + ) + else: + last_event = events[-1] + assert last_event.HasField('task') + assert last_event.task.status.state == TaskState.TASK_STATE_COMPLETED + + assert len(last_event.task.artifacts) > 0, ( + 'Bug: Task should include the published artifact' + ) + assert last_event.task.artifacts[0].artifact_id == 'art-1' diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index d3cc95dc3..3a4a24ff6 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -1047,6 +1047,7 @@ async def execute_mock(req, q): assert events[0] == initial_task +@pytest.mark.timeout(1) @pytest.mark.asyncio async def test_active_task_subscribe_request_parameter(): agent_executor = Mock() diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index 605078201..d48b82461 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -1104,55 +1104,6 @@ async def test_on_message_send_limit_history(): assert task.history is not None and len(task.history) > 1 -@pytest.mark.asyncio -async def test_on_message_send_task_id_mismatch(): - mock_task_store = AsyncMock(spec=TaskStore) - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) - - context_task_id = 'context_task_id_1' - result_task_id = 'DIFFERENT_task_id_1' - - mock_request_context = MagicMock() - mock_request_context.task_id = context_task_id - mock_request_context_builder.build.return_value = mock_request_context - - request_handler = DefaultRequestHandlerV2( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - request_context_builder=mock_request_context_builder, - agent_card=create_default_agent_card(), - ) - params = SendMessageRequest( - message=Message( - role=Role.ROLE_USER, - message_id='msg_id_mismatch', - parts=[Part(text='hello')], - ) - ) - - mock_active_task = MagicMock() - mismatched_task = create_sample_task(task_id=result_task_id) - mock_active_task.wait = AsyncMock(return_value=mismatched_task) - mock_active_task.start = AsyncMock() - mock_active_task.enqueue_request = AsyncMock() - mock_active_task.get_task = AsyncMock(return_value=mismatched_task) - with ( - patch.object( - request_handler._active_task_registry, - 'get_or_create', - return_value=mock_active_task, - ), - patch( - 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', - return_value=None, - ), - ): - with pytest.raises(InternalError) as exc_info: - await request_handler.on_message_send(params, context=MagicMock()) - assert 'Task ID mismatch' in exc_info.value.message - - @pytest.mark.asyncio async def test_on_message_send_stream_task_id_mismatch(): mock_task_store = AsyncMock(spec=TaskStore)