diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bf03a8b8d..56e4372ad 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_serializer, field_validator class OAuthToken(BaseModel): @@ -129,6 +129,12 @@ class OAuthMetadata(BaseModel): code_challenge_methods_supported: list[str] | None = None client_id_metadata_document_supported: bool | None = None + @field_serializer("issuer") + @staticmethod + def _serialize_issuer(v: AnyHttpUrl) -> str: + """Strip trailing slash added by AnyHttpUrl for RFC 8414 §3.3 compliance.""" + return str(v).rstrip("/") + class ProtectedResourceMetadata(BaseModel): """RFC 9728 OAuth 2.0 Protected Resource Metadata. @@ -151,3 +157,15 @@ class ProtectedResourceMetadata(BaseModel): dpop_signing_alg_values_supported: list[str] | None = None # dpop_bound_access_tokens_required default is False, but ommited here for clarity dpop_bound_access_tokens_required: bool | None = None + + @field_serializer("resource") + @staticmethod + def _serialize_resource(v: AnyHttpUrl) -> str: + """Strip trailing slash added by AnyHttpUrl for RFC 9728 §3 compliance.""" + return str(v).rstrip("/") + + @field_serializer("authorization_servers") + @staticmethod + def _serialize_authorization_servers(v: list[AnyHttpUrl]) -> list[str]: + """Strip trailing slashes added by AnyHttpUrl for RFC 9728 §3 compliance.""" + return [str(s).rstrip("/") for s in v] diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2f531cc65..1f52ebfd7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1283,6 +1283,7 @@ async def mock_callback() -> tuple[str, str | None]: @pytest.mark.parametrize( ( "issuer_url", + "expected_issuer", "service_documentation_url", "authorization_endpoint", "token_endpoint", @@ -1290,9 +1291,8 @@ async def mock_callback() -> tuple[str, str | None]: "revocation_endpoint", ), ( - # Pydantic's AnyUrl incorrectly adds trailing slash to base URLs - # This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+) pytest.param( + "https://auth.example.com", "https://auth.example.com", "https://auth.example.com/docs", "https://auth.example.com/authorize", @@ -1300,12 +1300,10 @@ async def mock_callback() -> tuple[str, str | None]: "https://auth.example.com/register", "https://auth.example.com/revoke", id="simple-url", - marks=pytest.mark.xfail( - reason="Pydantic AnyUrl adds trailing slash to base URLs - fixed in Pydantic 2.12+" - ), ), pytest.param( "https://auth.example.com/", + "https://auth.example.com", "https://auth.example.com/docs", "https://auth.example.com/authorize", "https://auth.example.com/token", @@ -1314,6 +1312,7 @@ async def mock_callback() -> tuple[str, str | None]: id="with-trailing-slash", ), pytest.param( + "https://auth.example.com/v1/mcp", "https://auth.example.com/v1/mcp", "https://auth.example.com/v1/mcp/docs", "https://auth.example.com/v1/mcp/authorize", @@ -1326,6 +1325,7 @@ async def mock_callback() -> tuple[str, str | None]: ) def test_build_metadata( issuer_url: str, + expected_issuer: str, service_documentation_url: str, authorization_endpoint: str, token_endpoint: str, @@ -1341,7 +1341,7 @@ def test_build_metadata( assert metadata.model_dump(exclude_defaults=True, mode="json") == snapshot( { - "issuer": Is(issuer_url), + "issuer": Is(expected_issuer), "authorization_endpoint": Is(authorization_endpoint), "token_endpoint": Is(token_endpoint), "registration_endpoint": Is(registration_endpoint), diff --git a/tests/server/auth/test_protected_resource.py b/tests/server/auth/test_protected_resource.py index 413a80276..ca2a6e15f 100644 --- a/tests/server/auth/test_protected_resource.py +++ b/tests/server/auth/test_protected_resource.py @@ -96,8 +96,8 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC assert response.status_code == 200 assert response.json() == snapshot( { - "resource": "https://example.com/", - "authorization_servers": ["https://auth.example.com/"], + "resource": "https://example.com", + "authorization_servers": ["https://auth.example.com"], "scopes_supported": ["read"], "resource_name": "Root Resource", "bearer_methods_supported": ["header"], diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 5000c7b38..ebf459539 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -311,7 +311,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert response.status_code == 200 metadata = response.json() - assert metadata["issuer"] == "https://auth.example.com/" + assert metadata["issuer"] == "https://auth.example.com" assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register"