diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index bf9e129a6..defdd5244 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -374,30 +374,33 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 await self._task_manager.process(event) # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states - res = await self._task_manager.get_task() + new_task = await self._task_manager.get_task() + if new_task is None: + raise RuntimeError( + f'Task {self.task_id} not found' + ) is_interrupted = ( - res - and res.status.state + new_task.status.state in INTERRUPTED_TASK_STATES ) is_terminal = ( - res - and res.status.state in TERMINAL_TASK_STATES + new_task.status.state + in TERMINAL_TASK_STATES ) # If we hit a breakpoint or terminal state, lock in the result. - if (is_interrupted or is_terminal) and res: + if is_interrupted or is_terminal: logger.debug( 'Consumer[%s]: Setting first result as Task (state=%s)', self._task_id, - res.status.state, + new_task.status.state, ) if is_terminal: logger.debug( 'Consumer[%s]: Reached terminal state %s', self._task_id, - res.status.state if res else 'unknown', + new_task.status.state, ) if not self._is_finished.is_set(): async with self._lock: @@ -413,7 +416,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 logger.debug( 'Consumer[%s]: Interrupted with state %s', self._task_id, - res.status.state if res else 'unknown', + new_task.status.state, ) if ( diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index a29394795..8414e2d17 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -5,7 +5,7 @@ from pydantic import ValidationError -from a2a.server.events.event_queue import Event, EventQueue, QueueShutDown +from a2a.server.events.event_queue import Event, EventQueueLegacy, QueueShutDown from a2a.types.a2a_pb2 import ( Message, Task, @@ -22,7 +22,7 @@ class EventConsumer: """Consumer to read events from the agent event queue.""" - def __init__(self, queue: EventQueue): + def __init__(self, queue: EventQueueLegacy): """Initializes the EventConsumer. Args: diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 25598d15b..bb4d7b9b4 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -92,73 +92,6 @@ async def enqueue_event(self, event: Event) -> None: Only main queue can enqueue events. Child queues can only dequeue events. """ - @abstractmethod - async def dequeue_event(self) -> Event: - """Pulls an event from the queue.""" - - @abstractmethod - def task_done(self) -> None: - """Signals that a work on dequeued event is complete.""" - - @abstractmethod - async def tap( - self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE - ) -> 'EventQueue': - """Creates a child queue that receives future events. - - Note: The tapped queue may receive some old events if the incoming event - queue is lagging behind and hasn't dispatched them yet. - """ - - @abstractmethod - async def close(self, immediate: bool = False) -> None: - """Closes the queue. - - For parent queue: it closes the main queue and all its child queues. - For child queue: it closes only child queue. - - It is safe to call it multiple times. - If immediate is True, the queue will be closed without waiting for all events to be processed. - If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). - - WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events - in any of the child sinks and the consumer has crashed without draining its queue. - It is highly recommended to wrap graceful shutdowns with a timeout, e.g., - `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. - """ - - @abstractmethod - def is_closed(self) -> bool: - """[DEPRECATED] Checks if the queue is closed. - - NOTE: Relying on this for enqueue logic introduces race conditions. - It is maintained primarily for backwards compatibility, workarounds for - Python 3.10/3.12 async queues in consumers, and for the test suite. - """ - - @abstractmethod - async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself. - - WARNING: See `__aexit__` for important deadlock risks associated with - exiting this context manager if unconsumed events remain. - """ - - @abstractmethod - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exits the async context manager, ensuring close() is called. - - WARNING: The context manager calls `close(immediate=False)` by default. - If a consumer exits the `async with` block early (e.g., due to an exception - or an explicit `break`) while unconsumed events remain in the queue, - `__aexit__` will deadlock waiting for `task_done()` to be called on those events. - """ - @trace_class(kind=SpanKind.SERVER) class EventQueueLegacy(EventQueue): @@ -180,7 +113,7 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: self._queue: AsyncQueue[Event] = _create_async_queue( maxsize=max_queue_size ) - self._children: list[EventQueue] = [] + self._children: list[EventQueueLegacy] = [] self._is_closed = False self._lock = asyncio.Lock() logger.debug('EventQueue initialized.') diff --git a/src/a2a/server/events/event_queue_v2.py b/src/a2a/server/events/event_queue_v2.py index de12c21d1..224cb8e56 100644 --- a/src/a2a/server/events/event_queue_v2.py +++ b/src/a2a/server/events/event_queue_v2.py @@ -193,19 +193,29 @@ async def enqueue_event(self, event: Event) -> None: return async def dequeue_event(self) -> Event: - """Dequeues an event from the default internal sink queue.""" + """Pulls an event from the default internal sink queue.""" if self._default_sink is None: raise ValueError('No default sink available.') return await self._default_sink.dequeue_event() def task_done(self) -> None: - """Signals that a formerly enqueued task is complete via the default internal sink queue.""" + """Signals that a work on dequeued event is complete via the default internal sink queue.""" if self._default_sink is None: raise ValueError('No default sink available.') self._default_sink.task_done() async def close(self, immediate: bool = False) -> None: - """Closes the queue for future push events and also closes all child sinks.""" + """Closes the queue and all its child sinks. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + + WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events + in any of the child sinks and the consumer has crashed without draining its queue. + It is highly recommended to wrap graceful shutdowns with a timeout, e.g., + `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. + """ logger.debug('Closing EventQueueSource: immediate=%s', immediate) async with self._lock: # No more tap() allowed. @@ -230,7 +240,12 @@ async def close(self, immediate: bool = False) -> None: ) def is_closed(self) -> bool: - """Checks if the queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def test_only_join_incoming_queue(self) -> None: @@ -238,7 +253,11 @@ async def test_only_join_incoming_queue(self) -> None: await self._join_incoming_queue() async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -247,7 +266,13 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() @@ -290,26 +315,35 @@ async def enqueue_event(self, event: Event) -> None: raise RuntimeError('Cannot enqueue to a sink-only queue') async def dequeue_event(self) -> Event: - """Dequeues an event from the sink queue.""" + """Pulls an event from the sink queue.""" logger.debug('Attempting to dequeue event (waiting).') event = await self._queue.get() logger.debug('Dequeued event: %s', event) return event def task_done(self) -> None: - """Signals that a formerly enqueued task is complete in this sink queue.""" + """Signals that a work on dequeued event is complete in this sink queue.""" logger.debug('Marking task as done in EventQueueSink.') self._queue.task_done() async def tap( self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE ) -> 'EventQueueSink': - """Taps the event queue to create a new child queue that receives future events.""" + """Creates a child queue that receives future events. + + Note: The tapped queue may receive some old events if the incoming event + queue is lagging behind and hasn't dispatched them yet. + """ # Delegate tap to the parent source so all sinks are flat under the source return await self._parent.tap(max_queue_size=max_queue_size) async def close(self, immediate: bool = False) -> None: - """Closes the child sink queue.""" + """Closes the child sink queue. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + """ logger.debug('Closing EventQueueSink.') async with self._lock: self._is_closed = True @@ -323,11 +357,20 @@ async def close(self, immediate: bool = False) -> None: await self._queue.join() def is_closed(self) -> bool: - """Checks if the sink queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -336,5 +379,11 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() diff --git a/src/a2a/server/events/in_memory_queue_manager.py b/src/a2a/server/events/in_memory_queue_manager.py index ddff52419..0beb354f9 100644 --- a/src/a2a/server/events/in_memory_queue_manager.py +++ b/src/a2a/server/events/in_memory_queue_manager.py @@ -1,6 +1,6 @@ import asyncio -from a2a.server.events.event_queue import EventQueue, EventQueueLegacy +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, QueueManager, @@ -23,10 +23,10 @@ class InMemoryQueueManager(QueueManager): def __init__(self) -> None: """Initializes the InMemoryQueueManager.""" - self._task_queue: dict[str, EventQueue] = {} + self._task_queue: dict[str, EventQueueLegacy] = {} self._lock = asyncio.Lock() - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue for a task ID. Raises: @@ -37,22 +37,22 @@ async def add(self, task_id: str, queue: EventQueue) -> None: raise TaskQueueExists self._task_queue[task_id] = queue - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID. Returns: - The `EventQueue` instance for the `task_id`, or `None` if not found. + The `EventQueueLegacy` instance for the `task_id`, or `None` if not found. """ async with self._lock: if task_id not in self._task_queue: return None return self._task_queue[task_id] - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Taps the event queue for a task ID to create a child queue. Returns: - A new child `EventQueue` instance, or `None` if the task ID is not found. + A new child `EventQueueLegacy` instance, or `None` if the task ID is not found. """ async with self._lock: if task_id not in self._task_queue: @@ -71,11 +71,11 @@ async def close(self, task_id: str) -> None: queue = self._task_queue.pop(task_id) await queue.close() - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one. Returns: - A new or child `EventQueue` instance for the `task_id`. + A new or child `EventQueueLegacy` instance for the `task_id`. """ async with self._lock: if task_id not in self._task_queue: diff --git a/src/a2a/server/events/queue_manager.py b/src/a2a/server/events/queue_manager.py index ed69aae68..b3ec204a5 100644 --- a/src/a2a/server/events/queue_manager.py +++ b/src/a2a/server/events/queue_manager.py @@ -1,21 +1,21 @@ from abc import ABC, abstractmethod -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy class QueueManager(ABC): """Interface for managing the event queue lifecycles per task.""" @abstractmethod - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue associated with a task ID.""" @abstractmethod - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID.""" @abstractmethod - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Creates a child event queue (tap) for an existing task ID.""" @abstractmethod @@ -23,7 +23,7 @@ async def close(self, task_id: str) -> None: """Closes and removes the event queue for a task ID.""" @abstractmethod - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a queue if one doesn't exist, otherwise taps the existing one.""" diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index e6b992250..fea5184d6 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -14,7 +14,6 @@ from a2a.server.events import ( Event, EventConsumer, - EventQueue, EventQueueLegacy, InMemoryQueueManager, QueueManager, @@ -241,7 +240,7 @@ async def on_cancel_task( return result async def _run_event_stream( - self, request: RequestContext, queue: EventQueue + self, request: RequestContext, queue: EventQueueLegacy ) -> None: """Runs the agent's `execute` method and closes the queue afterwards. @@ -256,7 +255,9 @@ async def _setup_message_execution( self, params: SendMessageRequest, context: ServerCallContext, - ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: + ) -> tuple[ + TaskManager, str, EventQueueLegacy, ResultAggregator, asyncio.Task + ]: """Common setup logic for both streaming and non-streaming message handling. Returns: diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index cfd315265..d7d20768b 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -49,11 +49,11 @@ def create_sample_task( @pytest.fixture def mock_event_queue(): - return AsyncMock(spec=EventQueue) + return AsyncMock(spec=EventQueueLegacy) @pytest.fixture -def event_consumer(mock_event_queue: EventQueue): +def event_consumer(mock_event_queue: EventQueueLegacy): return EventConsumer(queue=mock_event_queue) diff --git a/tests/server/events/test_inmemory_queue_manager.py b/tests/server/events/test_inmemory_queue_manager.py index b51334a95..9716b13bf 100644 --- a/tests/server/events/test_inmemory_queue_manager.py +++ b/tests/server/events/test_inmemory_queue_manager.py @@ -5,7 +5,7 @@ import pytest from a2a.server.events import InMemoryQueueManager -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, TaskQueueExists, @@ -21,7 +21,7 @@ def queue_manager(self) -> InMemoryQueueManager: @pytest.fixture def event_queue(self) -> MagicMock: """Fixture to create a mock EventQueue.""" - queue = MagicMock(spec=EventQueue) + queue = MagicMock(spec=EventQueueLegacy) # Mock the tap method to return itself queue.tap.return_value = queue @@ -119,7 +119,7 @@ async def test_create_or_tap_new_queue( task_id = 'test_task_id' result = await queue_manager.create_or_tap(task_id) - assert isinstance(result, EventQueue) + assert isinstance(result, EventQueueLegacy) assert queue_manager._task_queue[task_id] == result @pytest.mark.asyncio @@ -142,7 +142,7 @@ async def test_concurrency( """Test concurrent access to the queue manager.""" async def add_task(task_id): - queue = EventQueue() + queue = EventQueueLegacy() await queue_manager.add(task_id, queue) return task_id diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 59e965116..294f5aefe 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -22,7 +22,12 @@ SimpleRequestContextBuilder, ) from a2a.server.context import ServerCallContext -from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager +from a2a.server.events import ( + EventQueue, + EventQueueLegacy, + InMemoryQueueManager, + QueueManager, +) from a2a.server.request_handlers import ( LegacyRequestHandler as DefaultRequestHandler, ) @@ -380,7 +385,7 @@ async def test_on_cancel_task_cancels_running_agent(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -425,7 +430,7 @@ async def test_on_cancel_task_completes_during_cancellation(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -472,7 +477,7 @@ async def test_on_cancel_task_invalid_result_type(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -1452,7 +1457,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea mock_request_context_builder.build.return_value = mock_request_context # Queue used by _run_event_stream; must support close() - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler( @@ -1683,7 +1688,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(agent_card): mock_request_context.context_id = context_id mock_request_context_builder.build.return_value = mock_request_context - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler(