diff --git a/itk/main.py b/itk/main.py index 22cfef2a4..7be7a5a20 100644 --- a/itk/main.py +++ b/itk/main.py @@ -12,7 +12,7 @@ from pyproto import instruction_pb2 -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, create_client from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -128,10 +128,7 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: ) try: - client = await ClientFactory.connect( - call.agent_card_uri, - client_config=config, - ) + client = await create_client(call.agent_card_uri, client_config=config) # Wrap nested instruction async with client: diff --git a/samples/cli.py b/samples/cli.py index 6a4597fa9..8515fd5a9 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -9,7 +9,7 @@ import grpc import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.client import A2ACardResolver, ClientConfig, create_client from a2a.types import Message, Part, Role, SendMessageRequest, TaskState @@ -79,7 +79,7 @@ async def main() -> None: print('\n✓ Agent Card Found:') print(f' Name: {card.name}') - client = await ClientFactory.connect(card, client_config=config) + client = await create_client(card, client_config=config) actual_transport = getattr(client, '_transport', client) print(f' Picked Transport: {actual_transport.__class__.__name__}') diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 188ab4c80..c23041f32 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -12,7 +12,11 @@ ClientCallContext, ClientConfig, ) -from a2a.client.client_factory import ClientFactory, minimal_agent_card +from a2a.client.client_factory import ( + ClientFactory, + create_client, + minimal_agent_card, +) from a2a.client.errors import ( A2AClientError, A2AClientTimeoutError, @@ -36,6 +40,7 @@ 'ClientFactory', 'CredentialService', 'InMemoryContextCredentialStore', + 'create_client', 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index c5d5e8aa4..a59189ade 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import httpx @@ -56,32 +56,35 @@ class ClientFactory: - """ClientFactory is used to generate the appropriate client for the agent. + """Factory for creating clients that communicate with A2A agents. - The factory is configured with a `ClientConfig` and optionally a list of - `Consumer`s to use for all generated `Client`s. The expected use is: - - .. code-block:: python + The factory is configured with a `ClientConfig` and optionally custom + transport producers registered via `register`. Example usage: factory = ClientFactory(config) - # Optionally register custom client implementations - factory.register('my_customer_transport', NewCustomTransportClient) - # Then with an agent card make a client with additional interceptors + # Optionally register custom transport implementations + factory.register('my_custom_transport', custom_transport_producer) + # Create a client from an AgentCard client = factory.create(card, interceptors) + # Or resolve an AgentCard from a URL and create a client + client = await factory.create_from_url('https://example.com') - Now the client can be used consistently regardless of the transport. This + The client can be used consistently regardless of the transport. This aligns the client configuration with the server's capabilities. """ def __init__( self, - config: ClientConfig, + config: ClientConfig | None = None, ): - client = config.httpx_client or httpx.AsyncClient() - client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT) - config.httpx_client = client + config = config or ClientConfig() + httpx_client = config.httpx_client or httpx.AsyncClient() + httpx_client.headers.setdefault( + VERSION_HEADER, PROTOCOL_VERSION_CURRENT + ) self._config = config + self._httpx_client = httpx_client self._registry: dict[str, TransportProducer] = {} self._register_defaults(config.supported_protocol_bindings) @@ -112,13 +115,13 @@ def jsonrpc_transport_producer( ) return CompatJsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return JsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -151,13 +154,13 @@ def rest_transport_producer( ) return CompatRestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return RestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -252,73 +255,45 @@ def _find_best_interface( return best_gt_1_0 or best_ge_0_3 or best_no_version - @classmethod - async def connect( # noqa: PLR0913 - cls, - agent: str | AgentCard, - client_config: ClientConfig | None = None, + async def create_from_url( + self, + url: str, interceptors: list[ClientCallInterceptor] | None = None, relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, - extra_transports: dict[str, TransportProducer] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> Client: - """Convenience method for constructing a client. - - Constructs a client that connects to the specified agent. Note that - creating multiple clients via this method is less efficient than - constructing an instance of ClientFactory and reusing that. - - .. code-block:: python + """Create a `Client` by resolving an `AgentCard` from a URL. - # This will search for an AgentCard at /.well-known/agent-card.json - my_agent_url = 'https://travel.agents.example.com' - client = await ClientFactory.connect(my_agent_url) + Resolves the agent card from the given URL using the factory's + configured httpx client, then creates a client via `create`. + If the agent card is already available, use `create` directly + instead. Args: - agent: The base URL of the agent, or the AgentCard to connect to. - client_config: The ClientConfig to use when connecting to the agent. - - interceptors: A list of interceptors to use for each request. These - are used for things like attaching credentials or http headers - to all outbound requests. - relative_card_path: If the agent field is a URL, this value is used as - the relative path when resolving the agent card. See - A2AAgentCardResolver.get_agent_card for more details. - resolver_http_kwargs: Dictionary of arguments to provide to the httpx - client when resolving the agent card. This value is provided to - A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. - extra_transports: Additional transport protocols to enable when - constructing the client. - signature_verifier: A callable used to verify the agent card's signatures. + url: The base URL of the agent. The agent card will be fetched + from `/.well-known/agent-card.json` by default. + interceptors: A list of interceptors to use for each request. + These are used for things like attaching credentials or http + headers to all outbound requests. + relative_card_path: The relative path when resolving the agent + card. See `A2ACardResolver.get_agent_card` for details. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. Returns: A `Client` object. """ - client_config = client_config or ClientConfig() - if isinstance(agent, str): - if not client_config.httpx_client: - async with httpx.AsyncClient() as client: - resolver = A2ACardResolver(client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - resolver = A2ACardResolver(client_config.httpx_client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - card = agent - factory = cls(client_config) - for label, generator in (extra_transports or {}).items(): - factory.register(label, generator) - return factory.create(card, interceptors) + resolver = A2ACardResolver(self._httpx_client, url) + card = await resolver.get_agent_card( + relative_card_path=relative_card_path, + http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return self.create(card, interceptors) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -389,6 +364,48 @@ def create( ) +async def create_client( # noqa: PLR0913 + agent: str | AgentCard, + client_config: ClientConfig | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + relative_card_path: str | None = None, + resolver_http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, +) -> Client: + """Create a `Client` for an agent from a URL or `AgentCard`. + + Convenience function that constructs a `ClientFactory` internally. + For reusing a factory across multiple agents or registering custom + transports, use `ClientFactory` directly instead. + + Args: + agent: The base URL of the agent, or an `AgentCard` to use + directly. + client_config: Optional `ClientConfig`. A default config is + created if not provided. + interceptors: A list of interceptors to use for each request. + relative_card_path: The relative path when resolving the agent + card. Only used when `agent` is a URL. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. + + Returns: + A `Client` object. + """ + factory = ClientFactory(client_config) + if isinstance(agent, str): + return await factory.create_from_url( + agent, + interceptors=interceptors, + relative_card_path=relative_card_path, + resolver_http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return factory.create(agent, interceptors) + + def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index a5366e0d3..b30d57d12 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -1,18 +1,16 @@ """Tests for the ClientFactory.""" -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import typing import httpx import pytest -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, ClientFactory, create_client from a2a.client.client_factory import TransportProducer from a2a.client.transports import ( JsonRpcTransport, RestTransport, - ClientTransport, ) from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( @@ -127,26 +125,27 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): factory.create(base_agent_card) -@pytest.mark.asyncio -async def test_client_factory_connect_with_agent_card( +def test_client_factory_create_with_default_config( base_agent_card: AgentCard, ): - """Verify that connect works correctly when provided with an AgentCard.""" - client = await ClientFactory.connect(base_agent_card) + """Verify that create works correctly with a default ClientConfig.""" + factory = ClientFactory() + client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_url(base_agent_card: AgentCard): - """Verify that connect works correctly when provided with a URL.""" +async def test_client_factory_create_from_url(base_agent_card: AgentCard): + """Verify that create_from_url resolves the card and creates a client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card ) agent_url = 'http://example.com' - client = await ClientFactory.connect(agent_url) + factory = ClientFactory() + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once() assert mock_resolver.call_args[0][1] == agent_url @@ -157,10 +156,10 @@ async def test_client_factory_connect_with_url(base_agent_card: AgentCard): @pytest.mark.asyncio -async def test_client_factory_connect_with_url_and_client_config( +async def test_client_factory_create_from_url_uses_factory_httpx_client( base_agent_card: AgentCard, ): - """Verify connect with a URL and a pre-configured httpx client.""" + """Verify create_from_url uses the factory's configured httpx client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -170,7 +169,8 @@ async def test_client_factory_connect_with_url_and_client_config( mock_httpx_client = httpx.AsyncClient() config = ClientConfig(httpx_client=mock_httpx_client) - client = await ClientFactory.connect(agent_url, client_config=config) + factory = ClientFactory(config) + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once_with(mock_httpx_client, agent_url) mock_resolver.return_value.get_agent_card.assert_awaited_once() @@ -180,10 +180,10 @@ async def test_client_factory_connect_with_url_and_client_config( @pytest.mark.asyncio -async def test_client_factory_connect_with_resolver_args( +async def test_client_factory_create_from_url_passes_resolver_args( base_agent_card: AgentCard, ): - """Verify connect passes resolver arguments correctly.""" + """Verify create_from_url passes resolver arguments correctly.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -193,12 +193,11 @@ async def test_client_factory_connect_with_resolver_args( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - # The resolver args are only passed if an httpx_client is provided in config config = ClientConfig(httpx_client=httpx.AsyncClient()) + factory = ClientFactory(config) - await ClientFactory.connect( + await factory.create_from_url( agent_url, - client_config=config, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) @@ -211,10 +210,10 @@ async def test_client_factory_connect_with_resolver_args( @pytest.mark.asyncio -async def test_client_factory_connect_resolver_args_without_client( +async def test_client_factory_create_from_url_with_default_config( base_agent_card: AgentCard, ): - """Verify resolver args are ignored if no httpx_client is provided.""" + """Verify create_from_url works with a default ClientConfig.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -224,12 +223,16 @@ async def test_client_factory_connect_resolver_args_without_client( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - await ClientFactory.connect( + factory = ClientFactory() + + await factory.create_from_url( agent_url, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) + # Factory always creates an httpx client, so resolver gets it + mock_resolver.assert_called_once() mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, @@ -237,16 +240,17 @@ async def test_client_factory_connect_resolver_args_without_client( ) -@pytest.mark.asyncio -async def test_client_factory_connect_with_extra_transports( +def test_client_factory_register_and_create_custom_transport( base_agent_card: AgentCard, ): - """Verify that connect can register and use extra transports.""" + """Verify that register() + create() uses custom transports.""" class CustomTransport: pass - def custom_transport_producer(*args, **kwargs): + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: return CustomTransport() base_agent_card.supported_interfaces.insert( @@ -255,27 +259,60 @@ def custom_transport_producer(*args, **kwargs): ) config = ClientConfig(supported_protocol_bindings=['custom']) - - client = await ClientFactory.connect( - base_agent_card, - client_config=config, - extra_transports=typing.cast( - dict[str, TransportProducer], {'custom': custom_transport_producer} - ), + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), ) + client = factory.create(base_agent_card) assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_interceptors( +async def test_client_factory_create_from_url_uses_registered_transports( + base_agent_card: AgentCard, +): + """Verify that create_from_url() respects custom transports from register().""" + + class CustomTransport: + pass + + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: + return CustomTransport() + + base_agent_card.supported_interfaces.insert( + 0, + AgentInterface(protocol_binding='custom', url='custom://foo'), + ) + + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + config = ClientConfig(supported_protocol_bindings=['custom']) + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), + ) + + client = await factory.create_from_url('http://example.com') + assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] + + +def test_client_factory_create_with_interceptors( base_agent_card: AgentCard, ): """Verify interceptors are passed through correctly.""" interceptor1 = MagicMock() with patch('a2a.client.client_factory.BaseClient') as mock_base_client: - await ClientFactory.connect( + factory = ClientFactory() + factory.create( base_agent_card, interceptors=[interceptor1], ) @@ -298,3 +335,44 @@ def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): assert isinstance(client._transport, TenantTransportDecorator) # type: ignore[attr-defined] assert client._transport._tenant == 'my-tenant' # type: ignore[attr-defined] assert isinstance(client._transport._base, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_agent_card(base_agent_card: AgentCard): + """Verify create_client works when given an AgentCard directly.""" + client = await create_client(base_agent_card) + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url(base_agent_card: AgentCard): + """Verify create_client resolves a URL and creates a client.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + client = await create_client('http://example.com') + + mock_resolver.assert_called_once() + assert mock_resolver.call_args[0][1] == 'http://example.com' + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url_and_config(base_agent_card: AgentCard): + """Verify create_client passes client_config to the factory.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + mock_httpx_client = httpx.AsyncClient() + config = ClientConfig(httpx_client=mock_httpx_client) + + await create_client('http://example.com', client_config=config) + + mock_resolver.assert_called_once_with( + mock_httpx_client, 'http://example.com' + ) diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py index 5a5e192cf..6630bddad 100644 --- a/tests/integration/cross_version/client_server/client_1_0.py +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -5,7 +5,7 @@ import sys from uuid import uuid4 -from a2a.client import ClientFactory, ClientConfig +from a2a.client import ClientConfig, create_client from a2a.utils import TransportProtocol from a2a.types import ( Message, @@ -80,7 +80,7 @@ async def test_send_message_sync(url, protocol_enum): config.supported_protocol_bindings = [protocol_enum] config.streaming = False - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) msg = Message( role=Role.ROLE_USER, message_id=f'sync-{uuid4()}', @@ -296,7 +296,7 @@ async def run_client(url: str, protocol: str): config.supported_protocol_bindings = [protocol_enum] config.streaming = True - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) # 1. Get Extended Agent Card server_name = await test_get_extended_agent_card(client)