@@ -50,6 +50,7 @@ async def handle_sse(request):
5050from starlette .types import Receive , Scope , Send
5151
5252from mcp import types
53+ from mcp .server .auth .middleware .bearer_auth import AuthenticatedUser , AuthorizationContext , authorization_context
5354from mcp .server .transport_security import (
5455 TransportSecurityMiddleware ,
5556 TransportSecuritySettings ,
@@ -73,6 +74,9 @@ class SseServerTransport:
7374
7475 _endpoint : str
7576 _read_stream_writers : dict [UUID , ContextSendStream [SessionMessage | Exception ]]
77+ # Identity of the credential that created each session; requests for a
78+ # session must present the same credential.
79+ _session_owners : dict [UUID , AuthorizationContext ]
7680 _security : TransportSecurityMiddleware
7781
7882 def __init__ (self , endpoint : str , security_settings : TransportSecuritySettings | None = None ) -> None :
@@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
112116
113117 self ._endpoint = endpoint
114118 self ._read_stream_writers = {}
119+ self ._session_owners = {}
115120 self ._security = TransportSecurityMiddleware (security_settings )
116121 logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
117122
118123 @asynccontextmanager
119124 async def connect_sse (self , scope : Scope , receive : Receive , send : Send ):
120- if scope ["type" ] != "http" : # pragma: no cover
125+ if scope ["type" ] != "http" :
121126 logger .error ("connect_sse received non-HTTP request" )
122127 raise ValueError ("connect_sse can only handle HTTP requests" )
123128
124129 # Validate request headers for DNS rebinding protection
125130 request = Request (scope , receive )
126131 error_response = await self ._security .validate_request (request , is_post = False )
127- if error_response : # pragma: no cover
132+ if error_response :
128133 await error_response (scope , receive , send )
129134 raise ValueError ("Request validation failed" )
130135
@@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
134139 write_stream , write_stream_reader = create_context_streams [SessionMessage ](0 )
135140
136141 session_id = uuid4 ()
142+ user = scope .get ("user" )
143+ if isinstance (user , AuthenticatedUser ):
144+ self ._session_owners [session_id ] = authorization_context (user )
137145 self ._read_stream_writers [session_id ] = read_stream_writer
138146 logger .debug (f"Created new session with ID: { session_id } " )
139147
@@ -169,35 +177,38 @@ async def sse_writer():
169177 }
170178 )
171179
172- async with anyio . create_task_group () as tg :
173-
174- async def response_wrapper ( scope : Scope , receive : Receive , send : Send ):
175- """The EventSourceResponse returning signals a client close / disconnect.
176- In this case we close our side of the streams to signal the client that
177- the connection has been closed.
178- """
179- await EventSourceResponse ( content = sse_stream_reader , data_sender_callable = sse_writer )(
180- scope , receive , send
181- )
182- await sse_stream_reader . aclose ( )
183- await read_stream_writer .aclose ()
184- await write_stream_reader .aclose ()
185- self . _read_stream_writers . pop ( session_id , None )
186- logging .debug (f"Client session disconnected { session_id } " )
180+ try :
181+ async with anyio . create_task_group () as tg :
182+
183+ async def response_wrapper ( scope : Scope , receive : Receive , send : Send ):
184+ """The EventSourceResponse returning signals a client close / disconnect.
185+ In this case we close our side of the streams to signal the client that
186+ the connection has been closed.
187+ """
188+ await EventSourceResponse ( content = sse_stream_reader , data_sender_callable = sse_writer )(
189+ scope , receive , send
190+ )
191+ await read_stream_writer .aclose ()
192+ await write_stream_reader .aclose ()
193+ await sse_stream_reader . aclose ( )
194+ logging .debug (f"Client session disconnected { session_id } " )
187195
188- logger .debug ("Starting SSE response task" )
189- tg .start_soon (response_wrapper , scope , receive , send )
196+ logger .debug ("Starting SSE response task" )
197+ tg .start_soon (response_wrapper , scope , receive , send )
190198
191- logger .debug ("Yielding read and write streams" )
192- yield (read_stream , write_stream )
199+ logger .debug ("Yielding read and write streams" )
200+ yield (read_stream , write_stream )
201+ finally :
202+ self ._read_stream_writers .pop (session_id , None )
203+ self ._session_owners .pop (session_id , None )
193204
194205 async def handle_post_message (self , scope : Scope , receive : Receive , send : Send ) -> None :
195206 logger .debug ("Handling POST message" )
196207 request = Request (scope , receive )
197208
198209 # Validate request headers for DNS rebinding protection
199210 error_response = await self ._security .validate_request (request , is_post = True )
200- if error_response : # pragma: no cover
211+ if error_response :
201212 return await error_response (scope , receive , send )
202213
203214 session_id_param = request .query_params .get ("session_id" )
@@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
220231 response = Response ("Could not find session" , status_code = 404 )
221232 return await response (scope , receive , send )
222233
234+ user = scope .get ("user" )
235+ requestor = authorization_context (user ) if isinstance (user , AuthenticatedUser ) else None
236+ if requestor != self ._session_owners .get (session_id ):
237+ # A session can only be used with the credential that created it.
238+ # Respond exactly as if the session did not exist.
239+ logger .warning ("Rejecting message for session %s: credential does not match" , session_id )
240+ response = Response ("Could not find session" , status_code = 404 )
241+ return await response (scope , receive , send )
242+
223243 body = await request .body ()
224244 logger .debug (f"Received JSON: { body } " )
225245
226246 try :
227247 message = types .jsonrpc_message_adapter .validate_json (body , by_name = False )
228248 logger .debug (f"Validated client message: { message } " )
229- except ValidationError as err : # pragma: no cover
249+ except ValidationError as err :
230250 logger .exception ("Failed to parse message" )
231251 response = Response ("Could not parse message" , status_code = 400 )
232252 await response (scope , receive , send )
0 commit comments