Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 147 additions & 25 deletions vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,46 @@


try:
from a2a.types import (
AgentCard,
TransportProtocol,
Message,
TaskIdParams,
TaskQueryParams,
)
from a2a.client import ClientConfig, ClientFactory

AgentCard = AgentCard
TransportProtocol = TransportProtocol
Message = Message
ClientConfig = ClientConfig
ClientFactory = ClientFactory
TaskIdParams = TaskIdParams
TaskQueryParams = TaskQueryParams
except (ImportError, AttributeError):
AgentCard = None
TransportProtocol = None
Message = None
ClientConfig = None
ClientFactory = None
TaskIdParams = None
TaskQueryParams = None
from a2a.utils.constants import TransportProtocol as _TpTest

_A2A_SDK_VERSION: Optional[str] = "1.0"
except ImportError:
try:
from a2a.types import TransportProtocol as _TpTest

_A2A_SDK_VERSION = "0.3"
except ImportError:
_A2A_SDK_VERSION = None

if _A2A_SDK_VERSION == "1.0":
from a2a.types import (
AgentCard,
Message,
)
from a2a.client import ClientConfig, ClientFactory
from a2a.utils.constants import TransportProtocol
from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams
elif _A2A_SDK_VERSION == "0.3":
from a2a.types import (
AgentCard,
TransportProtocol,
Message,
TaskIdParams,
TaskQueryParams,
)
from a2a.client import ClientConfig, ClientFactory
else:
AgentCard = None
TransportProtocol = None
Message = None
ClientConfig = None
ClientFactory = None
TaskIdParams = None
TaskQueryParams = None
SendMessageRequest = None
GetTaskRequest = None
CancelTaskRequest = None
GetExtendedAgentCardRequest = None

_ACTIONS_KEY = "actions"
_ACTION_APPEND = "append"
Expand Down Expand Up @@ -1737,7 +1753,7 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
return _method


def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
def _wrap_a2a_operation_v03(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
"""Wraps an Agent Engine method, creating a callable for A2A API.

Args:
Expand Down Expand Up @@ -1854,6 +1870,112 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
return _method # type: ignore[return-value]


def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
"""Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+).

Args:
method_name: The name of the Agent Engine method to call.
agent_card: The agent card JSON string to use for the A2A API call.
Example: {'name': 'Sample Agent', 'description': ( 'A helpful
assistant agent that can answer questions.' ),
'supportedInterfaces': [{ 'url': 'http://localhost:8080/a2a/rest/',
'protocolBinding': 'HTTP+JSON', 'protocolVersion': '1.0', }],
'version': '1.0.0', 'capabilities': { 'streaming': True,
'pushNotifications': False, 'extendedAgentCard': True, },
'defaultInputModes': ['text'], 'defaultOutputModes': ['text'],
'skills': [{ 'id': 'question_answer', 'name': 'Q&A Agent',
'description': ( 'A helpful assistant agent that can answer
questions.' ), 'tags': ['Question-Answer'], 'examples': [ 'Who is
leading 2025 F1 Standings?', 'Where can i find an active volcano?',
], 'inputModes': ['text'], 'outputModes': ['text'], }]}

Returns:
A callable object that executes the method on the Agent Engine via
the A2A API.
"""

async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
if not self.api_client:
raise ValueError("api_client is not initialized.")
if not self.api_resource:
raise ValueError("api_resource is not initialized.")

a2a_agent_card = AgentCard()
json_format.ParseDict(
json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True
)

if a2a_agent_card.supported_interfaces:
interface = a2a_agent_card.supported_interfaces[0]
if interface.protocol_binding != TransportProtocol.HTTP_JSON:
raise ValueError(
"Only HTTP+JSON is supported for preferred transport on agent card"
)
else:
raise ValueError("Agent card does not define any supported interfaces.")

# base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
# api_version = self.api_client._api_client._http_options.api_version
# a2a_agent_card.supported_interfaces[0].url = (
# f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
# )

config = ClientConfig(
supported_protocol_bindings=[
TransportProtocol.HTTP_JSON,
],
use_client_preference=True,
httpx_client=httpx.AsyncClient(
headers={
"Authorization": (
f"Bearer {self.api_client._api_client._credentials.token}"
)
},
timeout=(
self.api_client._api_client._http_options.timeout / 1000.0
if self.api_client._api_client._http_options.timeout
else None
),
),
)
factory = ClientFactory(config)
client = factory.create(a2a_agent_card)

context = kwargs.pop("context", None)
if context is not None:
from a2a.client.client import ClientCallContext

if not isinstance(context, ClientCallContext):
actual_context = ClientCallContext()
if hasattr(context, "state"):
actual_context.state = context.state
elif isinstance(context, dict):
actual_context.state = context
context = actual_context

req = kwargs["request"]
if method_name == "on_message_send":
response = client.send_message(req, context=context)
chunks = []
async for chunk in response:
chunks.append(chunk)
return chunks
elif method_name == "on_get_task":
return await client.get_task(req, context=context)
elif method_name == "on_cancel_task":
return await client.cancel_task(req, context=context)
elif method_name == "on_get_extended_agent_card":
return await client.get_extended_agent_card(req, context=context)
else:
raise ValueError(f"Unknown method name: {method_name}")

return _method # type: ignore[return-value]


if _A2A_SDK_VERSION != "1.0":
_wrap_a2a_operation = _wrap_a2a_operation_v03


def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
"""Converts the body of the HTTP Response message to JSON format.

Expand Down
Loading