Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 123 additions & 24 deletions crates/rmcp/src/transport/async_rw.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::{marker::PhantomData, sync::Arc};

// use crate::schema::*;
use futures::{SinkExt, StreamExt};
use futures::SinkExt;
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
sync::Mutex,
};
use tokio_util::{
bytes::{Buf, BufMut, BytesMut},
codec::{Decoder, Encoder, FramedRead, FramedWrite},
codec::{Decoder, Encoder, FramedWrite},
};

use super::{IntoTransport, Transport};
Expand Down Expand Up @@ -47,8 +46,10 @@ where
pub type TransportWriter<Role, W> = FramedWrite<W, JsonRpcMessageCodec<TxJsonRpcMessage<Role>>>;

pub struct AsyncRwTransport<Role: ServiceRole, R: AsyncRead, W: AsyncWrite> {
read: FramedRead<R, JsonRpcMessageCodec<RxJsonRpcMessage<Role>>>,
read: BufReader<R>,
line_buf: Vec<u8>,
write: Arc<Mutex<Option<TransportWriter<Role, W>>>>,
_role: PhantomData<fn() -> Role>,
}

impl<Role: ServiceRole, R, W> AsyncRwTransport<Role, R, W>
Expand All @@ -57,15 +58,17 @@ where
W: Send + AsyncWrite + Unpin + 'static,
{
pub fn new(read: R, write: W) -> Self {
let read = FramedRead::new(
read,
JsonRpcMessageCodec::<RxJsonRpcMessage<Role>>::default(),
);
let read = BufReader::new(read);
let write = Arc::new(Mutex::new(Some(FramedWrite::new(
write,
JsonRpcMessageCodec::<TxJsonRpcMessage<Role>>::default(),
))));
Self { read, write }
Self {
read,
line_buf: Vec::new(),
write,
_role: PhantomData,
}
}
}

Expand Down Expand Up @@ -116,15 +119,42 @@ where
}
}

fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<Role>>> {
let next = self.read.next();
async {
next.await.and_then(|e| {
e.inspect_err(|e| {
async fn receive(&mut self) -> Option<RxJsonRpcMessage<Role>> {
loop {
self.line_buf.clear();
match self.read.read_until(b'\n', &mut self.line_buf).await {
Ok(0) => return None,
Ok(_) => {}
Err(e) => {
tracing::error!("Error reading from stream: {}", e);
})
.ok()
})
return None;
}
}
let line = without_carriage_return(
self.line_buf.strip_suffix(b"\n").unwrap_or(&self.line_buf),
);
if line.is_empty() {
continue;
}
match try_parse_with_compatibility::<RxJsonRpcMessage<Role>>(line, "receive") {
Ok(Some(msg)) => return Some(msg),
Ok(None) => continue,
Err(JsonRpcMessageCodecError::Serde(e)) => {
tracing::debug!("Parse error on incoming message: {e}");
let mut write = self.write.lock().await;
let framed = write.as_mut()?;
let inner = framed.get_mut();
if inner.write_all(PARSE_ERROR_RESPONSE).await.is_err()
|| inner.flush().await.is_err()
{
return None;
}
}
Err(e) => {
tracing::error!("Error reading from stream: {}", e);
return None;
}
}
}
}

Expand Down Expand Up @@ -172,13 +202,18 @@ impl<T> JsonRpcMessageCodec<T> {
}

fn without_carriage_return(s: &[u8]) -> &[u8] {
if let Some(&b'\r') = s.last() {
&s[..s.len() - 1]
} else {
s
}
s.strip_suffix(b"\r").unwrap_or(s)
}

/// UTF-8 byte order mark. RFC 8259 §8.1 allows JSON parsers to ignore a leading BOM.
const UTF8_BOM: &[u8; 3] = b"\xEF\xBB\xBF";

// JSON-RPC 2.0 §5.1: https://www.jsonrpc.org/specification#error_object
// Hardcoded bytes because `RequestId` has no `Null` variant — we can't
// build an `id: null` JsonRpcError through the typed codec.
const PARSE_ERROR_RESPONSE: &[u8] =
b"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"Parse error\"},\"id\":null}\n";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there really no way to build something we can serialize to get this value instead of hardcoding it?


/// Check if a method is a standard MCP method (request, response, or notification).
/// This includes both requests and notifications defined in the MCP specification.
///
Expand Down Expand Up @@ -247,6 +282,7 @@ fn try_parse_with_compatibility<T: serde::de::DeserializeOwned>(
line: &[u8],
context: &str,
) -> Result<Option<T>, JsonRpcMessageCodecError> {
let line = line.strip_prefix(UTF8_BOM.as_slice()).unwrap_or(line);
if let Ok(line_str) = std::str::from_utf8(line) {
match serde_json::from_slice(line) {
Ok(item) => Ok(Some(item)),
Expand Down Expand Up @@ -406,7 +442,8 @@ impl<T: Serialize> Encoder<T> for JsonRpcMessageCodec<T> {

#[cfg(test)]
mod test {
use futures::{Sink, Stream};
use futures::{Sink, Stream, StreamExt};
use tokio_util::codec::FramedRead;

use super::*;
fn from_async_read<T: DeserializeOwned, R: AsyncRead>(reader: R) -> impl Stream<Item = T> {
Expand Down Expand Up @@ -555,4 +592,66 @@ mod test {

println!("Standard notifications are preserved, non-standard are handled gracefully");
}

#[tokio::test]
async fn test_decode_strips_utf8_bom() {
use futures::StreamExt;
use tokio::io::BufReader;

// Valid JSON-RPC message preceded by a UTF-8 BOM (EF BB BF). Some Windows
// tooling and editors prepend this; the codec should ignore it per RFC 8259 §8.1.
let mut data = Vec::new();
data.extend_from_slice(UTF8_BOM);
data.extend_from_slice(br#"{"jsonrpc":"2.0","method":"ping","id":1}"#);
data.push(b'\n');

let mut cursor = BufReader::new(&data[..]);
let mut stream = from_async_read::<serde_json::Value, _>(&mut cursor);

let item = stream
.next()
.await
.expect("should decode BOM-prefixed line");
assert_eq!(
item,
serde_json::json!({"jsonrpc": "2.0", "method": "ping", "id": 1})
);
}

#[cfg(feature = "server")]
#[tokio::test]
async fn receive_recovers_from_parse_error() {
use tokio::io::AsyncReadExt;

use crate::{RoleServer, transport::Transport};

// Two paired streams: `server_io` is wrapped by the transport; the test
// drives `client_io` to act as the peer.
let (server_io, client_io) = tokio::io::duplex(4096);
let (server_r, server_w) = tokio::io::split(server_io);
let (mut client_r, mut client_w) = tokio::io::split(client_io);

let mut transport = AsyncRwTransport::<RoleServer, _, _>::new(server_r, server_w);

client_w
.write_all(
b"not json\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}\n",
)
.await
.unwrap();

let received = transport
.receive()
.await
.expect("transport should recover and yield the next valid message");

let mut reply = vec![0u8; PARSE_ERROR_RESPONSE.len()];
client_r.read_exact(&mut reply).await.unwrap();

assert_eq!(reply, PARSE_ERROR_RESPONSE);
assert_eq!(
serde_json::to_value(&received).unwrap()["method"],
"notifications/initialized",
);
}
}
Loading