Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ where
let mut transport = transport.into_transport();
let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
#[derive(Debug)]
enum SendTaskResult {
Request {
Expand Down Expand Up @@ -860,7 +861,7 @@ where
}
let send = transport.send(m);
let current_span = tracing::Span::current();
tokio::spawn(async move {
response_send_tasks.spawn(async move {
let send_result = send.await;
if let Err(error) = send_result {
tracing::error!(%error, "fail to response message");
Expand Down Expand Up @@ -1008,6 +1009,44 @@ where
}
}
};

// Drain in-flight handler responses before closing the transport.
// When stdin EOF or cancellation arrives, spawned handler tasks may still
// be finishing. We need to:
// 1. Wait for response sends that were already spawned in the main loop
// 2. Drain any remaining handler responses from the channel
let drain_timeout = match &quit_reason {
QuitReason::Closed => Some(Duration::from_secs(5)),
QuitReason::Cancelled => Some(Duration::from_secs(2)),
_ => None,
};
if let Some(timeout_duration) = drain_timeout {
// Drop our sender so the channel closes once all handler task
// clones finish sending their responses (or are dropped).
drop(sink_proxy_tx);
let drain_result = tokio::time::timeout(timeout_duration, async {
// First, wait for any response sends already dispatched by the
// main loop (these hold transport write futures).
while let Some(result) = response_send_tasks.join_next().await {
if let Err(error) = result {
tracing::error!(%error, "response send task failed during drain");
}
}
// Then drain any handler responses still in the channel
// (handlers that finished after the loop broke).
while let Some(m) = sink_proxy_rx.recv().await {
if let Err(error) = transport.send(m).await {
tracing::error!(%error, "failed to send pending response during drain");
break;
}
}
})
.await;
if drain_result.is_err() {
tracing::warn!("timed out draining in-flight responses");
}
}

let sink_close_result = transport.close().await;
if let Err(e) = sink_close_result {
tracing::error!(%e, "fail to close sink");
Expand Down
158 changes: 158 additions & 0 deletions crates/rmcp/tests/test_inflight_response_drain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#![cfg(not(feature = "local"))]
// cargo test --test test_inflight_response_drain --features "client server"

use std::{
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll},
time::Duration,
};

use rmcp::{
ServerHandler, ServiceExt,
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo},
service::QuitReason,
tool, tool_handler, tool_router,
};
use tokio::io::{AsyncRead, ReadBuf};

// A slow tool server that sleeps before returning a response.
#[derive(Debug, Clone)]
struct SlowToolServer {
tool_router: ToolRouter<Self>,
}

impl SlowToolServer {
fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
}

#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct SlowToolRequest {
#[schemars(description = "how long to sleep in milliseconds")]
sleep_ms: u64,
}

#[tool_router]
impl SlowToolServer {
#[tool(description = "A tool that sleeps then returns")]
async fn slow_tool(
&self,
Parameters(SlowToolRequest { sleep_ms }): Parameters<SlowToolRequest>,
) -> String {
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
format!("done after {}ms", sleep_ms)
}
}

#[tool_handler]
impl ServerHandler for SlowToolServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
}
}

#[derive(Debug, Clone, Default)]
struct DummyClientHandler;

impl rmcp::ClientHandler for DummyClientHandler {
fn get_info(&self) -> ClientInfo {
ClientInfo::default()
}
}

/// An `AsyncRead` wrapper that delegates to the inner reader until signalled,
/// then returns EOF (read 0 bytes).
struct ClosableReader<R> {
inner: R,
eof_flag: Arc<AtomicBool>,
}

impl<R: AsyncRead + Unpin> AsyncRead for ClosableReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.eof_flag.load(Ordering::Acquire) {
return Poll::Ready(Ok(()));
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

/// When the server's input stream returns EOF while a tool handler is still
/// in-flight, the drain phase should flush pending responses before closing.
#[tokio::test]
async fn test_inflight_response_drain_on_eof() -> anyhow::Result<()> {
// Two unidirectional channels:
// client_write → server_read (client sends requests to server)
// server_write → client_read (server sends responses to client)
let (client_write, server_read) = tokio::io::duplex(4096);
let (server_write, client_read) = tokio::io::duplex(4096);

// Wrap the server's read side so we can signal EOF from the test.
let eof_flag = Arc::new(AtomicBool::new(false));
let closable_read = ClosableReader {
inner: server_read,
eof_flag: eof_flag.clone(),
};

let server_transport = (closable_read, server_write);
let client_transport = (client_read, client_write);

// Start server with slow tool handler
let server_handle = tokio::spawn(async move {
let server = SlowToolServer::new();
let running = server.serve(server_transport).await?;
let reason = running.waiting().await?;
assert!(
matches!(reason, QuitReason::Closed),
"expected Closed quit reason, got {:?}",
reason,
);
anyhow::Ok(())
});

// Start client
let client = DummyClientHandler.serve(client_transport).await?;

// Call the slow tool (200ms sleep). Concurrently, signal the server's
// read side to return EOF after the request has been sent but before
// the handler finishes.
let tool_future = client.call_tool(
CallToolRequestParams::new("slow_tool").with_arguments(
serde_json::json!({ "sleep_ms": 200 })
.as_object()
.unwrap()
.clone(),
),
);

let (tool_result, _) = tokio::join!(tool_future, async {
// Wait for the request to be sent and received by the server,
// then signal EOF on the server's read side.
tokio::time::sleep(Duration::from_millis(50)).await;
eof_flag.store(true, Ordering::Release);
});

// The tool result should still arrive thanks to the drain phase.
let result = tool_result?;
let text = result
.content
.first()
.and_then(|c| c.raw.as_text())
.map(|t| t.text.as_str())
.expect("expected text content in tool result");
assert_eq!(text, "done after 200ms");

server_handle.await??;
Ok(())
}
Loading