Skip to content

Commit 0571504

Browse files
committed
DefaultRequestHandlerV2: Unification of on_message methods.
1 parent 6c807d5 commit 0571504

File tree

7 files changed

+370
-215
lines changed

7 files changed

+370
-215
lines changed

src/a2a/server/agent_execution/active_task.py

Lines changed: 150 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import uuid
77

8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING, Any, cast
99

1010
from a2a.server.agent_execution.context import RequestContext
1111

@@ -56,6 +56,12 @@
5656
}
5757

5858

59+
class _RequestStarted:
60+
def __init__(self, request_id: uuid.UUID, request_context: RequestContext):
61+
self.request_id = request_id
62+
self.request_context = request_context
63+
64+
5965
class _RequestCompleted:
6066
def __init__(self, request_id: uuid.UUID):
6167
self.request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205
logger.debug('TASK (start): %s', task)
200206

201207
if task:
208+
self._task_created.set()
202209
if task.status.state in TERMINAL_TASK_STATES:
203210
raise InvalidParamsError(
204211
message=f'Task {task.id} is in terminal state: {task.status.state}'
205212
)
206-
else:
207-
if not create_task_if_missing:
208-
raise TaskNotFoundError
209-
210-
# New task. Create and save it so it's not "missing" if queried immediately
211-
# (especially important for return_immediately=True)
212-
if self._task_manager.context_id is None:
213-
raise ValueError('Context ID is required for new tasks')
214-
task = self._task_manager._init_task_obj(
215-
self._task_id,
216-
self._task_manager.context_id,
217-
)
218-
await self._task_manager.save_task_event(task)
219-
if self._push_sender:
220-
await self._push_sender.send_notification(task.id, task)
213+
elif not create_task_if_missing:
214+
raise TaskNotFoundError
221215

222216
except Exception:
223217
logger.debug(
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247
Runs as a detached asyncio.Task. Safe to cancel.
254248
"""
255249
logger.debug('Producer[%s]: Started', self._task_id)
250+
request_context = None
256251
try:
257-
active = True
258-
while active:
252+
while True:
259253
(
260254
request_context,
261255
request_id,
262256
) = await self._request_queue.get()
263257
await self._request_lock.acquire()
264258
# TODO: Should we create task manager every time?
265259
self._task_manager._call_context = request_context.call_context
260+
266261
request_context.current_task = (
267262
await self._task_manager.get_task()
268263
)
269264

270-
message = request_context.message
271-
if message:
272-
request_context.current_task = (
273-
self._task_manager.update_with_message(
274-
message,
275-
cast('Task', request_context.current_task),
276-
)
277-
)
278-
await self._task_manager.save_task_event(
279-
request_context.current_task
280-
)
281-
self._task_created.set()
282265
logger.debug(
283266
'Producer[%s]: Executing agent task %s',
284267
self._task_id,
285268
request_context.current_task,
286269
)
287270

288271
try:
272+
await self._event_queue_agent.enqueue_event(
273+
cast(
274+
'Event',
275+
_RequestStarted(request_id, request_context),
276+
)
277+
)
278+
289279
await self._agent_executor.execute(
290280
request_context, self._event_queue_agent
291281
)
292282
logger.debug(
293283
'Producer[%s]: Execution finished successfully',
294284
self._task_id,
295285
)
296-
except QueueShutDown:
297-
logger.debug(
298-
'Producer[%s]: Request queue shut down', self._task_id
299-
)
300-
raise
301-
except asyncio.CancelledError:
302-
logger.debug('Producer[%s]: Cancelled', self._task_id)
303-
raise
304-
except Exception as e:
305-
logger.exception(
306-
'Producer[%s]: Execution failed',
307-
self._task_id,
308-
)
309-
async with self._lock:
310-
await self._mark_task_as_failed(e)
311-
active = False
312286
finally:
313287
logger.debug(
314288
'Producer[%s]: Enqueuing request completed event',
315289
self._task_id,
316290
)
317-
# TODO: Hide from external consumers
318291
await self._event_queue_agent.enqueue_event(
319292
cast('Event', _RequestCompleted(request_id))
320293
)
321294
self._request_queue.task_done()
295+
except asyncio.CancelledError:
296+
logger.debug('Producer[%s]: Cancelled', self._task_id)
297+
298+
except QueueShutDown:
299+
logger.debug('Producer[%s]: Queue shut down', self._task_id)
300+
301+
except Exception as e:
302+
logger.exception(
303+
'Producer[%s]: Execution failed',
304+
self._task_id,
305+
)
306+
# Create task and mark as failed.
307+
if request_context:
308+
await self._task_manager.ensure_task_id(
309+
self._task_id,
310+
request_context.context_id or '',
311+
)
312+
self._task_created.set()
313+
async with self._lock:
314+
await self._mark_task_as_failed(e)
315+
322316
finally:
323317
self._request_queue.shutdown(immediate=True)
324318
await self._event_queue_agent.close(immediate=False)
@@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332
`_is_finished`, unblocking all global subscribers and wait() calls.
339333
"""
340334
logger.debug('Consumer[%s]: Started', self._task_id)
335+
task_mode = None
336+
message_to_save = None
337+
# TODO: Make helper methods
338+
# TODO: Support Task enqueue
341339
try:
342340
try:
343341
try:
@@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347345
'Consumer[%s]: Waiting for event',
348346
self._task_id,
349347
)
348+
new_task = None
350349
event = await self._event_queue_agent.dequeue_event()
351350
logger.debug(
352351
'Consumer[%s]: Dequeued event %s',
@@ -361,24 +360,79 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361360
self._task_id,
362361
)
363362
self._request_lock.release()
363+
elif isinstance(event, _RequestStarted):
364+
logger.debug(
365+
'Consumer[%s]: Request started',
366+
self._task_id,
367+
)
368+
message_to_save = event.request_context.message
369+
364370
elif isinstance(event, Message):
371+
if task_mode is not None:
372+
if task_mode:
373+
logger.error(
374+
'Received Message() object in task mode.'
375+
)
376+
else:
377+
logger.error(
378+
'Multiple Message() objects received.'
379+
)
380+
task_mode = False
365381
logger.debug(
366382
'Consumer[%s]: Setting result to Message: %s',
367383
self._task_id,
368384
event,
369385
)
370386
else:
387+
if task_mode is False:
388+
logger.error(
389+
'Received %s in message mode.',
390+
type(event).__name__,
391+
)
392+
393+
if isinstance(event, Task):
394+
new_task = event
395+
await self._task_manager.save_task_event(
396+
new_task
397+
)
398+
# TODO: Avoid duplicated messages
399+
else:
400+
new_task = (
401+
await self._task_manager.ensure_task_id(
402+
self._task_id,
403+
event.context_id,
404+
)
405+
)
406+
407+
if message_to_save is not None:
408+
new_task = self._task_manager.update_with_message(
409+
message_to_save,
410+
new_task,
411+
)
412+
await (
413+
self._task_manager.save_task_event(
414+
new_task
415+
)
416+
)
417+
message_to_save = None
418+
419+
task_mode = True
371420
# Save structural events (like TaskStatusUpdate) to DB.
372-
# TODO: Create task manager every time ?
421+
373422
self._task_manager.context_id = event.context_id
374-
await self._task_manager.process(event)
423+
if not isinstance(event, Task):
424+
await self._task_manager.process(event)
425+
426+
self._task_created.set()
375427

376428
# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
377429
new_task = await self._task_manager.get_task()
378430
if new_task is None:
379431
raise RuntimeError(
380432
f'Task {self.task_id} not found'
381433
)
434+
if isinstance(event, Task):
435+
event = new_task
382436
is_interrupted = (
383437
new_task.status.state
384438
in INTERRUPTED_TASK_STATES
@@ -432,8 +486,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432486
self._task_id, event
433487
)
434488
finally:
489+
if new_task is not None:
490+
new_task_copy = Task()
491+
new_task_copy.CopyFrom(new_task)
492+
new_task = new_task_copy
493+
if isinstance(event, Task):
494+
new_task_copy = Task()
495+
new_task_copy.CopyFrom(event)
496+
event = new_task_copy
497+
498+
logger.debug(
499+
'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n',
500+
self._task_id,
501+
event,
502+
new_task,
503+
)
435504
await self._event_queue_subscribers.enqueue_event(
436-
event
505+
cast('Any', (event, new_task))
437506
)
438507
self._event_queue_agent.task_done()
439508
except QueueShutDown:
@@ -459,6 +528,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459528
*,
460529
request: RequestContext | None = None,
461530
include_initial_task: bool = False,
531+
replace_status_update_with_task: bool = False,
462532
) -> AsyncGenerator[Event, None]:
463533
"""Creates a queue tap and yields events as they are produced.
464534
@@ -506,9 +576,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506576

507577
# Wait for next event or task completion
508578
try:
509-
event = await asyncio.wait_for(
579+
dequeued = await asyncio.wait_for(
510580
tapped_queue.dequeue_event(), timeout=0.1
511581
)
582+
event, updated_task = cast('Any', dequeued)
583+
logger.debug(
584+
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
585+
self._task_id,
586+
event,
587+
updated_task,
588+
)
589+
if replace_status_update_with_task and isinstance(
590+
event, TaskStatusUpdateEvent
591+
):
592+
logger.debug(
593+
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
594+
self._task_id,
595+
updated_task,
596+
)
597+
event = updated_task
512598
if self._exception:
513599
raise self._exception from None
514600
if isinstance(event, _RequestCompleted):
@@ -522,6 +608,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522608
)
523609
return
524610
continue
611+
elif isinstance(event, _RequestStarted):
612+
logger.debug(
613+
'Subscriber[%s]: Request started',
614+
self._task_id,
615+
)
616+
continue
525617
except (asyncio.TimeoutError, TimeoutError):
526618
if self._is_finished.is_set():
527619
if self._exception:
@@ -545,7 +637,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545637
# Evaluate if this was the last subscriber on a finished task.
546638
await self._maybe_cleanup()
547639

548-
async def cancel(self, call_context: ServerCallContext) -> Task | Message:
640+
async def cancel(self, call_context: ServerCallContext) -> Task:
549641
"""Cancels the running active task.
550642
551643
Concurrency Guarantee:
@@ -558,11 +650,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558650
# TODO: Conflicts with call_context on the pending request.
559651
self._task_manager._call_context = call_context
560652

561-
task = await self.get_task()
653+
task = await self._task_manager.get_task()
562654
request_context = RequestContext(
563655
call_context=call_context,
564656
task_id=self._task_id,
565-
context_id=task.context_id,
657+
context_id=task.context_id if task else None,
566658
task=task,
567659
)
568660

@@ -591,7 +683,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591683
)
592684

593685
await self._is_finished.wait()
594-
return await self.get_task()
686+
task = await self._task_manager.get_task()
687+
if not task:
688+
raise RuntimeError('Task should have been created')
689+
return task
595690

596691
async def _maybe_cleanup(self) -> None:
597692
"""Triggers cleanup if task is finished and has no subscribers.

src/a2a/server/agent_execution/agent_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ async def execute(
3434
- Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc)
3535
- Explain if execute can wait for cancel and if cancel can wait for execute.
3636
- Explain behaviour of streaming / not-immediate when execute() returns in active state.
37+
- Possible workflows:
38+
- Enqueue a SINGLE Message object
39+
- Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent.
3740
3841
Args:
3942
context: The request context containing the message, task ID, etc.

0 commit comments

Comments
 (0)