Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 149 additions & 55 deletions src/a2a/server/agent_execution/active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from a2a.server.agent_execution.context import RequestContext

Expand Down Expand Up @@ -56,6 +56,12 @@
}


class _RequestStarted:
def __init__(self, request_id: uuid.UUID, request_context: RequestContext):
self.request_id = request_id
self.request_context = request_context


class _RequestCompleted:
def __init__(self, request_id: uuid.UUID):
self.request_id = request_id
Expand Down Expand Up @@ -199,25 +205,13 @@ async def start(
logger.debug('TASK (start): %s', task)

if task:
self._task_created.set()
if task.status.state in TERMINAL_TASK_STATES:
raise InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
else:
if not create_task_if_missing:
raise TaskNotFoundError

# New task. Create and save it so it's not "missing" if queried immediately
# (especially important for return_immediately=True)
if self._task_manager.context_id is None:
raise ValueError('Context ID is required for new tasks')
task = self._task_manager._init_task_obj(
self._task_id,
self._task_manager.context_id,
)
await self._task_manager.save_task_event(task)
if self._push_sender:
await self._push_sender.send_notification(task.id, task)
elif not create_task_if_missing:
raise TaskNotFoundError

except Exception:
logger.debug(
Expand Down Expand Up @@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
Runs as a detached asyncio.Task. Safe to cancel.
"""
logger.debug('Producer[%s]: Started', self._task_id)
request_context = None
try:
active = True
while active:
while True:
(
request_context,
request_id,
) = await self._request_queue.get()
await self._request_lock.acquire()
# TODO: Should we create task manager every time?
self._task_manager._call_context = request_context.call_context

request_context.current_task = (
await self._task_manager.get_task()
)

message = request_context.message
if message:
request_context.current_task = (
self._task_manager.update_with_message(
message,
cast('Task', request_context.current_task),
)
)
await self._task_manager.save_task_event(
request_context.current_task
)
self._task_created.set()
logger.debug(
'Producer[%s]: Executing agent task %s',
self._task_id,
request_context.current_task,
)

try:
await self._event_queue_agent.enqueue_event(
cast(
'Event',
_RequestStarted(request_id, request_context),
)
)

await self._agent_executor.execute(
request_context, self._event_queue_agent
)
logger.debug(
'Producer[%s]: Execution finished successfully',
self._task_id,
)
except QueueShutDown:
logger.debug(
'Producer[%s]: Request queue shut down', self._task_id
)
raise
except asyncio.CancelledError:
logger.debug('Producer[%s]: Cancelled', self._task_id)
raise
except Exception as e:
logger.exception(
'Producer[%s]: Execution failed',
self._task_id,
)
async with self._lock:
await self._mark_task_as_failed(e)
active = False
finally:
logger.debug(
'Producer[%s]: Enqueuing request completed event',
self._task_id,
)
# TODO: Hide from external consumers
await self._event_queue_agent.enqueue_event(
cast('Event', _RequestCompleted(request_id))
)
self._request_queue.task_done()
except asyncio.CancelledError:
logger.debug('Producer[%s]: Cancelled', self._task_id)

except QueueShutDown:
logger.debug('Producer[%s]: Queue shut down', self._task_id)

except Exception as e:
logger.exception(
'Producer[%s]: Execution failed',
self._task_id,
)
# Create task and mark as failed.
if request_context:
await self._task_manager.ensure_task_id(
self._task_id,
request_context.context_id or '',
)
self._task_created.set()
async with self._lock:
await self._mark_task_as_failed(e)

finally:
self._request_queue.shutdown(immediate=True)
await self._event_queue_agent.close(immediate=False)
Expand All @@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
`_is_finished`, unblocking all global subscribers and wait() calls.
"""
logger.debug('Consumer[%s]: Started', self._task_id)
task_mode = None
message_to_save = None
# TODO: Make helper methods
# TODO: Support Task enqueue
try:
try:
try:
Expand All @@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
'Consumer[%s]: Waiting for event',
self._task_id,
)
new_task = None
event = await self._event_queue_agent.dequeue_event()
logger.debug(
'Consumer[%s]: Dequeued event %s',
Expand All @@ -361,24 +360,78 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
self._task_id,
)
self._request_lock.release()
elif isinstance(event, _RequestStarted):
logger.debug(
'Consumer[%s]: Request started',
self._task_id,
)
message_to_save = event.request_context.message

elif isinstance(event, Message):
if task_mode is not None:
if task_mode:
logger.error(
'Received Message() object in task mode.'
)
else:
logger.error(
'Multiple Message() objects received.'
)
task_mode = False
logger.debug(
'Consumer[%s]: Setting result to Message: %s',
self._task_id,
event,
)
else:
if task_mode is False:
logger.error(
'Received %s in message mode.',
type(event).__name__,
)

if isinstance(event, Task):
new_task = event
await self._task_manager.save_task_event(
new_task
)
else:
Comment on lines +393 to +398
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I run tests from #954 against a new executor history assertions fail with message being duplicated. The utility method we have new_task populates first user message in the history. I am not sure if we should rely on it or maybe avoid saving message with the same id?

Either way it's something that changed from the "legacy" executor.

(maybe it's a known issue as a part of enqueue Task TODO)

new_task = (
await self._task_manager.ensure_task_id(
self._task_id,
event.context_id,
)
)

if message_to_save is not None:
new_task = self._task_manager.update_with_message(
message_to_save,
new_task,
)
await (
self._task_manager.save_task_event(
new_task
)
)
message_to_save = None

task_mode = True
# Save structural events (like TaskStatusUpdate) to DB.
# TODO: Create task manager every time ?

self._task_manager.context_id = event.context_id
await self._task_manager.process(event)
if not isinstance(event, Task):
await self._task_manager.process(event)

self._task_created.set()

# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
new_task = await self._task_manager.get_task()
if new_task is None:
raise RuntimeError(
f'Task {self.task_id} not found'
)
if isinstance(event, Task):
event = new_task
is_interrupted = (
new_task.status.state
in INTERRUPTED_TASK_STATES
Expand Down Expand Up @@ -432,8 +485,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
self._task_id, event
)
finally:
if new_task is not None:
new_task_copy = Task()
new_task_copy.CopyFrom(new_task)
new_task = new_task_copy
if isinstance(event, Task):
new_task_copy = Task()
new_task_copy.CopyFrom(event)
event = new_task_copy

logger.debug(
'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n',
self._task_id,
event,
new_task,
)
await self._event_queue_subscribers.enqueue_event(
event
cast('Any', (event, new_task))
)
self._event_queue_agent.task_done()
except QueueShutDown:
Expand All @@ -459,6 +527,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
*,
request: RequestContext | None = None,
include_initial_task: bool = False,
replace_status_update_with_task: bool = False,
) -> AsyncGenerator[Event, None]:
"""Creates a queue tap and yields events as they are produced.

Expand Down Expand Up @@ -506,9 +575,25 @@ async def subscribe( # noqa: PLR0912, PLR0915

# Wait for next event or task completion
try:
event = await asyncio.wait_for(
dequeued = await asyncio.wait_for(
tapped_queue.dequeue_event(), timeout=0.1
)
event, updated_task = cast('Any', dequeued)
logger.debug(
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
self._task_id,
event,
updated_task,
)
if replace_status_update_with_task and isinstance(
event, TaskStatusUpdateEvent
):
logger.debug(
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
self._task_id,
updated_task,
)
event = updated_task
if self._exception:
raise self._exception from None
if isinstance(event, _RequestCompleted):
Expand All @@ -522,6 +607,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
)
return
continue
elif isinstance(event, _RequestStarted):
logger.debug(
'Subscriber[%s]: Request started',
self._task_id,
)
continue
except (asyncio.TimeoutError, TimeoutError):
if self._is_finished.is_set():
if self._exception:
Expand All @@ -545,7 +636,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
# Evaluate if this was the last subscriber on a finished task.
await self._maybe_cleanup()

async def cancel(self, call_context: ServerCallContext) -> Task | Message:
async def cancel(self, call_context: ServerCallContext) -> Task:
"""Cancels the running active task.

Concurrency Guarantee:
Expand All @@ -558,11 +649,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
# TODO: Conflicts with call_context on the pending request.
self._task_manager._call_context = call_context

task = await self.get_task()
task = await self._task_manager.get_task()
request_context = RequestContext(
call_context=call_context,
task_id=self._task_id,
context_id=task.context_id,
context_id=task.context_id if task else None,
task=task,
)

Expand Down Expand Up @@ -591,7 +682,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
)

await self._is_finished.wait()
return await self.get_task()
task = await self._task_manager.get_task()
if not task:
raise RuntimeError('Task should have been created')
return task

async def _maybe_cleanup(self) -> None:
"""Triggers cleanup if task is finished and has no subscribers.
Expand Down
3 changes: 3 additions & 0 deletions src/a2a/server/agent_execution/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ async def execute(
- Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc)
- Explain if execute can wait for cancel and if cancel can wait for execute.
- Explain behaviour of streaming / not-immediate when execute() returns in active state.
- Possible workflows:
- Enqueue a SINGLE Message object
- Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent.

Args:
context: The request context containing the message, task ID, etc.
Expand Down
Loading
Loading