Skip to content

Commit a0bbcdf

Browse files
committed
fix(auth): make get_access_token per-request in stateful sessions
1 parent e8e6484 commit a0bbcdf

3 files changed

Lines changed: 129 additions & 2 deletions

File tree

src/mcp/server/auth/middleware/auth_context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import contextvars
22

3+
from contextvars import Token
4+
5+
from starlette.requests import Request
36
from starlette.types import ASGIApp, Receive, Scope, Send
47

58
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
@@ -20,6 +23,26 @@ def get_access_token() -> AccessToken | None:
2023
return auth_user.access_token if auth_user else None
2124

2225

26+
def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
27+
"""Set auth context for the current task from an incoming request.
28+
29+
This is primarily used by server transports where request handlers may run
30+
in background tasks that are not part of the original ASGI request task.
31+
"""
32+
if request is None:
33+
return None
34+
user = getattr(request, "user", None)
35+
if isinstance(user, AuthenticatedUser):
36+
return auth_context_var.set(user)
37+
return None
38+
39+
40+
def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
41+
if token is None:
42+
return
43+
auth_context_var.reset(token)
44+
45+
2346
class AuthContextMiddleware:
2447
"""Middleware that extracts the authenticated user from the request
2548
and sets it in a contextvar for easy access throughout the request lifecycle.

src/mcp/server/lowlevel/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def main():
5353
from typing_extensions import TypeVar
5454

5555
from mcp import types
56-
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
56+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, _pop_auth_context, _push_auth_context_from_request
5757
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
5858
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5959
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
@@ -497,7 +497,11 @@ async def _handle_request(
497497
close_sse_stream=close_sse_stream_cb,
498498
close_standalone_sse_stream=close_standalone_sse_stream_cb,
499499
)
500-
response = await handler(ctx, req.params)
500+
auth_token = _push_auth_context_from_request(request_data)
501+
try:
502+
response = await handler(ctx, req.params)
503+
finally:
504+
_pop_auth_context(auth_token)
501505
except MCPError as err:
502506
response = err.error
503507
except anyio.get_cancelled_exc_class():
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import time
2+
3+
import httpx
4+
import pytest
5+
from starlette.applications import Starlette
6+
from starlette.middleware import Middleware
7+
from starlette.middleware.authentication import AuthenticationMiddleware
8+
from starlette.routing import Mount
9+
10+
from mcp import Client
11+
from mcp.client.streamable_http import streamable_http_client
12+
from mcp.server import Server, ServerRequestContext
13+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
14+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
15+
from mcp.server.auth.provider import AccessToken
16+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17+
from mcp.server.transport_security import TransportSecuritySettings
18+
from mcp.types import (
19+
CallToolRequestParams,
20+
CallToolResult,
21+
ListToolsResult,
22+
PaginatedRequestParams,
23+
TextContent,
24+
Tool,
25+
)
26+
27+
28+
class _EchoTokenVerifier:
29+
"""Accepts any bearer token and echoes it back as the verified AccessToken."""
30+
31+
async def verify_token(self, token: str) -> AccessToken | None:
32+
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
33+
34+
35+
async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
36+
access = get_access_token()
37+
text = access.token if access else "<none>"
38+
return CallToolResult(content=[TextContent(type="text", text=text)])
39+
40+
41+
async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
42+
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])
43+
44+
45+
class _MutableBearerAuth(httpx.Auth):
46+
def __init__(self, token: str) -> None:
47+
self.token = token
48+
49+
def auth_flow(self, request: httpx.Request):
50+
request.headers["Authorization"] = f"Bearer {self.token}"
51+
yield request
52+
53+
54+
@pytest.mark.anyio
55+
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
56+
host = "testserver"
57+
58+
server = Server(
59+
"auth-test-server",
60+
on_call_tool=_handle_whoami,
61+
on_list_tools=_handle_list_tools,
62+
)
63+
64+
security = TransportSecuritySettings(
65+
allowed_hosts=[host, f"{host}:*"],
66+
allowed_origins=[f"http://{host}:*"],
67+
)
68+
session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False)
69+
70+
asgi_app = Starlette(
71+
routes=[Mount("/mcp", app=session_manager.handle_request)],
72+
middleware=[
73+
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
74+
Middleware(AuthContextMiddleware),
75+
],
76+
lifespan=lambda app: session_manager.run(),
77+
)
78+
79+
auth = _MutableBearerAuth("token-A")
80+
async with asgi_app.router.lifespan_context(asgi_app):
81+
async with (
82+
httpx.ASGITransport(asgi_app) as transport,
83+
httpx.AsyncClient(
84+
transport=transport,
85+
base_url=f"http://{host}",
86+
auth=auth,
87+
timeout=httpx.Timeout(30, read=30),
88+
follow_redirects=True,
89+
) as http_client,
90+
):
91+
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
92+
async with Client(transport_ctx) as client:
93+
r1 = await client.call_tool("whoami", {})
94+
assert isinstance(r1.content[0], TextContent)
95+
assert r1.content[0].text == "token-A"
96+
97+
auth.token = "token-B"
98+
r2 = await client.call_tool("whoami", {})
99+
assert isinstance(r2.content[0], TextContent)
100+
assert r2.content[0].text == "token-B"

0 commit comments

Comments
 (0)