|
| 1 | +import time |
| 2 | + |
| 3 | +import httpx |
| 4 | +import pytest |
| 5 | +from starlette.applications import Starlette |
| 6 | +from starlette.middleware import Middleware |
| 7 | +from starlette.middleware.authentication import AuthenticationMiddleware |
| 8 | +from starlette.routing import Mount |
| 9 | + |
| 10 | +from mcp import Client |
| 11 | +from mcp.client.streamable_http import streamable_http_client |
| 12 | +from mcp.server import Server, ServerRequestContext |
| 13 | +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token |
| 14 | +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend |
| 15 | +from mcp.server.auth.provider import AccessToken |
| 16 | +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
| 17 | +from mcp.server.transport_security import TransportSecuritySettings |
| 18 | +from mcp.types import ( |
| 19 | + CallToolRequestParams, |
| 20 | + CallToolResult, |
| 21 | + ListToolsResult, |
| 22 | + PaginatedRequestParams, |
| 23 | + TextContent, |
| 24 | + Tool, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +class _EchoTokenVerifier: |
| 29 | + """Accepts any bearer token and echoes it back as the verified AccessToken.""" |
| 30 | + |
| 31 | + async def verify_token(self, token: str) -> AccessToken | None: |
| 32 | + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) |
| 33 | + |
| 34 | + |
| 35 | +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: |
| 36 | + access = get_access_token() |
| 37 | + text = access.token if access else "<none>" |
| 38 | + return CallToolResult(content=[TextContent(type="text", text=text)]) |
| 39 | + |
| 40 | + |
| 41 | +async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: |
| 42 | + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) |
| 43 | + |
| 44 | + |
| 45 | +class _MutableBearerAuth(httpx.Auth): |
| 46 | + def __init__(self, token: str) -> None: |
| 47 | + self.token = token |
| 48 | + |
| 49 | + def auth_flow(self, request: httpx.Request): |
| 50 | + request.headers["Authorization"] = f"Bearer {self.token}" |
| 51 | + yield request |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.anyio |
| 55 | +async def test_get_access_token_reflects_current_request_in_stateful_session() -> None: |
| 56 | + host = "testserver" |
| 57 | + |
| 58 | + server = Server( |
| 59 | + "auth-test-server", |
| 60 | + on_call_tool=_handle_whoami, |
| 61 | + on_list_tools=_handle_list_tools, |
| 62 | + ) |
| 63 | + |
| 64 | + security = TransportSecuritySettings( |
| 65 | + allowed_hosts=[host, f"{host}:*"], |
| 66 | + allowed_origins=[f"http://{host}:*"], |
| 67 | + ) |
| 68 | + session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False) |
| 69 | + |
| 70 | + asgi_app = Starlette( |
| 71 | + routes=[Mount("/mcp", app=session_manager.handle_request)], |
| 72 | + middleware=[ |
| 73 | + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), |
| 74 | + Middleware(AuthContextMiddleware), |
| 75 | + ], |
| 76 | + lifespan=lambda app: session_manager.run(), |
| 77 | + ) |
| 78 | + |
| 79 | + auth = _MutableBearerAuth("token-A") |
| 80 | + async with asgi_app.router.lifespan_context(asgi_app): |
| 81 | + async with ( |
| 82 | + httpx.ASGITransport(asgi_app) as transport, |
| 83 | + httpx.AsyncClient( |
| 84 | + transport=transport, |
| 85 | + base_url=f"http://{host}", |
| 86 | + auth=auth, |
| 87 | + timeout=httpx.Timeout(30, read=30), |
| 88 | + follow_redirects=True, |
| 89 | + ) as http_client, |
| 90 | + ): |
| 91 | + transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) |
| 92 | + async with Client(transport_ctx) as client: |
| 93 | + r1 = await client.call_tool("whoami", {}) |
| 94 | + assert isinstance(r1.content[0], TextContent) |
| 95 | + assert r1.content[0].text == "token-A" |
| 96 | + |
| 97 | + auth.token = "token-B" |
| 98 | + r2 = await client.call_tool("whoami", {}) |
| 99 | + assert isinstance(r2.content[0], TextContent) |
| 100 | + assert r2.content[0].text == "token-B" |
0 commit comments