Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_serializer, field_validator


class OAuthToken(BaseModel):
Expand Down Expand Up @@ -129,6 +129,12 @@ class OAuthMetadata(BaseModel):
code_challenge_methods_supported: list[str] | None = None
client_id_metadata_document_supported: bool | None = None

@field_serializer("issuer")
@staticmethod
def _serialize_issuer(v: AnyHttpUrl) -> str:
"""Strip trailing slash added by AnyHttpUrl for RFC 8414 §3.3 compliance."""
return str(v).rstrip("/")
Comment on lines +132 to +136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just use str instead of AnyHttpUrl in the field?



class ProtectedResourceMetadata(BaseModel):
"""RFC 9728 OAuth 2.0 Protected Resource Metadata.
Expand All @@ -151,3 +157,15 @@ class ProtectedResourceMetadata(BaseModel):
dpop_signing_alg_values_supported: list[str] | None = None
# dpop_bound_access_tokens_required default is False, but ommited here for clarity
dpop_bound_access_tokens_required: bool | None = None

@field_serializer("resource")
@staticmethod
def _serialize_resource(v: AnyHttpUrl) -> str:
"""Strip trailing slash added by AnyHttpUrl for RFC 9728 §3 compliance."""
return str(v).rstrip("/")

@field_serializer("authorization_servers")
@staticmethod
def _serialize_authorization_servers(v: list[AnyHttpUrl]) -> list[str]:
"""Strip trailing slashes added by AnyHttpUrl for RFC 9728 §3 compliance."""
return [str(s).rstrip("/") for s in v]
12 changes: 6 additions & 6 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,29 +1283,27 @@ async def mock_callback() -> tuple[str, str | None]:
@pytest.mark.parametrize(
(
"issuer_url",
"expected_issuer",
"service_documentation_url",
"authorization_endpoint",
"token_endpoint",
"registration_endpoint",
"revocation_endpoint",
),
(
# Pydantic's AnyUrl incorrectly adds trailing slash to base URLs
# This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+)
pytest.param(
"https://auth.example.com",
"https://auth.example.com",
"https://auth.example.com/docs",
"https://auth.example.com/authorize",
"https://auth.example.com/token",
"https://auth.example.com/register",
"https://auth.example.com/revoke",
id="simple-url",
marks=pytest.mark.xfail(
reason="Pydantic AnyUrl adds trailing slash to base URLs - fixed in Pydantic 2.12+"
),
),
pytest.param(
"https://auth.example.com/",
"https://auth.example.com",
"https://auth.example.com/docs",
"https://auth.example.com/authorize",
"https://auth.example.com/token",
Expand All @@ -1314,6 +1312,7 @@ async def mock_callback() -> tuple[str, str | None]:
id="with-trailing-slash",
),
pytest.param(
"https://auth.example.com/v1/mcp",
"https://auth.example.com/v1/mcp",
"https://auth.example.com/v1/mcp/docs",
"https://auth.example.com/v1/mcp/authorize",
Expand All @@ -1326,6 +1325,7 @@ async def mock_callback() -> tuple[str, str | None]:
)
def test_build_metadata(
issuer_url: str,
expected_issuer: str,
service_documentation_url: str,
authorization_endpoint: str,
token_endpoint: str,
Expand All @@ -1341,7 +1341,7 @@ def test_build_metadata(

assert metadata.model_dump(exclude_defaults=True, mode="json") == snapshot(
{
"issuer": Is(issuer_url),
"issuer": Is(expected_issuer),
"authorization_endpoint": Is(authorization_endpoint),
"token_endpoint": Is(token_endpoint),
"registration_endpoint": Is(registration_endpoint),
Expand Down
4 changes: 2 additions & 2 deletions tests/server/auth/test_protected_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC
assert response.status_code == 200
assert response.json() == snapshot(
{
"resource": "https://example.com/",
"authorization_servers": ["https://auth.example.com/"],
"resource": "https://example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read"],
"resource_name": "Root Resource",
"bearer_methods_supported": ["header"],
Expand Down
2 changes: 1 addition & 1 deletion tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
assert response.status_code == 200

metadata = response.json()
assert metadata["issuer"] == "https://auth.example.com/"
assert metadata["issuer"] == "https://auth.example.com"
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
assert metadata["token_endpoint"] == "https://auth.example.com/token"
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
Expand Down