diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index b745d0b7..bb84d33d 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -223,6 +223,21 @@ class BearerTokenAuthMiddleware(BaseHTTPMiddleware): :param validate_token: Your token lookup. See :data:`TokenValidator`. :param unauthenticated_response: Optional override for the 401 response body. Default is ``{"error": "unauthenticated"}``. + :param header_name: Which HTTP header carries the credential. + Default ``"authorization"`` (the spec-canonical bearer header). + Adopters with legacy clients sending tokens via a custom header + (e.g. ``"x-adcp-auth"``) override this. Header lookup is + case-insensitive (Starlette normalizes). + :param bearer_prefix_required: When ``True`` (default), the + middleware strips a ``"Bearer "`` prefix and rejects headers + without it. When ``False``, the raw header value is passed + verbatim to ``validate_token`` — appropriate for non-OAuth + custom-header schemes (``X-Api-Key: ``, + ``x-adcp-auth: ``, etc.). Adopters changing + ``header_name`` to a non-standard value usually want this set + to ``False``. **Security note:** setting this to ``False`` + removes the prefix pre-filter; ``validate_token`` must be + defensive about unexpected input shapes and unbounded lengths. """ def __init__( @@ -231,10 +246,17 @@ def __init__( *, validate_token: TokenValidator, unauthenticated_response: dict[str, Any] | None = None, + header_name: str = "authorization", + bearer_prefix_required: bool = True, ) -> None: super().__init__(app) self._validate_token = validate_token self._unauth_body = unauthenticated_response or {"error": "unauthenticated"} + # Lower-cased once at construction so the per-request lookup + # avoids the normalization. Starlette's Headers does + # case-insensitive matching, so this is belt-and-suspenders. + self._header_name = header_name.lower() + self._bearer_prefix_required = bearer_prefix_required async def dispatch(self, request: Request, call_next: Any) -> Any: method, tool = await self._peek_jsonrpc(request) @@ -249,7 +271,15 @@ async def dispatch(self, request: Request, call_next: Any) -> Any: metadata_token = current_principal_metadata.set(None) return await call_next(request) - bearer = _parse_bearer_header(request.headers.get("authorization", "")) + raw_header = request.headers.get(self._header_name, "") + if self._bearer_prefix_required: + bearer = _parse_bearer_header(raw_header) + else: + # Custom-header schemes (X-Api-Key, x-adcp-auth, etc.) — + # pass the raw value through unchanged. Strip whitespace + # since copy-paste tokens often pick up trailing newlines. + stripped = raw_header.strip() + bearer = stripped or None if not bearer: return self._unauthenticated() diff --git a/tests/test_auth_middleware.py b/tests/test_auth_middleware.py index af83be6b..0d1c624f 100644 --- a/tests/test_auth_middleware.py +++ b/tests/test_auth_middleware.py @@ -569,3 +569,189 @@ def test_auth_context_factory_with_no_principal() -> None: # already boots the FastMCP initialize/tools-call flow end-to-end. The # tests in this file stay focused on the middleware class itself so # failures localise to the auth logic, not the transport plumbing. + + +# ----- custom header / non-Bearer schemes --------------------------------- + + +def _build_app_custom_header( + validator: Any, *, header_name: str, bearer_prefix_required: bool +) -> Starlette: + app = Starlette(routes=[Route("/", _echo_handler, methods=["POST"])]) + app.add_middleware( + BearerTokenAuthMiddleware, + validate_token=validator, + header_name=header_name, + bearer_prefix_required=bearer_prefix_required, + ) + return app + + +@pytest.mark.asyncio +async def test_custom_header_x_adcp_auth_no_bearer_prefix() -> None: + """Salesagent-shaped scheme: ``x-adcp-auth: ``. + + The legacy salesagent server uses this header layout — no + ``Authorization`` header, no ``Bearer`` prefix. The middleware + must accept the raw token verbatim when ``bearer_prefix_required`` + is False. + """ + received_tokens: list[str] = [] + + def validator(token: str) -> Principal | None: + received_tokens.append(token) + return Principal(caller_identity="alice", tenant_id="acme") + + app = _build_app_custom_header( + validator, header_name="x-adcp-auth", bearer_prefix_required=False + ) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"x-adcp-auth": "tok_alice_raw"}, + ) + assert resp.status_code == 200 + assert received_tokens == ["tok_alice_raw"] # passed through verbatim, no Bearer prefix + + +@pytest.mark.asyncio +async def test_custom_header_strips_whitespace() -> None: + """Trailing newlines / spaces (common in copy-pasted tokens) are stripped.""" + received: list[str] = [] + + def validator(token: str) -> Principal | None: + received.append(token) + return Principal(caller_identity="alice") + + app = _build_app_custom_header(validator, header_name="x-api-key", bearer_prefix_required=False) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"x-api-key": " tok_alice "}, + ) + assert resp.status_code == 200 + assert received == ["tok_alice"] + + +@pytest.mark.asyncio +async def test_custom_header_rejects_missing_credential() -> None: + """Custom-header mode still 401s when the configured header isn't present.""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + app = _build_app_custom_header( + validator, header_name="x-adcp-auth", bearer_prefix_required=False + ) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + # Send an Authorization header — but the middleware is configured + # to look at x-adcp-auth, so this must be rejected. + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"authorization": "Bearer tok_alice"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_custom_header_with_bearer_prefix_still_required() -> None: + """When ``bearer_prefix_required=True`` (the default), even a custom + header must carry the ``Bearer`` prefix. Useful for adopters using a + non-``Authorization`` header but keeping the OAuth2 envelope (e.g. + proxies that strip ``Authorization``).""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + app = _build_app_custom_header( + validator, header_name="x-proxied-auth", bearer_prefix_required=True + ) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + # Without Bearer prefix → 401 + resp1 = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"x-proxied-auth": "tok_alice"}, + ) + assert resp1.status_code == 401 + + # With Bearer prefix → 200 + resp2 = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"x-proxied-auth": "Bearer tok_alice"}, + ) + assert resp2.status_code == 200 + + +@pytest.mark.asyncio +async def test_custom_header_is_exclusive_no_fallback_to_authorization() -> None: + """Custom header_name is exclusive — Authorization is never consulted. + + When both ``Authorization: Bearer X`` and ``x-adcp-auth: Y`` are + present and the middleware is configured for ``x-adcp-auth``, only + ``Y`` reaches the validator. There is no fallback to the standard + header. Closes the "I expected fallback to Authorization" footgun for + adopters who set header_name accidentally. + """ + received: list[str] = [] + + def validator(token: str) -> Principal | None: + received.append(token) + return Principal(caller_identity="alice") + + app = _build_app_custom_header( + validator, header_name="x-adcp-auth", bearer_prefix_required=False + ) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={ + "Authorization": "Bearer tok_x", + "x-adcp-auth": "tok_y", + }, + ) + assert resp.status_code == 200 + assert received == ["tok_y"] # x-adcp-auth wins; Authorization is ignored + + +@pytest.mark.asyncio +async def test_default_header_unchanged_for_existing_adopters() -> None: + """The defaults (``Authorization`` header + Bearer prefix) match the + pre-existing behavior. Existing adopters not setting the new params + see no behavioral change.""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + # Use the original _build_app (no custom kwargs) — same as before. + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"authorization": "Bearer tok_alice"}, + ) + assert resp.status_code == 200