Skip to content

Commit 632ef64

Browse files
author
Henry Lee
committed
feat: allow overriding SSE messages endpoint
1 parent 3d7b311 commit 632ef64

2 files changed

Lines changed: 110 additions & 3 deletions

File tree

src/mcp/client/sse.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
2727
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
2828

2929

30+
def _resolve_endpoint_url(sse_url: str, endpoint_data: str, messages_url: str | None = None) -> str:
31+
if messages_url is None:
32+
return urljoin(sse_url, endpoint_data)
33+
34+
endpoint_url = urljoin(sse_url, messages_url)
35+
endpoint_query = urlparse(endpoint_data).query
36+
if endpoint_query:
37+
endpoint_parsed = urlparse(endpoint_url)
38+
query = "&".join(filter(None, [endpoint_parsed.query, endpoint_query]))
39+
endpoint_url = endpoint_parsed._replace(query=query).geturl()
40+
41+
return endpoint_url
42+
43+
3044
@asynccontextmanager
3145
async def sse_client(
3246
url: str,
@@ -36,6 +50,7 @@ async def sse_client(
3650
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3751
auth: httpx.Auth | None = None,
3852
on_session_created: Callable[[str], None] | None = None,
53+
messages_url: str | None = None,
3954
):
4055
"""Client transport for SSE.
4156
@@ -50,6 +65,9 @@ async def sse_client(
5065
httpx_client_factory: Factory function for creating the HTTPX client.
5166
auth: Optional HTTPX authentication handler.
5267
on_session_created: Optional callback invoked with the session ID when received.
68+
messages_url: Optional message endpoint URL to use instead of deriving it
69+
from the SSE endpoint event. Relative URLs are resolved against `url`,
70+
and any session query parameters from the endpoint event are preserved.
5371
"""
5472
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
5573
async with httpx_client_factory(
@@ -68,7 +86,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
6886
logger.debug(f"Received SSE event: {sse.event}")
6987
match sse.event:
7088
case "endpoint":
71-
endpoint_url = urljoin(url, sse.data)
89+
endpoint_url = _resolve_endpoint_url(url, sse.data, messages_url)
7290
logger.debug(f"Received endpoint URL: {endpoint_url}")
7391

7492
url_parsed = urlparse(url)

tests/shared/test_sse.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import multiprocessing
33
import socket
44
from collections.abc import AsyncGenerator, Generator
5-
from typing import Any
5+
from typing import Any, cast
66
from unittest.mock import AsyncMock, MagicMock, Mock, patch
77
from urllib.parse import urlparse
88

@@ -20,11 +20,12 @@
2020
import mcp.client.sse
2121
from mcp import types
2222
from mcp.client.session import ClientSession
23-
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
23+
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
2424
from mcp.server import Server, ServerRequestContext
2525
from mcp.server.sse import SseServerTransport
2626
from mcp.server.transport_security import TransportSecuritySettings
2727
from mcp.shared.exceptions import MCPError
28+
from mcp.shared.message import SessionMessage
2829
from mcp.types import (
2930
CallToolRequestParams,
3031
CallToolResult,
@@ -229,6 +230,50 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229230
assert _extract_session_id_from_endpoint(endpoint_url) == expected
230231

231232

233+
@pytest.mark.parametrize(
234+
("sse_url", "endpoint_data", "messages_url", "expected"),
235+
[
236+
(
237+
"https://example.com/api/v1/sse",
238+
"/v1/messages/?session_id=abc123",
239+
None,
240+
"https://example.com/v1/messages/?session_id=abc123",
241+
),
242+
(
243+
"https://example.com/api/v1/sse",
244+
"/v1/messages/?session_id=abc123",
245+
"https://example.com/api/v1/messages/",
246+
"https://example.com/api/v1/messages/?session_id=abc123",
247+
),
248+
(
249+
"https://example.com/api/v1/sse",
250+
"/v1/messages/?session_id=abc123",
251+
"/api/v1/messages/",
252+
"https://example.com/api/v1/messages/?session_id=abc123",
253+
),
254+
(
255+
"https://example.com/api/v1/sse",
256+
"/v1/messages/?session_id=abc123",
257+
"https://example.com/api/v1/messages/?tenant=blue",
258+
"https://example.com/api/v1/messages/?tenant=blue&session_id=abc123",
259+
),
260+
(
261+
"https://example.com/api/v1/sse",
262+
"/v1/messages/",
263+
"https://example.com/api/v1/messages/",
264+
"https://example.com/api/v1/messages/",
265+
),
266+
],
267+
)
268+
def test_resolve_endpoint_url_with_messages_url_override(
269+
sse_url: str,
270+
endpoint_data: str,
271+
messages_url: str | None,
272+
expected: str,
273+
) -> None:
274+
assert _resolve_endpoint_url(sse_url, endpoint_data, messages_url) == expected
275+
276+
232277
@pytest.mark.anyio
233278
async def test_sse_client_on_session_created_not_called_when_no_session_id(
234279
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
@@ -249,6 +294,50 @@ def mock_extract(url: str) -> None:
249294
callback_mock.assert_not_called()
250295

251296

297+
@pytest.mark.anyio
298+
async def test_sse_client_uses_messages_url_override() -> None:
299+
async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
300+
yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123")
301+
await anyio.sleep_forever()
302+
303+
mock_event_source = MagicMock()
304+
mock_event_source.aiter_sse.return_value = mock_aiter_sse()
305+
mock_event_source.response = MagicMock()
306+
mock_event_source.response.raise_for_status = MagicMock()
307+
308+
mock_aconnect_sse = MagicMock()
309+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
310+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
311+
312+
mock_client = MagicMock()
313+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
314+
mock_client.__aexit__ = AsyncMock(return_value=None)
315+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
316+
317+
def mock_httpx_client_factory(
318+
headers: dict[str, str] | None = None,
319+
timeout: httpx.Timeout | None = None,
320+
auth: httpx.Auth | None = None,
321+
) -> httpx.AsyncClient:
322+
_ = (headers, timeout, auth)
323+
return cast(httpx.AsyncClient, mock_client)
324+
325+
with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse):
326+
async with sse_client(
327+
"https://example.com/api/v1/sse",
328+
httpx_client_factory=mock_httpx_client_factory,
329+
messages_url="https://example.com/api/v1/messages/",
330+
) as (_, write_stream):
331+
message = types.JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
332+
await write_stream.send(SessionMessage(message))
333+
with anyio.fail_after(1): # pragma: no branch
334+
while not mock_client.post.await_count:
335+
await anyio.sleep(0.01)
336+
337+
mock_client.post.assert_awaited()
338+
assert mock_client.post.await_args.args[0] == "https://example.com/api/v1/messages/?session_id=abc123"
339+
340+
252341
@pytest.fixture
253342
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
254343
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:

0 commit comments

Comments
 (0)