From f13fab3a6431b8941f29a35b3517133652a18384 Mon Sep 17 00:00:00 2001 From: Dale Seo Date: Fri, 6 Feb 2026 17:25:44 -0500 Subject: [PATCH] feat: enforce SEP-1577 MUST requirements for sampling with tools --- crates/rmcp/src/model.rs | 87 ++++++++++++ crates/rmcp/src/model/content.rs | 61 --------- crates/rmcp/src/service/server.rs | 27 ++++ crates/rmcp/tests/test_sampling.rs | 213 ++++++++++++++++++++++++++--- 4 files changed, 310 insertions(+), 78 deletions(-) diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index db6e927e..b15ccab3 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -1580,6 +1580,85 @@ impl TaskAugmentedRequestParamsMeta for CreateMessageRequestParams { } } +impl CreateMessageRequestParams { + /// Validate the sampling request parameters per SEP-1577 spec requirements. + /// + /// Checks: + /// - ToolUse content is only allowed in assistant messages + /// - ToolResult content is only allowed in user messages + /// - Messages with tool result content MUST NOT contain other content types + /// - Every assistant ToolUse must be balanced with a corresponding user ToolResult + pub fn validate(&self) -> Result<(), String> { + for msg in &self.messages { + for content in msg.content.iter() { + // ToolUse only in assistant messages, ToolResult only in user messages + match content { + SamplingMessageContent::ToolUse(_) if msg.role != Role::Assistant => { + return Err("ToolUse content is only allowed in assistant messages".into()); + } + SamplingMessageContent::ToolResult(_) if msg.role != Role::User => { + return Err("ToolResult content is only allowed in user messages".into()); + } + _ => {} + } + } + + // Tool result messages MUST NOT contain other content types + let contents: Vec<_> = msg.content.iter().collect(); + let has_tool_result = contents + .iter() + .any(|c| matches!(c, SamplingMessageContent::ToolResult(_))); + if has_tool_result + && contents + .iter() + .any(|c| !matches!(c, SamplingMessageContent::ToolResult(_))) + { + return Err( + "SamplingMessage with tool result content MUST NOT contain other content types" + .into(), + ); + } + } + + // Every assistant ToolUse must be balanced with a user ToolResult + self.validate_tool_use_result_balance()?; + + Ok(()) + } + + fn validate_tool_use_result_balance(&self) -> Result<(), String> { + let mut pending_tool_use_ids: Vec = Vec::new(); + for msg in &self.messages { + if msg.role == Role::Assistant { + for content in msg.content.iter() { + if let SamplingMessageContent::ToolUse(tu) = content { + pending_tool_use_ids.push(tu.id.clone()); + } + } + } else if msg.role == Role::User { + for content in msg.content.iter() { + if let SamplingMessageContent::ToolResult(tr) = content { + if !pending_tool_use_ids.contains(&tr.tool_use_id) { + return Err(format!( + "ToolResult with toolUseId '{}' has no matching ToolUse", + tr.tool_use_id + )); + } + pending_tool_use_ids.retain(|id| id != &tr.tool_use_id); + } + } + } + } + if !pending_tool_use_ids.is_empty() { + return Err(format!( + "ToolUse with id(s) {:?} not balanced with ToolResult", + pending_tool_use_ids + )); + } + Ok(()) + } +} + /// Deprecated: Use [`CreateMessageRequestParams`] instead (SEP-1319 compliance). #[deprecated(since = "0.13.0", note = "Use CreateMessageRequestParams instead")] pub type CreateMessageRequestParam = CreateMessageRequestParams; @@ -2229,6 +2308,14 @@ impl CreateMessageResult { pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence"; pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens"; pub const STOP_REASON_TOOL_USE: &str = "toolUse"; + + /// Validate the result per SEP-1577: role must be "assistant". + pub fn validate(&self) -> Result<(), String> { + if self.message.role != Role::Assistant { + return Err("CreateMessageResult role must be 'assistant'".into()); + } + Ok(()) + } } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] diff --git a/crates/rmcp/src/model/content.rs b/crates/rmcp/src/model/content.rs index 297bc751..beb4d9f5 100644 --- a/crates/rmcp/src/model/content.rs +++ b/crates/rmcp/src/model/content.rs @@ -129,67 +129,6 @@ impl ToolResultContent { } } -/// Assistant message content types (SEP-1577). -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub enum AssistantMessageContent { - Text(RawTextContent), - Image(RawImageContent), - Audio(RawAudioContent), - ToolUse(ToolUseContent), -} - -/// User message content types (SEP-1577). -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub enum UserMessageContent { - Text(RawTextContent), - Image(RawImageContent), - Audio(RawAudioContent), - ToolResult(ToolResultContent), -} - -impl AssistantMessageContent { - /// Create a text content - pub fn text(text: impl Into) -> Self { - Self::Text(RawTextContent { - text: text.into(), - meta: None, - }) - } - - /// Create a tool use content - pub fn tool_use( - id: impl Into, - name: impl Into, - input: super::JsonObject, - ) -> Self { - Self::ToolUse(ToolUseContent::new(id, name, input)) - } -} - -impl UserMessageContent { - /// Create a text content - pub fn text(text: impl Into) -> Self { - Self::Text(RawTextContent { - text: text.into(), - meta: None, - }) - } - - /// Create a tool result content - pub fn tool_result(tool_use_id: impl Into, content: Vec) -> Self { - Self::ToolResult(ToolResultContent::new(tool_use_id, content)) - } - - /// Create an error tool result content - pub fn tool_result_error(tool_use_id: impl Into, content: Vec) -> Self { - Self::ToolResult(ToolResultContent::error(tool_use_id, content)) - } -} - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 1ba578b7..eeb880c0 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -384,10 +384,37 @@ macro_rules! method { } impl Peer { + /// Check if the client supports sampling tools capability. + pub fn supports_sampling_tools(&self) -> bool { + if let Some(client_info) = self.peer_info() { + client_info + .capabilities + .sampling + .as_ref() + .and_then(|s| s.tools.as_ref()) + .is_some() + } else { + false + } + } + pub async fn create_message( &self, params: CreateMessageRequestParams, ) -> Result { + // MUST throw error when tools/toolChoice provided without capability + if (params.tools.is_some() || params.tool_choice.is_some()) + && !self.supports_sampling_tools() + { + return Err(ServiceError::McpError(ErrorData::invalid_params( + "tools or toolChoice provided but client does not support sampling tools capability", + None, + ))); + } + // Validate message structure + params + .validate() + .map_err(|e| ServiceError::McpError(ErrorData::invalid_params(e, None)))?; let result = self .send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest { method: Default::default(), diff --git a/crates/rmcp/tests/test_sampling.rs b/crates/rmcp/tests/test_sampling.rs index 9bedfb04..e5191d3c 100644 --- a/crates/rmcp/tests/test_sampling.rs +++ b/crates/rmcp/tests/test_sampling.rs @@ -1,5 +1,3 @@ -//cargo test --test test_sampling --features "client server" - mod common; use anyhow::Result; @@ -103,21 +101,17 @@ async fn test_sampling_context_inclusion_enum() -> Result<()> { async fn test_sampling_integration_with_test_handlers() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); - // Start server let server_handle = tokio::spawn(async move { let server = TestServer::new().serve(server_transport).await?; server.waiting().await?; anyhow::Ok(()) }); - // Start client that honors sampling requests let handler = TestClientHandler::new(true, true); let client = handler.clone().serve(client_transport).await?; - // Wait for initialization tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - // Test sampling with context inclusion let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { method: Default::default(), params: CreateMessageRequestParams { @@ -157,7 +151,6 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { ) .await?; - // Verify the response if let ClientResult::CreateMessageResult(result) = result { assert_eq!(result.message.role, Role::Assistant); assert_eq!(result.model, "test-model"); @@ -192,21 +185,17 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { async fn test_sampling_no_context_inclusion() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); - // Start server let server_handle = tokio::spawn(async move { let server = TestServer::new().serve(server_transport).await?; server.waiting().await?; anyhow::Ok(()) }); - // Start client that honors sampling requests let handler = TestClientHandler::new(true, true); let client = handler.clone().serve(client_transport).await?; - // Wait for initialization tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - // Test sampling without context inclusion let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { method: Default::default(), params: CreateMessageRequestParams { @@ -239,7 +228,6 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { ) .await?; - // Verify the response if let ClientResult::CreateMessageResult(result) = result { assert_eq!(result.message.role, Role::Assistant); assert_eq!(result.model, "test-model"); @@ -270,21 +258,17 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { async fn test_sampling_error_invalid_message_sequence() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); - // Start server let server_handle = tokio::spawn(async move { let server = TestServer::new().serve(server_transport).await?; server.waiting().await?; anyhow::Ok(()) }); - // Start client let handler = TestClientHandler::new(true, true); let client = handler.clone().serve(client_transport).await?; - // Wait for initialization tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - // Test sampling with no user messages (should fail) let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { method: Default::default(), params: CreateMessageRequestParams { @@ -319,7 +303,6 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { ) .await; - // This should result in an error assert!(result.is_err()); client.cancel().await?; @@ -637,3 +620,199 @@ async fn test_content_conversion_unsupported_variants() { "Resource content is not supported in sampling messages" ); } + +#[tokio::test] +async fn test_validate_rejects_tool_use_in_user_message() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![SamplingMessage::new( + Role::User, + SamplingMessageContent::tool_use("call_1", "some_tool", Default::default()), + )], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + let err = params.validate().unwrap_err(); + assert!( + err.contains("ToolUse content is only allowed in assistant messages"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_validate_rejects_tool_result_in_assistant_message() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![SamplingMessage::new( + Role::Assistant, + SamplingMessageContent::tool_result("call_1", vec![Content::text("result")]), + )], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + let err = params.validate().unwrap_err(); + assert!( + err.contains("ToolResult content is only allowed in user messages"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_validate_rejects_mixed_content_with_tool_result() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![SamplingMessage::new_multiple( + Role::User, + vec![ + SamplingMessageContent::tool_result("call_1", vec![Content::text("result")]), + SamplingMessageContent::text("some extra text"), + ], + )], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + let err = params.validate().unwrap_err(); + assert!( + err.contains("MUST NOT contain other content types"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_validate_rejects_unbalanced_tool_use_result() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![ + SamplingMessage::user_text("Hello"), + SamplingMessage::assistant_tool_use("call_1", "some_tool", Default::default()), + ], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + let err = params.validate().unwrap_err(); + assert!( + err.contains("not balanced with ToolResult"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_validate_rejects_tool_result_without_matching_use() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![ + SamplingMessage::user_text("Hello"), + SamplingMessage::user_tool_result("nonexistent_call", vec![Content::text("result")]), + ], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + let err = params.validate().unwrap_err(); + assert!( + err.contains("has no matching ToolUse"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_validate_accepts_valid_tool_conversation() { + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![ + SamplingMessage::user_text("What's the weather?"), + SamplingMessage::assistant_tool_use( + "call_1", + "get_weather", + serde_json::json!({"location": "SF"}) + .as_object() + .unwrap() + .clone(), + ), + SamplingMessage::user_tool_result("call_1", vec![Content::text("72°F and sunny")]), + SamplingMessage::assistant_text("It's 72°F and sunny in SF."), + ], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + }; + + assert!(params.validate().is_ok()); +} + +#[tokio::test] +async fn test_create_message_result_validate_rejects_user_role() { + let result = CreateMessageResult { + message: SamplingMessage::user_text("This should not be a user message"), + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }; + + let err = result.validate().unwrap_err(); + assert!( + err.contains("role must be 'assistant'"), + "unexpected error: {err}" + ); +} + +#[tokio::test] +async fn test_create_message_result_validate_accepts_assistant_role() { + let result = CreateMessageResult { + message: SamplingMessage::assistant_text("Hello!"), + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }; + + assert!(result.validate().is_ok()); +}