From d4f3c0b4eeeb71ee2b3bf9b449ae9b4b6ab6fe97 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 14:54:38 +0000 Subject: [PATCH 1/7] test --- GEMINI.md | 2 +- pyproject.toml | 2 +- src/a2a/compat/v0_3/grpc_handler.py | 14 +- src/a2a/compat/v0_3/jsonrpc_adapter.py | 17 +- src/a2a/compat/v0_3/rest_adapter.py | 11 +- src/a2a/server/request_handlers/__init__.py | 4 + .../server/request_handlers/grpc_handler.py | 58 ++++--- src/a2a/server/routes/__init__.py | 10 +- src/a2a/server/routes/common.py | 85 ++++++++++ src/a2a/server/routes/jsonrpc_dispatcher.py | 79 ++------- src/a2a/server/routes/jsonrpc_routes.py | 12 +- src/a2a/server/routes/rest_dispatcher.py | 17 +- src/a2a/server/routes/rest_routes.py | 8 +- tests/extensions/__init__.py | 0 tests/server/routes/__init__.py | 0 tests/server/routes/test_common.py | 156 ++++++++++++++++++ .../server/routes/test_jsonrpc_dispatcher.py | 39 +---- tests/server/routes/test_rest_dispatcher.py | 1 - 18 files changed, 340 insertions(+), 175 deletions(-) create mode 100644 src/a2a/server/routes/common.py create mode 100644 tests/extensions/__init__.py create mode 100644 tests/server/routes/__init__.py create mode 100644 tests/server/routes/test_common.py diff --git a/GEMINI.md b/GEMINI.md index aaab0bf66..59ef64713 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -8,7 +8,7 @@ - **Language**: Python 3.10+ - **Package Manager**: `uv` -- **Lead Transports**: FastAPI (REST/JSON-RPC), gRPC +- **Lead Transports**: Starlette (REST/JSON-RPC), gRPC - **Data Layer**: SQLAlchemy (SQL), Pydantic (Logic/Legacy), Protobuf (Modern Messaging) - **Key Directories**: - `/src`: Core implementation logic. diff --git a/pyproject.toml b/pyproject.toml index ac2083b16..7dc53ef8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ ] [project.optional-dependencies] -http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"] +http-server = ["sse-starlette", "starlette"] encryption = ["cryptography>=43.0.0"] grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio-status>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index eb72cf76b..c9db99557 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -23,8 +23,8 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.grpc_handler import ( _ERROR_CODE_MAP, - CallContextBuilder, - DefaultCallContextBuilder, + DefaultGrpcServerCallContextBuilder, + GrpcServerCallContextBuilder, ) from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import AgentCard @@ -44,7 +44,7 @@ def __init__( self, agent_card: AgentCard, request_handler: RequestHandler, - context_builder: CallContextBuilder | None = None, + context_builder: GrpcServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, ): @@ -61,7 +61,9 @@ def __init__( """ self.agent_card = agent_card self.handler03 = RequestHandler03(request_handler=request_handler) - self.context_builder = context_builder or DefaultCallContextBuilder() + self._context_builder = ( + context_builder or DefaultGrpcServerCallContextBuilder() + ) self.card_modifier = card_modifier async def _handle_unary( @@ -72,7 +74,7 @@ async def _handle_unary( ) -> TResponse: """Centralized error handling and context management for unary calls.""" try: - server_context = self.context_builder.build(context) + server_context = self._context_builder.build(context) result = await handler_func(server_context) self._set_extension_metadata(context, server_context) except A2AError as e: @@ -88,7 +90,7 @@ async def _handle_stream( ) -> AsyncIterable[TResponse]: """Centralized error handling and context management for streaming calls.""" try: - server_context = self.context_builder.build(context) + server_context = self._context_builder.build(context) async for item in handler_func(server_context): yield item self._set_extension_metadata(context, server_context) diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index d9d698411..d01a7e11c 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -11,7 +11,6 @@ from starlette.requests import Request from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.server.routes import CallContextBuilder from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True @@ -38,6 +37,10 @@ from a2a.server.jsonrpc_models import ( JSONRPCError as CoreJSONRPCError, ) +from a2a.server.routes.common import ( + DefaultServerCallContextBuilder, + ServerCallContextBuilder, +) from a2a.utils import constants from a2a.utils.errors import ExtendedAgentCardNotConfiguredError from a2a.utils.helpers import maybe_await, validate_version @@ -67,7 +70,7 @@ def __init__( # noqa: PLR0913 agent_card: 'AgentCard', http_handler: 'RequestHandler', extended_agent_card: 'AgentCard | None' = None, - context_builder: 'CallContextBuilder | None' = None, + context_builder: 'ServerCallContextBuilder | None' = None, card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): @@ -78,7 +81,9 @@ def __init__( # noqa: PLR0913 self.handler = RequestHandler03( request_handler=http_handler, ) - self._context_builder = context_builder + self._context_builder = ( + context_builder or DefaultServerCallContextBuilder() + ) def supports_method(self, method: str) -> bool: """Returns True if the v0.3 adapter supports the given method name.""" @@ -126,11 +131,7 @@ async def handle_request( CoreInvalidRequestError(data=str(e)), ) - call_context = ( - self._context_builder.build(request) - if self._context_builder - else ServerCallContext() - ) + call_context = self._context_builder.build(request) call_context.tenant = ( getattr(specific_request.params, 'tenant', '') if hasattr(specific_request, 'params') diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index 76b1ce4d1..27aba2aad 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -34,7 +34,10 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler from a2a.server.context import ServerCallContext -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.server.routes.common import ( + DefaultServerCallContextBuilder, + ServerCallContextBuilder, +) from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, @@ -60,7 +63,7 @@ def __init__( # noqa: PLR0913 agent_card: 'AgentCard', http_handler: 'RequestHandler', extended_agent_card: 'AgentCard | None' = None, - context_builder: 'CallContextBuilder | None' = None, + context_builder: 'ServerCallContextBuilder | None' = None, card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): @@ -71,7 +74,9 @@ def __init__( # noqa: PLR0913 self.handler = REST03Handler( agent_card=agent_card, request_handler=http_handler ) - self._context_builder = context_builder or DefaultCallContextBuilder() + self._context_builder = ( + context_builder or DefaultServerCallContextBuilder() + ) @rest_error_handler async def _handle_request( diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index f239af3e6..194e81a45 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -19,7 +19,9 @@ try: from a2a.server.request_handlers.grpc_handler import ( + DefaultGrpcServerCallContextBuilder, GrpcHandler, # type: ignore + GrpcServerCallContextBuilder, ) except ImportError as e: _original_error = e @@ -39,8 +41,10 @@ def __init__(self, *args, **kwargs): __all__ = [ + 'DefaultGrpcServerCallContextBuilder', 'DefaultRequestHandler', 'GrpcHandler', + 'GrpcServerCallContextBuilder', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index c354e097e..60aa41d22 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -24,7 +24,7 @@ import a2a.types.a2a_pb2_grpc as a2a_grpc from a2a import types -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, get_requested_extensions, @@ -41,15 +41,32 @@ logger = logging.getLogger(__name__) -# For now we use a trivial wrapper on the grpc context object - -class CallContextBuilder(ABC): - """A class for building ServerCallContexts using the Starlette Request.""" +class GrpcServerCallContextBuilder(ABC): + """Interface for building ServerCallContext from gRPC context.""" @abstractmethod def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: - """Builds a ServerCallContext from a gRPC Request.""" + """Builds a ServerCallContext from a gRPC ServicerContext.""" + + +class DefaultGrpcServerCallContextBuilder(GrpcServerCallContextBuilder): + """Default implementation of GrpcServerCallContextBuilder.""" + + def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext from a gRPC ServicerContext.""" + state = {'grpc_context': context} + return ServerCallContext( + user=self.build_user(context), + state=state, + requested_extensions=get_requested_extensions( + _get_metadata_value(context, HTTP_EXTENSION_HEADER) + ), + ) + + def build_user(self, context: grpc.aio.ServicerContext) -> User: + """Builds a User from a gRPC ServicerContext.""" + return UnauthenticatedUser() def _get_metadata_value( @@ -67,22 +84,6 @@ def _get_metadata_value( ] -class DefaultCallContextBuilder(CallContextBuilder): - """A default implementation of CallContextBuilder.""" - - def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: - """Builds the ServerCallContext.""" - user = UnauthenticatedUser() - state = {'grpc_context': context} - return ServerCallContext( - user=user, - state=state, - requested_extensions=get_requested_extensions( - _get_metadata_value(context, HTTP_EXTENSION_HEADER) - ), - ) - - _ERROR_CODE_MAP = { types.InvalidRequestError: grpc.StatusCode.INVALID_ARGUMENT, types.MethodNotFoundError: grpc.StatusCode.NOT_FOUND, @@ -110,7 +111,7 @@ def __init__( self, agent_card: AgentCard, request_handler: RequestHandler, - context_builder: CallContextBuilder | None = None, + context_builder: GrpcServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, ): @@ -120,14 +121,17 @@ def __init__( agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. - context_builder: The CallContextBuilder object. If none the - DefaultCallContextBuilder is used. + context_builder: The GrpcContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None the + DefaultGrpcContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. """ self.agent_card = agent_card self.request_handler = request_handler - self.context_builder = context_builder or DefaultCallContextBuilder() + self._context_builder = ( + context_builder or DefaultGrpcServerCallContextBuilder() + ) self.card_modifier = card_modifier async def _handle_unary( @@ -451,6 +455,6 @@ def _build_call_context( context: grpc.aio.ServicerContext, request: message.Message, ) -> ServerCallContext: - server_context = self.context_builder.build(context) + server_context = self._context_builder.build(context) server_context.tenant = getattr(request, 'tenant', '') return server_context diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index bb6ae0ba1..007e2722f 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -1,17 +1,17 @@ """A2A Routes.""" from a2a.server.routes.agent_card_routes import create_agent_card_routes -from a2a.server.routes.jsonrpc_dispatcher import ( - CallContextBuilder, - DefaultCallContextBuilder, +from a2a.server.routes.common import ( + DefaultServerCallContextBuilder, + ServerCallContextBuilder, ) from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes __all__ = [ - 'CallContextBuilder', - 'DefaultCallContextBuilder', + 'DefaultServerCallContextBuilder', + 'ServerCallContextBuilder', 'create_agent_card_routes', 'create_jsonrpc_routes', 'create_rest_routes', diff --git a/src/a2a/server/routes/common.py b/src/a2a/server/routes/common.py new file mode 100644 index 000000000..18b6865c5 --- /dev/null +++ b/src/a2a/server/routes/common.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from starlette.authentication import BaseUser + from starlette.requests import Request +else: + try: + from starlette.authentication import BaseUser + from starlette.requests import Request + except ImportError: + Request = Any + BaseUser = Any + +from a2a.auth.user import UnauthenticatedUser, User +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) +from a2a.server.context import ServerCallContext + + +class StarletteUser(User): + """Adapts a Starlette BaseUser to the A2A User interface.""" + + def __init__(self, user: BaseUser): + self._user = user + + @property + def is_authenticated(self) -> bool: + """Returns whether the current user is authenticated.""" + return self._user.is_authenticated + + @property + def user_name(self) -> str: + """Returns the user name of the current user.""" + return self._user.display_name + + +class ServerCallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, request: Request) -> ServerCallContext: + """Builds a ServerCallContext from a Starlette Request.""" + + +class DefaultServerCallContextBuilder(ServerCallContextBuilder): + """A default implementation of ServerCallContextBuilder.""" + + def build(self, request: Request) -> ServerCallContext: + """Builds a ServerCallContext from a Starlette Request. + + Args: + request: The incoming Starlette Request object. + + Returns: + A ServerCallContext instance populated with user and state + information from the request. + """ + state = {} + if 'auth' in request.scope: + state['auth'] = request.auth + state['headers'] = dict(request.headers) + return ServerCallContext( + user=self.build_user(request), + state=state, + requested_extensions=get_requested_extensions( + request.headers.getlist(HTTP_EXTENSION_HEADER) + ), + ) + + def build_user(self, request: Request) -> User: + """Builds a User from a Starlette Request. + + Args: + request: The incoming Starlette Request object. + + Returns: + A User instance populated with user information from the request. + """ + if 'user' in request.scope: + return StarletteUser(request.user) + return UnauthenticatedUser() diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 6bd326c8e..2739f53db 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -4,19 +4,15 @@ import logging import traceback -from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any from google.protobuf.json_format import MessageToDict, ParseDict from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response -from a2a.auth.user import UnauthenticatedUser -from a2a.auth.user import User as A2AUser from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, - get_requested_extensions, ) from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -31,6 +27,10 @@ from a2a.server.request_handlers.response_helpers import ( build_error_response, ) +from a2a.server.routes.common import ( + DefaultServerCallContextBuilder, + ServerCallContextBuilder, +) from a2a.types import A2ARequest from a2a.types.a2a_pb2 import ( AgentCard, @@ -63,10 +63,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from fastapi import FastAPI from sse_starlette.sse import EventSourceResponse - from starlette.applications import Starlette - from starlette.authentication import BaseUser from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -81,11 +78,8 @@ _package_starlette_installed = True else: - FastAPI = Any try: from sse_starlette.sse import EventSourceResponse - from starlette.applications import Starlette - from starlette.authentication import BaseUser from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -104,8 +98,6 @@ # Provide placeholder types for runtime type hinting when dependencies are not installed. # These will not be used if the code path that needs them is guarded by _http_server_installed. EventSourceResponse = Any - Starlette = Any - BaseUser = Any HTTPException = Any Request = Any JSONResponse = Any @@ -113,59 +105,6 @@ HTTP_413_CONTENT_TOO_LARGE = Any -class StarletteUserProxy(A2AUser): - """Adapts the Starlette User class to the A2A user representation.""" - - def __init__(self, user: BaseUser): - self._user = user - - @property - def is_authenticated(self) -> bool: - """Returns whether the current user is authenticated.""" - return self._user.is_authenticated - - @property - def user_name(self) -> str: - """Returns the user name of the current user.""" - return self._user.display_name - - -class CallContextBuilder(ABC): - """A class for building ServerCallContexts using the Starlette Request.""" - - @abstractmethod - def build(self, request: Request) -> ServerCallContext: - """Builds a ServerCallContext from a Starlette Request.""" - - -class DefaultCallContextBuilder(CallContextBuilder): - """A default implementation of CallContextBuilder.""" - - def build(self, request: Request) -> ServerCallContext: - """Builds a ServerCallContext from a Starlette Request. - - Args: - request: The incoming Starlette Request object. - - Returns: - A ServerCallContext instance populated with user and state - information from the request. - """ - user: A2AUser = UnauthenticatedUser() - state = {} - if 'user' in request.scope: - user = StarletteUserProxy(request.user) - state['auth'] = request.auth - state['headers'] = dict(request.headers) - return ServerCallContext( - user=user, - state=state, - requested_extensions=get_requested_extensions( - request.headers.getlist(HTTP_EXTENSION_HEADER) - ), - ) - - @trace_class(kind=SpanKind.SERVER) class JsonRpcDispatcher: """Base class for A2A JSONRPC applications. @@ -197,7 +136,7 @@ def __init__( # noqa: PLR0913 agent_card: AgentCard, request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, + context_builder: ServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, extended_card_modifier: Callable[ @@ -214,9 +153,9 @@ def __init__( # noqa: PLR0913 requests via http. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the + context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the - DefaultCallContextBuilder is used. + DefaultServerCallContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. extended_card_modifier: An optional callback to dynamically modify @@ -236,7 +175,9 @@ def __init__( # noqa: PLR0913 self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self._context_builder = context_builder or DefaultCallContextBuilder() + self._context_builder = ( + context_builder or DefaultServerCallContextBuilder() + ) self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter: JSONRPC03Adapter | None = None diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index a71a02b2d..f19625379 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -19,10 +19,8 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes.jsonrpc_dispatcher import ( - CallContextBuilder, - JsonRpcDispatcher, -) +from a2a.server.routes.common import ServerCallContextBuilder +from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher from a2a.types.a2a_pb2 import AgentCard @@ -31,7 +29,7 @@ def create_jsonrpc_routes( # noqa: PLR0913 request_handler: RequestHandler, rpc_url: str, extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, + context_builder: ServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, extended_card_modifier: Callable[ @@ -53,9 +51,9 @@ def create_jsonrpc_routes( # noqa: PLR0913 rpc_url: The URL prefix for the RPC endpoints. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the + context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the - DefaultCallContextBuilder is used. + DefaultServerCallContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. extended_card_modifier: An optional callback to dynamically modify diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 768315086..1f91dd573 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -8,7 +8,10 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.server.routes.common import ( + DefaultServerCallContextBuilder, + ServerCallContextBuilder, +) from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( AgentCard, @@ -68,7 +71,7 @@ def __init__( # noqa: PLR0913 agent_card: AgentCard, request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, + context_builder: ServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, extended_card_modifier: Callable[ @@ -83,9 +86,9 @@ def __init__( # noqa: PLR0913 request_handler: The underlying `RequestHandler` instance to delegate requests to. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the request_handler. If None, no - ServerCallContext is passed. + context_builder: The ServerCallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None the + DefaultServerCallContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. extended_card_modifier: An optional callback to dynamically modify @@ -103,7 +106,9 @@ def __init__( # noqa: PLR0913 self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self._context_builder = context_builder or DefaultCallContextBuilder() + self._context_builder = ( + context_builder or DefaultServerCallContextBuilder() + ) self.request_handler = request_handler def _build_call_context(self, request: Request) -> ServerCallContext: diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 5d0cfcfc8..89ba63b8e 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -6,7 +6,7 @@ from a2a.compat.v0_3.rest_adapter import REST03Adapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes import CallContextBuilder +from a2a.server.routes.common import ServerCallContextBuilder from a2a.server.routes.rest_dispatcher import RestDispatcher from a2a.types.a2a_pb2 import ( AgentCard, @@ -46,7 +46,7 @@ def create_rest_routes( # noqa: PLR0913 agent_card: AgentCard, request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, + context_builder: ServerCallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None = None, extended_card_modifier: Callable[ @@ -64,9 +64,9 @@ def create_rest_routes( # noqa: PLR0913 requests via http. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the + context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the - DefaultCallContextBuilder is used. + DefaultServerCallContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. extended_card_modifier: An optional callback to dynamically modify diff --git a/tests/extensions/__init__.py b/tests/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/routes/__init__.py b/tests/server/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/routes/test_common.py b/tests/server/routes/test_common.py new file mode 100644 index 000000000..3c4a08d2b --- /dev/null +++ b/tests/server/routes/test_common.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest +from starlette.datastructures import Headers + +try: + from starlette.authentication import BaseUser as StarletteBaseUser +except ImportError: + StarletteBaseUser = MagicMock() # type: ignore + +from a2a.auth.user import UnauthenticatedUser +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.server.context import ServerCallContext +from a2a.server.routes.common import ( + StarletteUser, + DefaultServerCallContextBuilder, +) + + +# --- StarletteUser Tests --- + + +class TestStarletteUser: + def test_is_authenticated_true(self): + starlette_user = MagicMock(spec=StarletteBaseUser) + starlette_user.is_authenticated = True + proxy = StarletteUser(starlette_user) + assert proxy.is_authenticated is True + + def test_is_authenticated_false(self): + starlette_user = MagicMock(spec=StarletteBaseUser) + starlette_user.is_authenticated = False + proxy = StarletteUser(starlette_user) + assert proxy.is_authenticated is False + + def test_user_name(self): + starlette_user = MagicMock(spec=StarletteBaseUser) + starlette_user.display_name = 'Test User' + proxy = StarletteUser(starlette_user) + assert proxy.user_name == 'Test User' + + def test_user_name_raises_attribute_error(self): + starlette_user = MagicMock(spec=StarletteBaseUser) + del starlette_user.display_name + proxy = StarletteUser(starlette_user) + with pytest.raises(AttributeError, match='display_name'): + _ = proxy.user_name + + +# --- default_user_builder Tests --- + + +def _make_mock_request(scope=None, headers=None): + request = MagicMock() + request.scope = scope or {} + request.headers = Headers(headers or {}) + return request + + +class TestDefaultContextBuilder: + def test_returns_unauthenticated_user_when_no_user_in_scope(self): + request = _make_mock_request(scope={}) + user = DefaultServerCallContextBuilder().build_user(request) + assert isinstance(user, UnauthenticatedUser) + assert user.is_authenticated is False + assert user.user_name == '' + + def test_returns_proxy_when_user_in_scope(self): + starlette_user = MagicMock() + starlette_user.is_authenticated = True + starlette_user.display_name = 'Alice' + request = _make_mock_request(scope={'user': starlette_user}) + request.user = starlette_user + + user = DefaultServerCallContextBuilder().build_user(request) + assert isinstance(user, StarletteUser) + assert user.is_authenticated is True + assert user.user_name == 'Alice' + + def test_returns_unauthenticated_proxy_when_user_not_authenticated(self): + starlette_user = MagicMock() + starlette_user.is_authenticated = False + starlette_user.display_name = '' + request = _make_mock_request(scope={'user': starlette_user}) + request.user = starlette_user + + user = DefaultServerCallContextBuilder().build_user(request) + assert isinstance(user, StarletteUser) + assert user.is_authenticated is False + + +# --- build_server_call_context Tests --- + + +class TestBuildServerCallContext: + def test_basic_context_with_default_user_builder(self): + request = _make_mock_request( + scope={}, headers={'content-type': 'application/json'} + ) + ctx = DefaultServerCallContextBuilder().build(request) + + assert isinstance(ctx, ServerCallContext) + assert isinstance(ctx.user, UnauthenticatedUser) + assert 'headers' in ctx.state + assert ctx.state['headers']['content-type'] == 'application/json' + assert 'auth' not in ctx.state + + def test_auth_populated_when_in_scope(self): + auth_credentials = MagicMock() + request = _make_mock_request(scope={'auth': auth_credentials}) + request.auth = auth_credentials + + ctx = DefaultServerCallContextBuilder().build(request) + assert ctx.state['auth'] is auth_credentials + + def test_auth_not_populated_when_not_in_scope(self): + request = _make_mock_request(scope={}) + ctx = DefaultServerCallContextBuilder().build(request) + assert 'auth' not in ctx.state + + def test_headers_captured_in_state(self): + request = _make_mock_request( + headers={'x-custom': 'value', 'authorization': 'Bearer tok'} + ) + ctx = DefaultServerCallContextBuilder().build(request) + assert ctx.state['headers']['x-custom'] == 'value' + assert ctx.state['headers']['authorization'] == 'Bearer tok' + + def test_requested_extensions_single(self): + request = _make_mock_request(headers={HTTP_EXTENSION_HEADER: 'foo'}) + ctx = DefaultServerCallContextBuilder().build(request) + assert ctx.requested_extensions == {'foo'} + + def test_requested_extensions_comma_separated(self): + request = _make_mock_request( + headers={HTTP_EXTENSION_HEADER: 'foo, bar'} + ) + ctx = DefaultServerCallContextBuilder().build(request) + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + request = _make_mock_request() + ctx = DefaultServerCallContextBuilder().build(request) + assert ctx.requested_extensions == set() + + def test_custom_user_builder(self): + custom_user = MagicMock(spec=UnauthenticatedUser) + custom_user.is_authenticated = True + + class MyContextBuilder(DefaultServerCallContextBuilder): + def build_user(self, req): + return custom_user + + request = _make_mock_request() + ctx = MyContextBuilder().build(request) + assert ctx.user is custom_user diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 1242bee23..31a550de3 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -21,49 +21,14 @@ Role, ) from a2a.server.routes import jsonrpc_dispatcher -from a2a.server.routes.jsonrpc_dispatcher import ( - CallContextBuilder, - DefaultCallContextBuilder, - JsonRpcDispatcher, - StarletteUserProxy, -) + +from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes from a2a.server.routes.agent_card_routes import create_agent_card_routes from a2a.server.jsonrpc_models import JSONRPCError from a2a.utils.errors import A2AError -# --- StarletteUserProxy Tests --- - - -class TestStarletteUserProxy: - def test_starlette_user_proxy_is_authenticated_true(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = True - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is True - - def test_starlette_user_proxy_is_authenticated_false(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.is_authenticated = False - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.is_authenticated is False - - def test_starlette_user_proxy_user_name(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - starlette_user_mock.display_name = 'Test User DisplayName' - proxy = StarletteUserProxy(starlette_user_mock) - assert proxy.user_name == 'Test User DisplayName' - - def test_starlette_user_proxy_user_name_raises_attribute_error(self): - starlette_user_mock = MagicMock(spec=StarletteBaseUser) - del starlette_user_mock.display_name - - proxy = StarletteUserProxy(starlette_user_mock) - with pytest.raises(AttributeError, match='display_name'): - _ = proxy.user_name - - # --- JsonRpcDispatcher Tests --- diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index b4233d0cd..be5870cc4 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -11,7 +11,6 @@ from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes import rest_dispatcher from a2a.server.routes.rest_dispatcher import ( - DefaultCallContextBuilder, RestDispatcher, ) from a2a.types.a2a_pb2 import ( From 007be0603003212db17dfe046e3521313835295c Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 14:56:53 +0000 Subject: [PATCH 2/7] remove unused stuff --- src/a2a/server/routes/rest_routes.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 89ba63b8e..20a899ca4 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -14,25 +14,15 @@ if TYPE_CHECKING: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response from starlette.routing import BaseRoute, Mount, Route _package_starlette_installed = True else: try: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response from starlette.routing import BaseRoute, Mount, Route _package_starlette_installed = True except ImportError: - EventSourceResponse = Any - Request = Any - JSONResponse = Any - Response = Any Route = Any Mount = Any BaseRoute = Any From 8772059d9f64191745aaf83e2c4562003e40e2aa Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 15:07:27 +0000 Subject: [PATCH 3/7] test --- src/a2a/server/routes/test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/a2a/server/routes/test.txt diff --git a/src/a2a/server/routes/test.txt b/src/a2a/server/routes/test.txt new file mode 100644 index 000000000..e69de29bb From 2f3f1a3e1ad81d86f38f20ee2826fad3e4ca309b Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 15:07:40 +0000 Subject: [PATCH 4/7] remove test --- src/a2a/server/routes/test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/a2a/server/routes/test.txt diff --git a/src/a2a/server/routes/test.txt b/src/a2a/server/routes/test.txt deleted file mode 100644 index e69de29bb..000000000 From 85b6f4826e3e5f6f512e9a62a5b531ee7c8a8709 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 15:09:13 +0000 Subject: [PATCH 5/7] fix --- uv.lock | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/uv.lock b/uv.lock index 5d7d3b6fb..919ed3148 100644 --- a/uv.lock +++ b/uv.lock @@ -27,7 +27,6 @@ dependencies = [ all = [ { name = "alembic" }, { name = "cryptography" }, - { name = "fastapi" }, { name = "google-cloud-aiplatform" }, { name = "grpcio" }, { name = "grpcio-reflection" }, @@ -53,7 +52,6 @@ grpc = [ { name = "grpcio-tools" }, ] http-server = [ - { name = "fastapi" }, { name = "sse-starlette" }, { name = "starlette" }, ] @@ -109,8 +107,6 @@ requires-dist = [ { name = "cryptography", marker = "extra == 'all'", specifier = ">=43.0.0" }, { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, { name = "culsans", marker = "python_full_version < '3.13'", specifier = ">=0.11.0" }, - { name = "fastapi", marker = "extra == 'all'", specifier = ">=0.115.2" }, - { name = "fastapi", marker = "extra == 'http-server'", specifier = ">=0.115.2" }, { name = "google-api-core", specifier = ">=1.26.0" }, { name = "google-cloud-aiplatform", marker = "extra == 'all'", specifier = ">=1.140.0" }, { name = "google-cloud-aiplatform", marker = "extra == 'vertex'", specifier = ">=1.140.0" }, @@ -223,15 +219,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, ] -[[package]] -name = "annotated-doc" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -818,22 +805,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] -[[package]] -name = "fastapi" -version = "0.135.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "annotated-doc" }, - { name = "pydantic" }, - { name = "starlette" }, - { name = "typing-extensions" }, - { name = "typing-inspection" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e7/7b/f8e0211e9380f7195ba3f3d40c292594fd81ba8ec4629e3854c353aaca45/fastapi-0.135.1.tar.gz", hash = "sha256:d04115b508d936d254cea545b7312ecaa58a7b3a0f84952535b4c9afae7668cd", size = 394962, upload-time = "2026-03-01T18:18:29.369Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/72/42e900510195b23a56bde950d26a51f8b723846bfcaa0286e90287f0422b/fastapi-0.135.1-py3-none-any.whl", hash = "sha256:46e2fc5745924b7c840f71ddd277382af29ce1cdb7d5eab5bf697e3fb9999c9e", size = 116999, upload-time = "2026-03-01T18:18:30.831Z" }, -] - [[package]] name = "filelock" version = "3.25.2" From b331c1aee52c29aec12a4bc1585531ab7e5cb5fb Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 15:17:58 +0000 Subject: [PATCH 6/7] chore: update project dependencies in pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7dc53ef8a..724749865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ style = "pep440" [dependency-groups] dev = [ + "fastapi>=0.115.2", "mypy>=1.15.0", "PyJWT>=2.0.0", "pytest>=8.3.5", From 8c1d03256a04179ce437e3c281558b0659ac2214 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 3 Apr 2026 15:18:54 +0000 Subject: [PATCH 7/7] change --- uv.lock | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/uv.lock b/uv.lock index 919ed3148..dc87a7b6d 100644 --- a/uv.lock +++ b/uv.lock @@ -81,6 +81,7 @@ vertex = [ [package.dev-dependencies] dev = [ { name = "a2a-sdk", extra = ["all"] }, + { name = "fastapi" }, { name = "mypy" }, { name = "pre-commit" }, { name = "pyjwt" }, @@ -150,6 +151,7 @@ provides-extras = ["all", "db-cli", "encryption", "grpc", "http-server", "mysql" [package.metadata.requires-dev] dev = [ { name = "a2a-sdk", extras = ["all"], editable = "." }, + { name = "fastapi", specifier = ">=0.115.2" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "pre-commit" }, { name = "pyjwt", specifier = ">=2.0.0" }, @@ -219,6 +221,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -805,6 +816,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] +[[package]] +name = "fastapi" +version = "0.135.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/e6/7adb4c5fa231e82c35b8f5741a9f2d055f520c29af5546fd70d3e8e1cd2e/fastapi-0.135.3.tar.gz", hash = "sha256:bd6d7caf1a2bdd8d676843cdcd2287729572a1ef524fc4d65c17ae002a1be654", size = 396524, upload-time = "2026-04-01T16:23:58.188Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/a4/5caa2de7f917a04ada20018eccf60d6cc6145b0199d55ca3711b0fc08312/fastapi-0.135.3-py3-none-any.whl", hash = "sha256:9b0f590c813acd13d0ab43dd8494138eb58e484bfac405db1f3187cfc5810d98", size = 117734, upload-time = "2026-04-01T16:23:59.328Z" }, +] + [[package]] name = "filelock" version = "3.25.2"