@@ -57,108 +57,101 @@ async def sse_client(
5757 write_stream : MemoryObjectSendStream [SessionMessage ]
5858 write_stream_reader : MemoryObjectReceiveStream [SessionMessage ]
5959
60- read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
61- write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
62-
63- async with anyio .create_task_group () as tg :
64- try :
65- logger .debug (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
66- async with httpx_client_factory (
67- headers = headers , auth = auth , timeout = httpx .Timeout (timeout , read = sse_read_timeout )
68- ) as client :
69- async with aconnect_sse (
70- client ,
71- "GET" ,
72- url ,
73- ) as event_source :
74- event_source .response .raise_for_status ()
75- logger .debug ("SSE connection established" )
76-
77- async def sse_reader (task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ):
78- try :
79- async for sse in event_source .aiter_sse (): # pragma: no branch
80- logger .debug (f"Received SSE event: { sse .event } " )
81- match sse .event :
82- case "endpoint" :
83- endpoint_url = urljoin (url , sse .data )
84- logger .debug (f"Received endpoint URL: { endpoint_url } " )
85-
86- url_parsed = urlparse (url )
87- endpoint_parsed = urlparse (endpoint_url )
88- if ( # pragma: no cover
89- url_parsed .netloc != endpoint_parsed .netloc
90- or url_parsed .scheme != endpoint_parsed .scheme
91- ):
92- error_msg = ( # pragma: no cover
93- f"Endpoint origin does not match connection origin: { endpoint_url } "
94- )
95- logger .error (error_msg ) # pragma: no cover
96- raise ValueError (error_msg ) # pragma: no cover
97-
98- if on_session_created :
99- session_id = _extract_session_id_from_endpoint (endpoint_url )
100- if session_id :
101- on_session_created (session_id )
102-
103- task_status .started (endpoint_url )
104-
105- case "message" :
106- # Skip empty data (keep-alive pings)
107- if not sse .data :
108- continue
109- try :
110- message = types .jsonrpc_message_adapter .validate_json (
111- sse .data , by_name = False
112- )
113- logger .debug (f"Received server message: { message } " )
114- except Exception as exc : # pragma: no cover
115- logger .exception ("Error parsing server message" ) # pragma: no cover
116- await read_stream_writer .send (exc ) # pragma: no cover
117- continue # pragma: no cover
118-
119- session_message = SessionMessage (message )
120- await read_stream_writer .send (session_message )
121- case _: # pragma: no cover
122- logger .warning (f"Unknown SSE event: { sse .event } " ) # pragma: no cover
123- except SSEError as sse_exc : # pragma: lax no cover
124- logger .exception ("Encountered SSE exception" )
125- raise sse_exc
126- except Exception as exc : # pragma: lax no cover
127- logger .exception ("Error in sse_reader" )
128- await read_stream_writer .send (exc )
129- finally :
130- await read_stream_writer .aclose ()
131-
132- async def post_writer (endpoint_url : str ):
133- try :
134- async with write_stream_reader :
135- async for session_message in write_stream_reader :
136- logger .debug (f"Sending client message: { session_message } " )
137- response = await client .post (
138- endpoint_url ,
139- json = session_message .message .model_dump (
140- by_alias = True ,
141- mode = "json" ,
142- exclude_unset = True ,
143- ),
60+ logger .debug (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
61+ async with httpx_client_factory (
62+ headers = headers , auth = auth , timeout = httpx .Timeout (timeout , read = sse_read_timeout )
63+ ) as client :
64+ async with aconnect_sse (client , "GET" , url ) as event_source :
65+ event_source .response .raise_for_status ()
66+ logger .debug ("SSE connection established" )
67+
68+ read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
69+ write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
70+
71+ async def sse_reader (task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ):
72+ try :
73+ async for sse in event_source .aiter_sse (): # pragma: no branch
74+ logger .debug (f"Received SSE event: { sse .event } " )
75+ match sse .event :
76+ case "endpoint" :
77+ endpoint_url = urljoin (url , sse .data )
78+ logger .debug (f"Received endpoint URL: { endpoint_url } " )
79+
80+ url_parsed = urlparse (url )
81+ endpoint_parsed = urlparse (endpoint_url )
82+ if ( # pragma: no cover
83+ url_parsed .netloc != endpoint_parsed .netloc
84+ or url_parsed .scheme != endpoint_parsed .scheme
85+ ):
86+ error_msg = ( # pragma: no cover
87+ f"Endpoint origin does not match connection origin: { endpoint_url } "
14488 )
145- response .raise_for_status ()
146- logger .debug (f"Client message sent successfully: { response .status_code } " )
147- except Exception : # pragma: lax no cover
148- logger .exception ("Error in post_writer" )
149- finally :
150- await write_stream .aclose ()
151-
152- endpoint_url = await tg .start (sse_reader )
153- logger .debug (f"Starting post writer with endpoint URL: { endpoint_url } " )
154- tg .start_soon (post_writer , endpoint_url )
155-
156- try :
157- yield read_stream , write_stream
158- finally :
159- tg .cancel_scope .cancel ()
160- finally :
161- await read_stream_writer .aclose ()
162- await write_stream .aclose ()
163- await read_stream .aclose ()
164- await write_stream_reader .aclose ()
89+ logger .error (error_msg ) # pragma: no cover
90+ raise ValueError (error_msg ) # pragma: no cover
91+
92+ if on_session_created :
93+ session_id = _extract_session_id_from_endpoint (endpoint_url )
94+ if session_id :
95+ on_session_created (session_id )
96+
97+ task_status .started (endpoint_url )
98+
99+ case "message" :
100+ # Skip empty data (keep-alive pings)
101+ if not sse .data :
102+ continue
103+ try :
104+ message = types .jsonrpc_message_adapter .validate_json (sse .data , by_name = False )
105+ logger .debug (f"Received server message: { message } " )
106+ except Exception as exc : # pragma: no cover
107+ logger .exception ("Error parsing server message" ) # pragma: no cover
108+ await read_stream_writer .send (exc ) # pragma: no cover
109+ continue # pragma: no cover
110+
111+ session_message = SessionMessage (message )
112+ await read_stream_writer .send (session_message )
113+ case _: # pragma: no cover
114+ logger .warning (f"Unknown SSE event: { sse .event } " ) # pragma: no cover
115+ except SSEError as sse_exc : # pragma: lax no cover
116+ logger .exception ("Encountered SSE exception" )
117+ raise sse_exc
118+ except Exception as exc : # pragma: lax no cover
119+ logger .exception ("Error in sse_reader" )
120+ await read_stream_writer .send (exc )
121+ finally :
122+ await read_stream_writer .aclose ()
123+
124+ async def post_writer (endpoint_url : str ):
125+ try :
126+ async with write_stream_reader , write_stream :
127+ async for session_message in write_stream_reader :
128+ logger .debug (f"Sending client message: { session_message } " )
129+ response = await client .post (
130+ endpoint_url ,
131+ json = session_message .message .model_dump (
132+ by_alias = True ,
133+ mode = "json" ,
134+ exclude_unset = True ,
135+ ),
136+ )
137+ response .raise_for_status ()
138+ logger .debug (f"Client message sent successfully: { response .status_code } " )
139+ except Exception : # pragma: lax no cover
140+ logger .exception ("Error in post_writer" )
141+
142+ # On Python 3.14, coverage.py reports a phantom branch arc on this
143+ # line (->yield) when nested two async-with levels deep. The branch
144+ # is the unreachable "did __aexit__ suppress?" arm for memory streams.
145+ async with ( # pragma: no branch
146+ read_stream_writer ,
147+ read_stream ,
148+ write_stream ,
149+ write_stream_reader ,
150+ anyio .create_task_group () as tg ,
151+ ):
152+ endpoint_url = await tg .start (sse_reader )
153+ logger .debug (f"Starting post writer with endpoint URL: { endpoint_url } " )
154+ tg .start_soon (post_writer , endpoint_url )
155+
156+ yield read_stream , write_stream
157+ tg .cancel_scope .cancel ()
0 commit comments