diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index 92d91addbb..38532dd20a 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -110,30 +110,46 @@ try: - from a2a.types import ( - AgentCard, - TransportProtocol, - Message, - TaskIdParams, - TaskQueryParams, - ) - from a2a.client import ClientConfig, ClientFactory - - AgentCard = AgentCard - TransportProtocol = TransportProtocol - Message = Message - ClientConfig = ClientConfig - ClientFactory = ClientFactory - TaskIdParams = TaskIdParams - TaskQueryParams = TaskQueryParams -except (ImportError, AttributeError): - AgentCard = None - TransportProtocol = None - Message = None - ClientConfig = None - ClientFactory = None - TaskIdParams = None - TaskQueryParams = None + from a2a.utils.constants import TransportProtocol as _TpTest + + _A2A_SDK_VERSION: Optional[str] = "1.0" +except ImportError: + try: + from a2a.types import TransportProtocol as _TpTest + + _A2A_SDK_VERSION = "0.3" + except ImportError: + _A2A_SDK_VERSION = None + +if _A2A_SDK_VERSION == "1.0": + from a2a.types import ( + AgentCard, + Message, + ) + from a2a.client import ClientConfig, ClientFactory + from a2a.utils.constants import TransportProtocol + from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams +elif _A2A_SDK_VERSION == "0.3": + from a2a.types import ( + AgentCard, + TransportProtocol, + Message, + TaskIdParams, + TaskQueryParams, + ) + from a2a.client import ClientConfig, ClientFactory +else: + AgentCard = None + TransportProtocol = None + Message = None + ClientConfig = None + ClientFactory = None + TaskIdParams = None + TaskQueryParams = None + SendMessageRequest = None + GetTaskRequest = None + CancelTaskRequest = None + GetExtendedAgentCardRequest = None _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" @@ -1737,7 +1753,7 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any] return _method -def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: +def _wrap_a2a_operation_v03(method_name: str, agent_card: str) -> Callable[..., list[Any]]: """Wraps an Agent Engine method, creating a callable for A2A API. Args: @@ -1854,6 +1870,112 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] return _method # type: ignore[return-value] +def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: + """Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+). + + Args: + method_name: The name of the Agent Engine method to call. + agent_card: The agent card JSON string to use for the A2A API call. + Example: {'name': 'Sample Agent', 'description': ( 'A helpful + assistant agent that can answer questions.' ), + 'supportedInterfaces': [{ 'url': 'http://localhost:8080/a2a/rest/', + 'protocolBinding': 'HTTP+JSON', 'protocolVersion': '1.0', }], + 'version': '1.0.0', 'capabilities': { 'streaming': True, + 'pushNotifications': False, 'extendedAgentCard': True, }, + 'defaultInputModes': ['text'], 'defaultOutputModes': ['text'], + 'skills': [{ 'id': 'question_answer', 'name': 'Q&A Agent', + 'description': ( 'A helpful assistant agent that can answer + questions.' ), 'tags': ['Question-Answer'], 'examples': [ 'Who is + leading 2025 F1 Standings?', 'Where can i find an active volcano?', + ], 'inputModes': ['text'], 'outputModes': ['text'], }]} + + Returns: + A callable object that executes the method on the Agent Engine via + the A2A API. + """ + + async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + + a2a_agent_card = AgentCard() + json_format.ParseDict( + json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True + ) + + if a2a_agent_card.supported_interfaces: + interface = a2a_agent_card.supported_interfaces[0] + if interface.protocol_binding != TransportProtocol.HTTP_JSON: + raise ValueError( + "Only HTTP+JSON is supported for preferred transport on agent card" + ) + else: + raise ValueError("Agent card does not define any supported interfaces.") + + # base_url = self.api_client._api_client._http_options.base_url.rstrip("/") + # api_version = self.api_client._api_client._http_options.api_version + # a2a_agent_card.supported_interfaces[0].url = ( + # f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + # ) + + config = ClientConfig( + supported_protocol_bindings=[ + TransportProtocol.HTTP_JSON, + ], + use_client_preference=True, + httpx_client=httpx.AsyncClient( + headers={ + "Authorization": ( + f"Bearer {self.api_client._api_client._credentials.token}" + ) + }, + timeout=( + self.api_client._api_client._http_options.timeout / 1000.0 + if self.api_client._api_client._http_options.timeout + else None + ), + ), + ) + factory = ClientFactory(config) + client = factory.create(a2a_agent_card) + + context = kwargs.pop("context", None) + if context is not None: + from a2a.client.client import ClientCallContext + + if not isinstance(context, ClientCallContext): + actual_context = ClientCallContext() + if hasattr(context, "state"): + actual_context.state = context.state + elif isinstance(context, dict): + actual_context.state = context + context = actual_context + + req = kwargs["request"] + if method_name == "on_message_send": + response = client.send_message(req, context=context) + chunks = [] + async for chunk in response: + chunks.append(chunk) + return chunks + elif method_name == "on_get_task": + return await client.get_task(req, context=context) + elif method_name == "on_cancel_task": + return await client.cancel_task(req, context=context) + elif method_name == "on_get_extended_agent_card": + return await client.get_extended_agent_card(req, context=context) + else: + raise ValueError(f"Unknown method name: {method_name}") + + return _method # type: ignore[return-value] + + +if _A2A_SDK_VERSION != "1.0": + _wrap_a2a_operation = _wrap_a2a_operation_v03 + + def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]: """Converts the body of the HTTP Response message to JSON format.