Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,15 @@ path = "tests/test_sse_concurrent_streams.rs"
name = "test_client_credentials"
required-features = ["auth"]
path = "tests/test_client_credentials.rs"

[[test]]
name = "test_streamable_http_stale_session"
required-features = [
"server",
"client",
"transport-streamable-http-server",
"transport-streamable-http-client",
"transport-streamable-http-client-reqwest"
]
path = "tests/test_streamable_http_stale_session.rs"

215 changes: 122 additions & 93 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,48 +600,51 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
.await;
let send_result = match response {
Err(StreamableHttpError::SessionExpired) => {
// The server discarded the session (HTTP 404). Perform a
// fresh handshake once and replay the original message.
tracing::info!(
"session expired (HTTP 404), attempting transparent re-initialization"
);
match Self::perform_reinitialization(
self.client.clone(),
saved_init_request.clone(),
config.uri.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await
{
Ok((new_session_id, new_protocol_headers)) => {
// Old streams hold the stale session ID; abort them
// so the new standalone SSE stream takes over.
streams.abort_all();
if !config.reinit_on_expired_session {
Err(StreamableHttpError::SessionExpired)
} else {
// The server discarded the session (HTTP 404). Perform a
// fresh handshake once and replay the original message.
tracing::info!(
"session expired (HTTP 404), attempting transparent re-initialization"
);
match Self::perform_reinitialization(
self.client.clone(),
saved_init_request.clone(),
config.uri.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await
{
Ok((new_session_id, new_protocol_headers)) => {
// Old streams hold the stale session ID; abort them
// so the new standalone SSE stream takes over.
streams.abort_all();

session_id = new_session_id;
protocol_headers = new_protocol_headers;
session_cleanup_info =
session_id.as_ref().map(|sid| SessionCleanupInfo {
client: self.client.clone(),
uri: config.uri.clone(),
session_id: sid.clone(),
auth_header: config.auth_header.clone(),
protocol_headers: protocol_headers.clone(),
});
session_id = new_session_id;
protocol_headers = new_protocol_headers;
session_cleanup_info =
session_id.as_ref().map(|sid| SessionCleanupInfo {
client: self.client.clone(),
uri: config.uri.clone(),
session_id: sid.clone(),
auth_header: config.auth_header.clone(),
protocol_headers: protocol_headers.clone(),
});

if let Some(new_sid) = &session_id {
let client = self.client.clone();
let uri = config.uri.clone();
let new_sid = new_sid.clone();
let auth_header = config.auth_header.clone();
let retry_config = self.config.retry_config.clone();
let sse_tx = sse_worker_tx.clone();
let task_ct = transport_task_ct.clone();
let config_uri = config.uri.clone();
let config_auth = config.auth_header.clone();
let spawn_headers = protocol_headers.clone();
streams.spawn(async move {
if let Some(new_sid) = &session_id {
let client = self.client.clone();
let uri = config.uri.clone();
let new_sid = new_sid.clone();
let auth_header = config.auth_header.clone();
let retry_config = self.config.retry_config.clone();
let sse_tx = sse_worker_tx.clone();
let task_ct = transport_task_ct.clone();
let config_uri = config.uri.clone();
let config_auth = config.auth_header.clone();
let spawn_headers = protocol_headers.clone();
streams.spawn(async move {
match client
.get_stream(
uri,
Expand Down Expand Up @@ -686,69 +689,71 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
}
});
}

let retry_response = self
.client
.post_message(
config.uri.clone(),
message,
session_id.clone(),
config.auth_header.clone(),
protocol_headers.clone(),
)
.await;
match retry_response {
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
tracing::trace!(
"client message accepted after re-init"
);
Ok(())
}
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
context.send_to_handler(msg).await?;
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
if let Some(sid) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: sid.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers.clone(),
},
self.config.retry_config.clone(),

let retry_response = self
.client
.post_message(
config.uri.clone(),
message,
session_id.clone(),
config.auth_header.clone(),
protocol_headers.clone(),
)
.await;
match retry_response {
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
tracing::trace!(
"client message accepted after re-init"
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
} else {
let sse_stream =
Ok(())
}
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
context.send_to_handler(msg).await?;
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
if let Some(sid) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: sid.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers
.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
} else {
let sse_stream =
SseAutoReconnectStream::never_reconnect(
stream,
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
}
tracing::trace!("got new sse stream after re-init");
Ok(())
}
tracing::trace!("got new sse stream after re-init");
Ok(())
}
}
Err(reinit_err) => Err(reinit_err),
}
Err(reinit_err) => Err(reinit_err),
}
} // else enable_reinit_on_expired_session
}
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
Expand Down Expand Up @@ -1051,6 +1056,16 @@ pub struct StreamableHttpClientTransportConfig {
pub auth_header: Option<String>,
/// Custom HTTP headers to include with every request
pub custom_headers: HashMap<HeaderName, HeaderValue>,
/// Enables transparent recovery when the server reports an expired session (`HTTP 404`).
///
/// When enabled, the transport performs one automatic recovery attempt:
/// 1. Replays the original `initialize` handshake to create a new session.
/// 2. Re-establishes streaming state for that session.
/// 3. Retries the in-flight request that failed with `SessionExpired`.
///
/// This recovery is best-effort and bounded to a single attempt. If recovery fails,
/// the original failure path is preserved and the error is returned to the caller.
pub reinit_on_expired_session: bool,
}

impl StreamableHttpClientTransportConfig {
Expand Down Expand Up @@ -1098,6 +1113,19 @@ impl StreamableHttpClientTransportConfig {
self.custom_headers = custom_headers;
self
}

/// Set whether the transport should attempt transparent re-initialization on session expiration
/// See [`Self::reinit_on_expired_session`] for details.
/// # Example
/// ```rust,no_run
/// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
/// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
/// .reinit_on_expired_session(true);
/// ```
pub fn reinit_on_expired_session(mut self, enable: bool) -> Self {
self.reinit_on_expired_session = enable;
self
}
}

impl Default for StreamableHttpClientTransportConfig {
Expand All @@ -1109,6 +1137,7 @@ impl Default for StreamableHttpClientTransportConfig {
allow_stateless: true,
auth_header: None,
custom_headers: HashMap::new(),
reinit_on_expired_session: true,
}
}
}
Loading
Loading