Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ async def _run_impl(
# combine the messages
session_messages: list[Message] = session_context.get_messages(include_input=True)

self._restore_pending_requests_from_session(provider_session)
output_events: list[WorkflowEvent[Any]] = []
async for event in self._run_core(
session_messages,
Expand All @@ -305,6 +306,7 @@ async def _run_impl(
output_events.append(event)

result = self._convert_workflow_events_to_agent_response(response_id, output_events)
self._store_pending_requests_in_session(provider_session)

# Set the response on the context so after_run providers (e.g. InMemoryHistoryProvider)
# can persist the response messages alongside input messages.
Expand All @@ -313,6 +315,20 @@ async def _run_impl(
await self._run_after_providers(session=provider_session, context=session_context)
return result

def _store_pending_requests_in_session(self, session: AgentSession | None) -> None:
"""Save pending request IDs into the session state for later restoration."""
if session is not None:
session.state["_pending_requests"] = list(self._pending_requests.keys())

def _restore_pending_requests_from_session(self, session: AgentSession | None) -> None:
"""Restore pending request IDs from the session state."""
if session is not None:
stored = session.state.get("_pending_requests")
if isinstance(stored, list):
for req_id in stored:
if req_id not in self._pending_requests:
self._pending_requests[req_id] = None # type: ignore[assignment]

async def _run_stream_impl(
self,
messages: AgentRunInputs,
Expand Down Expand Up @@ -372,6 +388,7 @@ async def _run_stream_impl(
# combine the messages

session_messages: list[Message] = session_context.get_messages(include_input=True)
self._restore_pending_requests_from_session(provider_session)
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages,
Expand All @@ -386,6 +403,8 @@ async def _run_stream_impl(
all_updates.append(update)
yield update

self._store_pending_requests_in_session(provider_session)

# Build the final response from collected updates so after_run providers
# (e.g. InMemoryHistoryProvider) can persist the response messages.
if all_updates:
Expand Down