Skip to content

Commit bbf6277

Browse files
author
Henry Lee
committed
feat: allow overriding SSE messages endpoint
1 parent 3d7b311 commit bbf6277

2 files changed

Lines changed: 119 additions & 3 deletions

File tree

src/mcp/client/sse.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
2727
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
2828

2929

30+
def _resolve_endpoint_url(sse_url: str, endpoint_data: str, messages_url: str | None = None) -> str:
31+
if messages_url is None:
32+
return urljoin(sse_url, endpoint_data)
33+
34+
endpoint_url = urljoin(sse_url, messages_url)
35+
endpoint_query = urlparse(endpoint_data).query
36+
if endpoint_query:
37+
endpoint_parsed = urlparse(endpoint_url)
38+
query = "&".join(filter(None, [endpoint_parsed.query, endpoint_query]))
39+
endpoint_url = endpoint_parsed._replace(query=query).geturl()
40+
41+
return endpoint_url
42+
43+
3044
@asynccontextmanager
3145
async def sse_client(
3246
url: str,
@@ -36,6 +50,7 @@ async def sse_client(
3650
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3751
auth: httpx.Auth | None = None,
3852
on_session_created: Callable[[str], None] | None = None,
53+
messages_url: str | None = None,
3954
):
4055
"""Client transport for SSE.
4156
@@ -50,6 +65,9 @@ async def sse_client(
5065
httpx_client_factory: Factory function for creating the HTTPX client.
5166
auth: Optional HTTPX authentication handler.
5267
on_session_created: Optional callback invoked with the session ID when received.
68+
messages_url: Optional message endpoint URL to use instead of deriving it
69+
from the SSE endpoint event. Relative URLs are resolved against `url`,
70+
and any session query parameters from the endpoint event are preserved.
5371
"""
5472
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
5573
async with httpx_client_factory(
@@ -68,7 +86,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
6886
logger.debug(f"Received SSE event: {sse.event}")
6987
match sse.event:
7088
case "endpoint":
71-
endpoint_url = urljoin(url, sse.data)
89+
endpoint_url = _resolve_endpoint_url(url, sse.data, messages_url)
7290
logger.debug(f"Received endpoint URL: {endpoint_url}")
7391

7492
url_parsed = urlparse(url)

tests/shared/test_sse.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import multiprocessing
33
import socket
44
from collections.abc import AsyncGenerator, Generator
5-
from typing import Any
5+
from typing import Any, cast
66
from unittest.mock import AsyncMock, MagicMock, Mock, patch
77
from urllib.parse import urlparse
88

@@ -20,7 +20,7 @@
2020
import mcp.client.sse
2121
from mcp import types
2222
from mcp.client.session import ClientSession
23-
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
23+
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
2424
from mcp.server import Server, ServerRequestContext
2525
from mcp.server.sse import SseServerTransport
2626
from mcp.server.transport_security import TransportSecuritySettings
@@ -229,6 +229,50 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229229
assert _extract_session_id_from_endpoint(endpoint_url) == expected
230230

231231

232+
@pytest.mark.parametrize(
233+
("sse_url", "endpoint_data", "messages_url", "expected"),
234+
[
235+
(
236+
"https://example.com/api/v1/sse",
237+
"/v1/messages/?session_id=abc123",
238+
None,
239+
"https://example.com/v1/messages/?session_id=abc123",
240+
),
241+
(
242+
"https://example.com/api/v1/sse",
243+
"/v1/messages/?session_id=abc123",
244+
"https://example.com/api/v1/messages/",
245+
"https://example.com/api/v1/messages/?session_id=abc123",
246+
),
247+
(
248+
"https://example.com/api/v1/sse",
249+
"/v1/messages/?session_id=abc123",
250+
"/api/v1/messages/",
251+
"https://example.com/api/v1/messages/?session_id=abc123",
252+
),
253+
(
254+
"https://example.com/api/v1/sse",
255+
"/v1/messages/?session_id=abc123",
256+
"https://example.com/api/v1/messages/?tenant=blue",
257+
"https://example.com/api/v1/messages/?tenant=blue&session_id=abc123",
258+
),
259+
(
260+
"https://example.com/api/v1/sse",
261+
"/v1/messages/",
262+
"https://example.com/api/v1/messages/",
263+
"https://example.com/api/v1/messages/",
264+
),
265+
],
266+
)
267+
def test_resolve_endpoint_url_with_messages_url_override(
268+
sse_url: str,
269+
endpoint_data: str,
270+
messages_url: str | None,
271+
expected: str,
272+
) -> None:
273+
assert _resolve_endpoint_url(sse_url, endpoint_data, messages_url) == expected
274+
275+
232276
@pytest.mark.anyio
233277
async def test_sse_client_on_session_created_not_called_when_no_session_id(
234278
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
@@ -249,6 +293,60 @@ def mock_extract(url: str) -> None:
249293
callback_mock.assert_not_called()
250294

251295

296+
@pytest.mark.anyio
297+
async def test_sse_client_uses_messages_url_override() -> None:
298+
init_result = InitializeResult(
299+
protocol_version="2024-11-05",
300+
capabilities=ServerCapabilities(),
301+
server_info=Implementation(name="test", version="1.0"),
302+
)
303+
response = JSONRPCResponse(
304+
jsonrpc="2.0",
305+
id=0,
306+
result=init_result.model_dump(by_alias=True, exclude_none=True),
307+
)
308+
response_json = response.model_dump_json(by_alias=True, exclude_none=True)
309+
310+
async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
311+
yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123")
312+
yield ServerSentEvent(event="message", data=response_json)
313+
await anyio.sleep_forever()
314+
315+
mock_event_source = MagicMock()
316+
mock_event_source.aiter_sse.return_value = mock_aiter_sse()
317+
mock_event_source.response = MagicMock()
318+
mock_event_source.response.raise_for_status = MagicMock()
319+
320+
mock_aconnect_sse = MagicMock()
321+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
322+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
323+
324+
mock_client = MagicMock()
325+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
326+
mock_client.__aexit__ = AsyncMock(return_value=None)
327+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
328+
329+
def mock_httpx_client_factory(
330+
headers: dict[str, str] | None = None,
331+
timeout: httpx.Timeout | None = None,
332+
auth: httpx.Auth | None = None,
333+
) -> httpx.AsyncClient:
334+
_ = (headers, timeout, auth)
335+
return cast(httpx.AsyncClient, mock_client)
336+
337+
with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse):
338+
async with sse_client(
339+
"https://example.com/api/v1/sse",
340+
httpx_client_factory=mock_httpx_client_factory,
341+
messages_url="https://example.com/api/v1/messages/",
342+
) as streams:
343+
async with ClientSession(*streams) as session:
344+
await session.initialize()
345+
346+
mock_client.post.assert_awaited()
347+
assert mock_client.post.await_args.args[0] == "https://example.com/api/v1/messages/?session_id=abc123"
348+
349+
252350
@pytest.fixture
253351
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
254352
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:

0 commit comments

Comments
 (0)