|
11 | 11 | import anyio |
12 | 12 | import httpx |
13 | 13 | from anyio.abc import TaskGroup |
14 | | -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
15 | 14 | from httpx_sse import EventSource, ServerSentEvent, aconnect_sse |
16 | 15 | from pydantic import ValidationError |
17 | 16 |
|
18 | 17 | from mcp.client._transport import TransportStreams |
| 18 | +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams |
19 | 19 | from mcp.shared._httpx_utils import create_mcp_http_client |
20 | 20 | from mcp.shared.message import ClientMessageMetadata, SessionMessage |
21 | 21 | from mcp.types import ( |
|
38 | 38 |
|
39 | 39 | # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. |
40 | 40 | SessionMessageOrError = SessionMessage | Exception |
41 | | -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] |
42 | | -StreamReader = MemoryObjectReceiveStream[SessionMessage] |
| 41 | +StreamWriter = ContextSendStream[SessionMessageOrError] |
| 42 | +StreamReader = ContextReceiveStream[SessionMessage] |
43 | 43 |
|
44 | 44 | MCP_SESSION_ID = "mcp-session-id" |
45 | 45 | MCP_PROTOCOL_VERSION = "mcp-protocol-version" |
@@ -434,14 +434,15 @@ async def post_writer( |
434 | 434 | client: httpx.AsyncClient, |
435 | 435 | write_stream_reader: StreamReader, |
436 | 436 | read_stream_writer: StreamWriter, |
437 | | - write_stream: MemoryObjectSendStream[SessionMessage], |
| 437 | + write_stream: ContextSendStream[SessionMessage], |
438 | 438 | start_get_stream: Callable[[], None], |
439 | 439 | tg: TaskGroup, |
440 | 440 | ) -> None: |
441 | 441 | """Handle writing requests to the server.""" |
442 | 442 | try: |
443 | 443 | async with write_stream_reader: |
444 | | - async for session_message in write_stream_reader: |
| 444 | + |
| 445 | + async def _handle_message(session_message: SessionMessage) -> None: |
445 | 446 | message = session_message.message |
446 | 447 | metadata = ( |
447 | 448 | session_message.metadata |
@@ -478,6 +479,14 @@ async def handle_request_async(): |
478 | 479 | else: |
479 | 480 | await handle_request_async() |
480 | 481 |
|
| 482 | + async for session_message in write_stream_reader: |
| 483 | + sender_ctx = write_stream_reader.last_context |
| 484 | + if sender_ctx is not None: |
| 485 | + async with anyio.create_task_group() as tg_local: |
| 486 | + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) |
| 487 | + else: |
| 488 | + await _handle_message(session_message) # pragma: no cover |
| 489 | + |
481 | 490 | except Exception: # pragma: lax no cover |
482 | 491 | logger.exception("Error in post_writer") |
483 | 492 | finally: |
@@ -533,8 +542,8 @@ async def streamable_http_client( |
533 | 542 | Example: |
534 | 543 | See examples/snippets/clients/ for usage patterns. |
535 | 544 | """ |
536 | | - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) |
537 | | - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) |
| 545 | + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) |
| 546 | + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) |
538 | 547 |
|
539 | 548 | # Determine if we need to create and manage the client |
540 | 549 | client_provided = http_client is not None |
|
0 commit comments