22import multiprocessing
33import socket
44from collections .abc import AsyncGenerator , Generator
5- from typing import Any
5+ from typing import Any , cast
66from unittest .mock import AsyncMock , MagicMock , Mock , patch
77from urllib .parse import urlparse
88
2020import mcp .client .sse
2121from mcp import types
2222from mcp .client .session import ClientSession
23- from mcp .client .sse import _extract_session_id_from_endpoint , sse_client
23+ from mcp .client .sse import _extract_session_id_from_endpoint , _resolve_endpoint_url , sse_client
2424from mcp .server import Server , ServerRequestContext
2525from mcp .server .sse import SseServerTransport
2626from mcp .server .transport_security import TransportSecuritySettings
@@ -229,6 +229,50 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229229 assert _extract_session_id_from_endpoint (endpoint_url ) == expected
230230
231231
232+ @pytest .mark .parametrize (
233+ ("sse_url" , "endpoint_data" , "messages_url" , "expected" ),
234+ [
235+ (
236+ "https://example.com/api/v1/sse" ,
237+ "/v1/messages/?session_id=abc123" ,
238+ None ,
239+ "https://example.com/v1/messages/?session_id=abc123" ,
240+ ),
241+ (
242+ "https://example.com/api/v1/sse" ,
243+ "/v1/messages/?session_id=abc123" ,
244+ "https://example.com/api/v1/messages/" ,
245+ "https://example.com/api/v1/messages/?session_id=abc123" ,
246+ ),
247+ (
248+ "https://example.com/api/v1/sse" ,
249+ "/v1/messages/?session_id=abc123" ,
250+ "/api/v1/messages/" ,
251+ "https://example.com/api/v1/messages/?session_id=abc123" ,
252+ ),
253+ (
254+ "https://example.com/api/v1/sse" ,
255+ "/v1/messages/?session_id=abc123" ,
256+ "https://example.com/api/v1/messages/?tenant=blue" ,
257+ "https://example.com/api/v1/messages/?tenant=blue&session_id=abc123" ,
258+ ),
259+ (
260+ "https://example.com/api/v1/sse" ,
261+ "/v1/messages/" ,
262+ "https://example.com/api/v1/messages/" ,
263+ "https://example.com/api/v1/messages/" ,
264+ ),
265+ ],
266+ )
267+ def test_resolve_endpoint_url_with_messages_url_override (
268+ sse_url : str ,
269+ endpoint_data : str ,
270+ messages_url : str | None ,
271+ expected : str ,
272+ ) -> None :
273+ assert _resolve_endpoint_url (sse_url , endpoint_data , messages_url ) == expected
274+
275+
232276@pytest .mark .anyio
233277async def test_sse_client_on_session_created_not_called_when_no_session_id (
234278 server : None , server_url : str , monkeypatch : pytest .MonkeyPatch
@@ -249,6 +293,60 @@ def mock_extract(url: str) -> None:
249293 callback_mock .assert_not_called ()
250294
251295
296+ @pytest .mark .anyio
297+ async def test_sse_client_uses_messages_url_override () -> None :
298+ init_result = InitializeResult (
299+ protocol_version = "2024-11-05" ,
300+ capabilities = ServerCapabilities (),
301+ server_info = Implementation (name = "test" , version = "1.0" ),
302+ )
303+ response = JSONRPCResponse (
304+ jsonrpc = "2.0" ,
305+ id = 0 ,
306+ result = init_result .model_dump (by_alias = True , exclude_none = True ),
307+ )
308+ response_json = response .model_dump_json (by_alias = True , exclude_none = True )
309+
310+ async def mock_aiter_sse () -> AsyncGenerator [ServerSentEvent , None ]:
311+ yield ServerSentEvent (event = "endpoint" , data = "/v1/messages/?session_id=abc123" )
312+ yield ServerSentEvent (event = "message" , data = response_json )
313+ await anyio .sleep_forever ()
314+
315+ mock_event_source = MagicMock ()
316+ mock_event_source .aiter_sse .return_value = mock_aiter_sse ()
317+ mock_event_source .response = MagicMock ()
318+ mock_event_source .response .raise_for_status = MagicMock ()
319+
320+ mock_aconnect_sse = MagicMock ()
321+ mock_aconnect_sse .__aenter__ = AsyncMock (return_value = mock_event_source )
322+ mock_aconnect_sse .__aexit__ = AsyncMock (return_value = None )
323+
324+ mock_client = MagicMock ()
325+ mock_client .__aenter__ = AsyncMock (return_value = mock_client )
326+ mock_client .__aexit__ = AsyncMock (return_value = None )
327+ mock_client .post = AsyncMock (return_value = MagicMock (status_code = 200 , raise_for_status = MagicMock ()))
328+
329+ def mock_httpx_client_factory (
330+ headers : dict [str , str ] | None = None ,
331+ timeout : httpx .Timeout | None = None ,
332+ auth : httpx .Auth | None = None ,
333+ ) -> httpx .AsyncClient :
334+ _ = (headers , timeout , auth )
335+ return cast (httpx .AsyncClient , mock_client )
336+
337+ with patch ("mcp.client.sse.aconnect_sse" , return_value = mock_aconnect_sse ):
338+ async with sse_client (
339+ "https://example.com/api/v1/sse" ,
340+ httpx_client_factory = mock_httpx_client_factory ,
341+ messages_url = "https://example.com/api/v1/messages/" ,
342+ ) as streams :
343+ async with ClientSession (* streams ) as session :
344+ await session .initialize ()
345+
346+ mock_client .post .assert_awaited ()
347+ assert mock_client .post .await_args .args [0 ] == "https://example.com/api/v1/messages/?session_id=abc123"
348+
349+
252350@pytest .fixture
253351async def initialized_sse_client_session (server : None , server_url : str ) -> AsyncGenerator [ClientSession , None ]:
254352 async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
0 commit comments