diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9c613398..0f7446e4 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -265,3 +265,15 @@ path = "tests/test_sse_concurrent_streams.rs" name = "test_client_credentials" required-features = ["auth"] path = "tests/test_client_credentials.rs" + +[[test]] +name = "test_streamable_http_stale_session" +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "transport-streamable-http-client", + "transport-streamable-http-client-reqwest" +] +path = "tests/test_streamable_http_stale_session.rs" + diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index bbb98bf3..9a27b493 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -600,48 +600,51 @@ impl Worker for StreamableHttpClientWorker { .await; let send_result = match response { Err(StreamableHttpError::SessionExpired) => { - // The server discarded the session (HTTP 404). Perform a - // fresh handshake once and replay the original message. - tracing::info!( - "session expired (HTTP 404), attempting transparent re-initialization" - ); - match Self::perform_reinitialization( - self.client.clone(), - saved_init_request.clone(), - config.uri.clone(), - config.auth_header.clone(), - config.custom_headers.clone(), - ) - .await - { - Ok((new_session_id, new_protocol_headers)) => { - // Old streams hold the stale session ID; abort them - // so the new standalone SSE stream takes over. - streams.abort_all(); + if !config.reinit_on_expired_session { + Err(StreamableHttpError::SessionExpired) + } else { + // The server discarded the session (HTTP 404). Perform a + // fresh handshake once and replay the original message. + tracing::info!( + "session expired (HTTP 404), attempting transparent re-initialization" + ); + match Self::perform_reinitialization( + self.client.clone(), + saved_init_request.clone(), + config.uri.clone(), + config.auth_header.clone(), + config.custom_headers.clone(), + ) + .await + { + Ok((new_session_id, new_protocol_headers)) => { + // Old streams hold the stale session ID; abort them + // so the new standalone SSE stream takes over. + streams.abort_all(); - session_id = new_session_id; - protocol_headers = new_protocol_headers; - session_cleanup_info = - session_id.as_ref().map(|sid| SessionCleanupInfo { - client: self.client.clone(), - uri: config.uri.clone(), - session_id: sid.clone(), - auth_header: config.auth_header.clone(), - protocol_headers: protocol_headers.clone(), - }); + session_id = new_session_id; + protocol_headers = new_protocol_headers; + session_cleanup_info = + session_id.as_ref().map(|sid| SessionCleanupInfo { + client: self.client.clone(), + uri: config.uri.clone(), + session_id: sid.clone(), + auth_header: config.auth_header.clone(), + protocol_headers: protocol_headers.clone(), + }); - if let Some(new_sid) = &session_id { - let client = self.client.clone(); - let uri = config.uri.clone(); - let new_sid = new_sid.clone(); - let auth_header = config.auth_header.clone(); - let retry_config = self.config.retry_config.clone(); - let sse_tx = sse_worker_tx.clone(); - let task_ct = transport_task_ct.clone(); - let config_uri = config.uri.clone(); - let config_auth = config.auth_header.clone(); - let spawn_headers = protocol_headers.clone(); - streams.spawn(async move { + if let Some(new_sid) = &session_id { + let client = self.client.clone(); + let uri = config.uri.clone(); + let new_sid = new_sid.clone(); + let auth_header = config.auth_header.clone(); + let retry_config = self.config.retry_config.clone(); + let sse_tx = sse_worker_tx.clone(); + let task_ct = transport_task_ct.clone(); + let config_uri = config.uri.clone(); + let config_auth = config.auth_header.clone(); + let spawn_headers = protocol_headers.clone(); + streams.spawn(async move { match client .get_stream( uri, @@ -686,69 +689,71 @@ impl Worker for StreamableHttpClientWorker { } } }); - } - - let retry_response = self - .client - .post_message( - config.uri.clone(), - message, - session_id.clone(), - config.auth_header.clone(), - protocol_headers.clone(), - ) - .await; - match retry_response { - Err(e) => Err(e), - Ok(StreamableHttpPostResponse::Accepted) => { - tracing::trace!( - "client message accepted after re-init" - ); - Ok(()) - } - Ok(StreamableHttpPostResponse::Json(msg, ..)) => { - context.send_to_handler(msg).await?; - Ok(()) } - Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { - if let Some(sid) = &session_id { - let sse_stream = SseAutoReconnectStream::new( - stream, - StreamableHttpClientReconnect { - client: self.client.clone(), - session_id: sid.clone(), - uri: config.uri.clone(), - auth_header: config.auth_header.clone(), - custom_headers: protocol_headers.clone(), - }, - self.config.retry_config.clone(), + + let retry_response = self + .client + .post_message( + config.uri.clone(), + message, + session_id.clone(), + config.auth_header.clone(), + protocol_headers.clone(), + ) + .await; + match retry_response { + Err(e) => Err(e), + Ok(StreamableHttpPostResponse::Accepted) => { + tracing::trace!( + "client message accepted after re-init" ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); - } else { - let sse_stream = + Ok(()) + } + Ok(StreamableHttpPostResponse::Json(msg, ..)) => { + context.send_to_handler(msg).await?; + Ok(()) + } + Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { + if let Some(sid) = &session_id { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: sid.clone(), + uri: config.uri.clone(), + auth_header: config.auth_header.clone(), + custom_headers: protocol_headers + .clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } else { + let sse_stream = SseAutoReconnectStream::never_reconnect( stream, StreamableHttpError::::UnexpectedEndOfStream, ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } + tracing::trace!("got new sse stream after re-init"); + Ok(()) } - tracing::trace!("got new sse stream after re-init"); - Ok(()) } } + Err(reinit_err) => Err(reinit_err), } - Err(reinit_err) => Err(reinit_err), - } + } // else enable_reinit_on_expired_session } Err(e) => Err(e), Ok(StreamableHttpPostResponse::Accepted) => { @@ -1051,6 +1056,16 @@ pub struct StreamableHttpClientTransportConfig { pub auth_header: Option, /// Custom HTTP headers to include with every request pub custom_headers: HashMap, + /// Enables transparent recovery when the server reports an expired session (`HTTP 404`). + /// + /// When enabled, the transport performs one automatic recovery attempt: + /// 1. Replays the original `initialize` handshake to create a new session. + /// 2. Re-establishes streaming state for that session. + /// 3. Retries the in-flight request that failed with `SessionExpired`. + /// + /// This recovery is best-effort and bounded to a single attempt. If recovery fails, + /// the original failure path is preserved and the error is returned to the caller. + pub reinit_on_expired_session: bool, } impl StreamableHttpClientTransportConfig { @@ -1098,6 +1113,19 @@ impl StreamableHttpClientTransportConfig { self.custom_headers = custom_headers; self } + + /// Set whether the transport should attempt transparent re-initialization on session expiration + /// See [`Self::reinit_on_expired_session`] for details. + /// # Example + /// ```rust,no_run + /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000") + /// .reinit_on_expired_session(true); + /// ``` + pub fn reinit_on_expired_session(mut self, enable: bool) -> Self { + self.reinit_on_expired_session = enable; + self + } } impl Default for StreamableHttpClientTransportConfig { @@ -1109,6 +1137,7 @@ impl Default for StreamableHttpClientTransportConfig { allow_stateless: true, auth_header: None, custom_headers: HashMap::new(), + reinit_on_expired_session: true, } } } diff --git a/crates/rmcp/tests/test_streamable_http_stale_session.rs b/crates/rmcp/tests/test_streamable_http_stale_session.rs index e33bf91c..b385cc52 100644 --- a/crates/rmcp/tests/test_streamable_http_stale_session.rs +++ b/crates/rmcp/tests/test_streamable_http_stale_session.rs @@ -8,7 +8,7 @@ use std::{collections::HashMap, sync::Arc}; use rmcp::{ - ServiceExt, + ServiceError, ServiceExt, model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, transport::{ StreamableHttpClientTransport, @@ -126,7 +126,8 @@ async fn test_transparent_reinitialization_on_session_expiry() -> anyhow::Result // Connect a full client transport (this performs initialize + notifications/initialized) let transport = StreamableHttpClientTransport::from_config( - StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")) + .reinit_on_expired_session(true), ); let client = ().serve(transport).await?; @@ -171,3 +172,69 @@ async fn test_transparent_reinitialization_on_session_expiry() -> anyhow::Result Ok(()) } + +/// Verify that when `reinit_on_expired_session` is false and the server loses the session, +/// the client receives a `SessionExpired` transport error instead of retrying. +#[tokio::test] +async fn test_session_expired_error_when_reinit_disabled() -> anyhow::Result<()> { + let ct = CancellationToken::new(); + let session_manager = Arc::new(LocalSessionManager::default()); + + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")) + .reinit_on_expired_session(false), + ); + let client = ().serve(transport).await?; + + // Verify the session is established + let _resources = client.list_all_resources().await?; + + // Force session expiry by removing all sessions from the server-side manager + { + let mut sessions = session_manager.sessions.write().await; + sessions.clear(); + } + + // This call should fail with a SessionExpired transport error + let result = client.list_all_resources().await; + match result { + Err(ServiceError::TransportSend(ref dyn_err)) => { + let err_msg = format!("{dyn_err}"); + assert!( + err_msg.contains("Session expired"), + "expected 'Session expired' in error message, got: {err_msg}" + ); + } + other => panic!("expected TransportSend(SessionExpired), got: {other:?}"), + } + + let _ = client.cancel().await; + ct.cancel(); + server_handle.await?; + + Ok(()) +}