55import logging
66import uuid
77
8- from typing import TYPE_CHECKING , cast
8+ from typing import TYPE_CHECKING , Any , cast
99
1010from a2a .server .agent_execution .context import RequestContext
1111
5656}
5757
5858
59+ class _RequestStarted :
60+ def __init__ (self , request_id : uuid .UUID , request_context : RequestContext ):
61+ self .request_id = request_id
62+ self .request_context = request_context
63+
64+
5965class _RequestCompleted :
6066 def __init__ (self , request_id : uuid .UUID ):
6167 self .request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205 logger .debug ('TASK (start): %s' , task )
200206
201207 if task :
208+ self ._task_created .set ()
202209 if task .status .state in TERMINAL_TASK_STATES :
203210 raise InvalidParamsError (
204211 message = f'Task { task .id } is in terminal state: { task .status .state } '
205212 )
206- else :
207- if not create_task_if_missing :
208- raise TaskNotFoundError
209-
210- # New task. Create and save it so it's not "missing" if queried immediately
211- # (especially important for return_immediately=True)
212- if self ._task_manager .context_id is None :
213- raise ValueError ('Context ID is required for new tasks' )
214- task = self ._task_manager ._init_task_obj (
215- self ._task_id ,
216- self ._task_manager .context_id ,
217- )
218- await self ._task_manager .save_task_event (task )
219- if self ._push_sender :
220- await self ._push_sender .send_notification (task .id , task )
213+ elif not create_task_if_missing :
214+ raise TaskNotFoundError
221215
222216 except Exception :
223217 logger .debug (
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247 Runs as a detached asyncio.Task. Safe to cancel.
254248 """
255249 logger .debug ('Producer[%s]: Started' , self ._task_id )
250+ request_context = None
256251 try :
257- active = True
258- while active :
252+ while True :
259253 (
260254 request_context ,
261255 request_id ,
262256 ) = await self ._request_queue .get ()
263257 await self ._request_lock .acquire ()
264258 # TODO: Should we create task manager every time?
265259 self ._task_manager ._call_context = request_context .call_context
260+
266261 request_context .current_task = (
267262 await self ._task_manager .get_task ()
268263 )
269264
270- message = request_context .message
271- if message :
272- request_context .current_task = (
273- self ._task_manager .update_with_message (
274- message ,
275- cast ('Task' , request_context .current_task ),
276- )
277- )
278- await self ._task_manager .save_task_event (
279- request_context .current_task
280- )
281- self ._task_created .set ()
282265 logger .debug (
283266 'Producer[%s]: Executing agent task %s' ,
284267 self ._task_id ,
285268 request_context .current_task ,
286269 )
287270
288271 try :
272+ await self ._event_queue_agent .enqueue_event (
273+ cast (
274+ 'Event' ,
275+ _RequestStarted (request_id , request_context ),
276+ )
277+ )
278+
289279 await self ._agent_executor .execute (
290280 request_context , self ._event_queue_agent
291281 )
292282 logger .debug (
293283 'Producer[%s]: Execution finished successfully' ,
294284 self ._task_id ,
295285 )
296- except QueueShutDown :
297- logger .debug (
298- 'Producer[%s]: Request queue shut down' , self ._task_id
299- )
300- raise
301- except asyncio .CancelledError :
302- logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
303- raise
304- except Exception as e :
305- logger .exception (
306- 'Producer[%s]: Execution failed' ,
307- self ._task_id ,
308- )
309- async with self ._lock :
310- await self ._mark_task_as_failed (e )
311- active = False
312286 finally :
313287 logger .debug (
314288 'Producer[%s]: Enqueuing request completed event' ,
315289 self ._task_id ,
316290 )
317- # TODO: Hide from external consumers
318291 await self ._event_queue_agent .enqueue_event (
319292 cast ('Event' , _RequestCompleted (request_id ))
320293 )
321294 self ._request_queue .task_done ()
295+ except asyncio .CancelledError :
296+ logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
297+
298+ except QueueShutDown :
299+ logger .debug ('Producer[%s]: Queue shut down' , self ._task_id )
300+
301+ except Exception as e :
302+ logger .exception (
303+ 'Producer[%s]: Execution failed' ,
304+ self ._task_id ,
305+ )
306+ # Create task and mark as failed.
307+ if request_context :
308+ await self ._task_manager .ensure_task_id (
309+ self ._task_id ,
310+ request_context .context_id or '' ,
311+ )
312+ self ._task_created .set ()
313+ async with self ._lock :
314+ await self ._mark_task_as_failed (e )
315+
322316 finally :
323317 self ._request_queue .shutdown (immediate = True )
324318 await self ._event_queue_agent .close (immediate = False )
@@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332 `_is_finished`, unblocking all global subscribers and wait() calls.
339333 """
340334 logger .debug ('Consumer[%s]: Started' , self ._task_id )
335+ task_mode = None
336+ message_to_save = None
337+ # TODO: Make helper methods
338+ # TODO: Support Task enqueue
341339 try :
342340 try :
343341 try :
@@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347345 'Consumer[%s]: Waiting for event' ,
348346 self ._task_id ,
349347 )
348+ new_task = None
350349 event = await self ._event_queue_agent .dequeue_event ()
351350 logger .debug (
352351 'Consumer[%s]: Dequeued event %s' ,
@@ -361,24 +360,79 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361360 self ._task_id ,
362361 )
363362 self ._request_lock .release ()
363+ elif isinstance (event , _RequestStarted ):
364+ logger .debug (
365+ 'Consumer[%s]: Request started' ,
366+ self ._task_id ,
367+ )
368+ message_to_save = event .request_context .message
369+
364370 elif isinstance (event , Message ):
371+ if task_mode is not None :
372+ if task_mode :
373+ logger .error (
374+ 'Received Message() object in task mode.'
375+ )
376+ else :
377+ logger .error (
378+ 'Multiple Message() objects received.'
379+ )
380+ task_mode = False
365381 logger .debug (
366382 'Consumer[%s]: Setting result to Message: %s' ,
367383 self ._task_id ,
368384 event ,
369385 )
370386 else :
387+ if task_mode is False :
388+ logger .error (
389+ 'Received %s in message mode.' ,
390+ type (event ).__name__ ,
391+ )
392+
393+ if isinstance (event , Task ):
394+ new_task = event
395+ await self ._task_manager .save_task_event (
396+ new_task
397+ )
398+ # TODO: Avoid duplicated messages
399+ else :
400+ new_task = (
401+ await self ._task_manager .ensure_task_id (
402+ self ._task_id ,
403+ event .context_id ,
404+ )
405+ )
406+
407+ if message_to_save is not None :
408+ new_task = self ._task_manager .update_with_message (
409+ message_to_save ,
410+ new_task ,
411+ )
412+ await (
413+ self ._task_manager .save_task_event (
414+ new_task
415+ )
416+ )
417+ message_to_save = None
418+
419+ task_mode = True
371420 # Save structural events (like TaskStatusUpdate) to DB.
372- # TODO: Create task manager every time ?
421+
373422 self ._task_manager .context_id = event .context_id
374- await self ._task_manager .process (event )
423+ if not isinstance (event , Task ):
424+ await self ._task_manager .process (event )
425+
426+ self ._task_created .set ()
375427
376428 # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
377429 new_task = await self ._task_manager .get_task ()
378430 if new_task is None :
379431 raise RuntimeError (
380432 f'Task { self .task_id } not found'
381433 )
434+ if isinstance (event , Task ):
435+ event = new_task
382436 is_interrupted = (
383437 new_task .status .state
384438 in INTERRUPTED_TASK_STATES
@@ -432,8 +486,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432486 self ._task_id , event
433487 )
434488 finally :
489+ if new_task is not None :
490+ new_task_copy = Task ()
491+ new_task_copy .CopyFrom (new_task )
492+ new_task = new_task_copy
493+ if isinstance (event , Task ):
494+ new_task_copy = Task ()
495+ new_task_copy .CopyFrom (event )
496+ event = new_task_copy
497+
498+ logger .debug (
499+ 'Consumer[%s]: Enqueuing\n Event: %s\n New Task: %s\n ' ,
500+ self ._task_id ,
501+ event ,
502+ new_task ,
503+ )
435504 await self ._event_queue_subscribers .enqueue_event (
436- event
505+ cast ( 'Any' , ( event , new_task ))
437506 )
438507 self ._event_queue_agent .task_done ()
439508 except QueueShutDown :
@@ -459,6 +528,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459528 * ,
460529 request : RequestContext | None = None ,
461530 include_initial_task : bool = False ,
531+ replace_status_update_with_task : bool = False ,
462532 ) -> AsyncGenerator [Event , None ]:
463533 """Creates a queue tap and yields events as they are produced.
464534
@@ -506,9 +576,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506576
507577 # Wait for next event or task completion
508578 try :
509- event = await asyncio .wait_for (
579+ dequeued = await asyncio .wait_for (
510580 tapped_queue .dequeue_event (), timeout = 0.1
511581 )
582+ event , updated_task = cast ('Any' , dequeued )
583+ logger .debug (
584+ 'Subscriber[%s]\n Dequeued event %s\n Updated task %s\n ' ,
585+ self ._task_id ,
586+ event ,
587+ updated_task ,
588+ )
589+ if replace_status_update_with_task and isinstance (
590+ event , TaskStatusUpdateEvent
591+ ):
592+ logger .debug (
593+ 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s' ,
594+ self ._task_id ,
595+ updated_task ,
596+ )
597+ event = updated_task
512598 if self ._exception :
513599 raise self ._exception from None
514600 if isinstance (event , _RequestCompleted ):
@@ -522,6 +608,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522608 )
523609 return
524610 continue
611+ elif isinstance (event , _RequestStarted ):
612+ logger .debug (
613+ 'Subscriber[%s]: Request started' ,
614+ self ._task_id ,
615+ )
616+ continue
525617 except (asyncio .TimeoutError , TimeoutError ):
526618 if self ._is_finished .is_set ():
527619 if self ._exception :
@@ -545,7 +637,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545637 # Evaluate if this was the last subscriber on a finished task.
546638 await self ._maybe_cleanup ()
547639
548- async def cancel (self , call_context : ServerCallContext ) -> Task | Message :
640+ async def cancel (self , call_context : ServerCallContext ) -> Task :
549641 """Cancels the running active task.
550642
551643 Concurrency Guarantee:
@@ -558,11 +650,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558650 # TODO: Conflicts with call_context on the pending request.
559651 self ._task_manager ._call_context = call_context
560652
561- task = await self .get_task ()
653+ task = await self ._task_manager . get_task ()
562654 request_context = RequestContext (
563655 call_context = call_context ,
564656 task_id = self ._task_id ,
565- context_id = task .context_id ,
657+ context_id = task .context_id if task else None ,
566658 task = task ,
567659 )
568660
@@ -591,7 +683,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591683 )
592684
593685 await self ._is_finished .wait ()
594- return await self .get_task ()
686+ task = await self ._task_manager .get_task ()
687+ if not task :
688+ raise RuntimeError ('Task should have been created' )
689+ return task
595690
596691 async def _maybe_cleanup (self ) -> None :
597692 """Triggers cleanup if task is finished and has no subscribers.
0 commit comments