Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import importlib
import json
import os
import asyncio
import re
import sys
from typing import Optional
Expand Down Expand Up @@ -308,6 +309,34 @@ async def run_async(self, *args, **kwargs):
}
)

async def run_live(self, *args, **kwargs):
from google.adk.events import event

yield event.Event(
**{
"author": "currency_exchange_agent",
"content": {
"parts": [
{
"thought_signature": b"test_signature",
"function_call": {
"args": {
"currency_date": "2025-04-03",
"currency_from": "USD",
"currency_to": "SEK",
},
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
"name": "get_exchange_rate",
},
}
],
"role": "model",
},
"id": "9aaItGK9",
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
}
)


@pytest.mark.usefixtures("google_auth_mock")
class TestAdkApp:
Expand Down Expand Up @@ -904,6 +933,62 @@ def test_span_content_capture_enabled_with_tracing(
app.set_up()
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true"

@pytest.mark.asyncio
async def test_async_bidi_stream_query(
self,
default_instrumentor_builder_mock: mock.Mock,
get_project_id_mock: mock.Mock,
):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
request_queue = asyncio.Queue()
request_dict = {
"user_id": _TEST_USER_ID,
"live_request": {
"input": "What is the exchange rate from USD to SEK?",
},
}

await request_queue.put(request_dict)
await request_queue.put(None) # Sentinel to end the stream.
events = []
async for event in app.bidi_stream_query(request_queue):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
async def test_async_bidi_stream_query_with_state(
self,
default_instrumentor_builder_mock: mock.Mock,
get_project_id_mock: mock.Mock,
):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
request_queue = asyncio.Queue()
request_dict = {
"user_id": _TEST_USER_ID,
"state": {"test_key": "test_val"},
"live_request": {
"input": "What is the exchange rate from USD to SEK?",
},
}

await request_queue.put(request_dict)
await request_queue.put(None) # Sentinel to end the stream.

with mock.patch.object(
app, "async_create_session", wraps=app.async_create_session
) as mock_create_session:
async for _ in app.bidi_stream_query(request_queue):
pass
mock_create_session.assert_called_once_with(
user_id=_TEST_USER_ID, state={"test_key": "test_val"}
)


def test_dump_event_for_json():
from google.adk.events import event
Expand Down
88 changes: 88 additions & 0 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,93 @@ async def async_search_memory(self, *, user_id: str, query: str):
query=query,
)


async def bidi_stream_query(
self,
request_queue: Any,
) -> AsyncIterable[Any]:
"""Bidi streaming query the ADK application.

Args:
request_queue:
The queue of requests to stream responses for, with the type of
asyncio.Queue[Any].

Raises:
TypeError: If the request_queue is not an asyncio.Queue instance.
ValueError: If the first request does not have a user_id.
ValidationError: If failed to convert to LiveRequest.

Yields:
The stream responses of querying the ADK application.
"""
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.live_request_queue import LiveRequestQueue
from vertexai.agent_engines import _utils

# Manual type check needed as Pydantic doesn't support asyncio.Queue.
if not isinstance(request_queue, asyncio.Queue):
raise TypeError("request_queue must be an asyncio.Queue instance.")

first_request = await request_queue.get()
user_id = first_request.get("user_id")
if not user_id:
raise ValueError("The first request must have a user_id.")

session_id = first_request.get("session_id")
run_config = first_request.get("run_config")
first_live_request = first_request.get("live_request")

if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id:
state = first_request.get("state")
session = await self.async_create_session(user_id=user_id, state=state)
session_id = session["id"] if isinstance(session, dict) else session.id
run_config = _validate_run_config(run_config)

live_request_queue = LiveRequestQueue()

if first_live_request and isinstance(first_live_request, Dict):
live_request_queue.send(LiveRequest.model_validate(first_live_request))

# Forwards live requests to the agent.
async def _forward_requests():
while True:
request = await request_queue.get()
live_request = LiveRequest.model_validate(request)
live_request_queue.send(live_request)

# Forwards events to the client.
async def _forward_events():
if run_config:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
run_config=run_config,
)
else:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
)
async for event in events_async:
yield _utils.dump_event_for_json(event)

requests_task = asyncio.create_task(_forward_requests())

try:
async for event in _forward_events():
yield event
finally:
requests_task.cancel()
try:
await requests_task
except asyncio.CancelledError:
pass

def register_operations(self) -> Dict[str, List[str]]:
"""Registers the operations of the ADK application."""
return {
Expand All @@ -1776,6 +1863,7 @@ def register_operations(self) -> Dict[str, List[str]]:
"async_stream_query",
"streaming_agent_run_with_events",
],
"bidi_stream": ["bidi_stream_query"],
}

def _telemetry_enabled(self) -> Optional[bool]:
Expand Down
Loading