Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/adcp/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ class BearerTokenAuthMiddleware(BaseHTTPMiddleware):
:param validate_token: Your token lookup. See :data:`TokenValidator`.
:param unauthenticated_response: Optional override for the 401
response body. Default is ``{"error": "unauthenticated"}``.
:param header_name: Which HTTP header carries the credential.
Default ``"authorization"`` (the spec-canonical bearer header).
Adopters with legacy clients sending tokens via a custom header
(e.g. ``"x-adcp-auth"``) override this. Header lookup is
case-insensitive (Starlette normalizes).
:param bearer_prefix_required: When ``True`` (default), the
middleware strips a ``"Bearer "`` prefix and rejects headers
without it. When ``False``, the raw header value is passed
verbatim to ``validate_token`` — appropriate for non-OAuth
custom-header schemes (``X-Api-Key: <token>``,
``x-adcp-auth: <token>``, etc.). Adopters changing
``header_name`` to a non-standard value usually want this set
to ``False``. **Security note:** setting this to ``False``
removes the prefix pre-filter; ``validate_token`` must be
defensive about unexpected input shapes and unbounded lengths.
"""

def __init__(
Expand All @@ -231,10 +246,17 @@ def __init__(
*,
validate_token: TokenValidator,
unauthenticated_response: dict[str, Any] | None = None,
header_name: str = "authorization",
bearer_prefix_required: bool = True,
) -> None:
super().__init__(app)
self._validate_token = validate_token
self._unauth_body = unauthenticated_response or {"error": "unauthenticated"}
# Lower-cased once at construction so the per-request lookup
# avoids the normalization. Starlette's Headers does
# case-insensitive matching, so this is belt-and-suspenders.
self._header_name = header_name.lower()
self._bearer_prefix_required = bearer_prefix_required

async def dispatch(self, request: Request, call_next: Any) -> Any:
method, tool = await self._peek_jsonrpc(request)
Expand All @@ -249,7 +271,15 @@ async def dispatch(self, request: Request, call_next: Any) -> Any:
metadata_token = current_principal_metadata.set(None)
return await call_next(request)

bearer = _parse_bearer_header(request.headers.get("authorization", ""))
raw_header = request.headers.get(self._header_name, "")
if self._bearer_prefix_required:
bearer = _parse_bearer_header(raw_header)
else:
# Custom-header schemes (X-Api-Key, x-adcp-auth, etc.) —
# pass the raw value through unchanged. Strip whitespace
# since copy-paste tokens often pick up trailing newlines.
stripped = raw_header.strip()
bearer = stripped or None
if not bearer:
return self._unauthenticated()

Expand Down
186 changes: 186 additions & 0 deletions tests/test_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,189 @@ def test_auth_context_factory_with_no_principal() -> None:
# already boots the FastMCP initialize/tools-call flow end-to-end. The
# tests in this file stay focused on the middleware class itself so
# failures localise to the auth logic, not the transport plumbing.


# ----- custom header / non-Bearer schemes ---------------------------------


def _build_app_custom_header(
validator: Any, *, header_name: str, bearer_prefix_required: bool
) -> Starlette:
app = Starlette(routes=[Route("/", _echo_handler, methods=["POST"])])
app.add_middleware(
BearerTokenAuthMiddleware,
validate_token=validator,
header_name=header_name,
bearer_prefix_required=bearer_prefix_required,
)
return app


@pytest.mark.asyncio
async def test_custom_header_x_adcp_auth_no_bearer_prefix() -> None:
"""Salesagent-shaped scheme: ``x-adcp-auth: <raw-token>``.

The legacy salesagent server uses this header layout — no
``Authorization`` header, no ``Bearer`` prefix. The middleware
must accept the raw token verbatim when ``bearer_prefix_required``
is False.
"""
received_tokens: list[str] = []

def validator(token: str) -> Principal | None:
received_tokens.append(token)
return Principal(caller_identity="alice", tenant_id="acme")

app = _build_app_custom_header(
validator, header_name="x-adcp-auth", bearer_prefix_required=False
)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"x-adcp-auth": "tok_alice_raw"},
)
assert resp.status_code == 200
assert received_tokens == ["tok_alice_raw"] # passed through verbatim, no Bearer prefix


@pytest.mark.asyncio
async def test_custom_header_strips_whitespace() -> None:
"""Trailing newlines / spaces (common in copy-pasted tokens) are stripped."""
received: list[str] = []

def validator(token: str) -> Principal | None:
received.append(token)
return Principal(caller_identity="alice")

app = _build_app_custom_header(validator, header_name="x-api-key", bearer_prefix_required=False)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"x-api-key": " tok_alice "},
)
assert resp.status_code == 200
assert received == ["tok_alice"]


@pytest.mark.asyncio
async def test_custom_header_rejects_missing_credential() -> None:
"""Custom-header mode still 401s when the configured header isn't present."""

def validator(token: str) -> Principal | None:
return Principal(caller_identity="alice")

app = _build_app_custom_header(
validator, header_name="x-adcp-auth", bearer_prefix_required=False
)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
# Send an Authorization header — but the middleware is configured
# to look at x-adcp-auth, so this must be rejected.
resp = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"authorization": "Bearer tok_alice"},
)
assert resp.status_code == 401


@pytest.mark.asyncio
async def test_custom_header_with_bearer_prefix_still_required() -> None:
"""When ``bearer_prefix_required=True`` (the default), even a custom
header must carry the ``Bearer`` prefix. Useful for adopters using a
non-``Authorization`` header but keeping the OAuth2 envelope (e.g.
proxies that strip ``Authorization``)."""

def validator(token: str) -> Principal | None:
return Principal(caller_identity="alice")

app = _build_app_custom_header(
validator, header_name="x-proxied-auth", bearer_prefix_required=True
)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
# Without Bearer prefix → 401
resp1 = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"x-proxied-auth": "tok_alice"},
)
assert resp1.status_code == 401

# With Bearer prefix → 200
resp2 = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"x-proxied-auth": "Bearer tok_alice"},
)
assert resp2.status_code == 200


@pytest.mark.asyncio
async def test_custom_header_is_exclusive_no_fallback_to_authorization() -> None:
"""Custom header_name is exclusive — Authorization is never consulted.

When both ``Authorization: Bearer X`` and ``x-adcp-auth: Y`` are
present and the middleware is configured for ``x-adcp-auth``, only
``Y`` reaches the validator. There is no fallback to the standard
header. Closes the "I expected fallback to Authorization" footgun for
adopters who set header_name accidentally.
"""
received: list[str] = []

def validator(token: str) -> Principal | None:
received.append(token)
return Principal(caller_identity="alice")

app = _build_app_custom_header(
validator, header_name="x-adcp-auth", bearer_prefix_required=False
)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={
"Authorization": "Bearer tok_x",
"x-adcp-auth": "tok_y",
},
)
assert resp.status_code == 200
assert received == ["tok_y"] # x-adcp-auth wins; Authorization is ignored


@pytest.mark.asyncio
async def test_default_header_unchanged_for_existing_adopters() -> None:
"""The defaults (``Authorization`` header + Bearer prefix) match the
pre-existing behavior. Existing adopters not setting the new params
see no behavioral change."""

def validator(token: str) -> Principal | None:
return Principal(caller_identity="alice")

# Use the original _build_app (no custom kwargs) — same as before.
app = _build_app(validator)
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/",
json={"method": "tools/call", "params": {"name": "get_products"}},
headers={"authorization": "Bearer tok_alice"},
)
assert resp.status_code == 200
Loading