Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,85 @@ impl TaskAugmentedRequestParamsMeta for CreateMessageRequestParams {
}
}

impl CreateMessageRequestParams {
/// Validate the sampling request parameters per SEP-1577 spec requirements.
///
/// Checks:
/// - ToolUse content is only allowed in assistant messages
/// - ToolResult content is only allowed in user messages
/// - Messages with tool result content MUST NOT contain other content types
/// - Every assistant ToolUse must be balanced with a corresponding user ToolResult
pub fn validate(&self) -> Result<(), String> {
for msg in &self.messages {
for content in msg.content.iter() {
// ToolUse only in assistant messages, ToolResult only in user messages
match content {
SamplingMessageContent::ToolUse(_) if msg.role != Role::Assistant => {
return Err("ToolUse content is only allowed in assistant messages".into());
}
SamplingMessageContent::ToolResult(_) if msg.role != Role::User => {
return Err("ToolResult content is only allowed in user messages".into());
}
_ => {}
}
}

// Tool result messages MUST NOT contain other content types
let contents: Vec<_> = msg.content.iter().collect();
let has_tool_result = contents
.iter()
.any(|c| matches!(c, SamplingMessageContent::ToolResult(_)));
if has_tool_result
&& contents
.iter()
.any(|c| !matches!(c, SamplingMessageContent::ToolResult(_)))
{
return Err(
"SamplingMessage with tool result content MUST NOT contain other content types"
.into(),
);
}
}

// Every assistant ToolUse must be balanced with a user ToolResult
self.validate_tool_use_result_balance()?;

Ok(())
}

fn validate_tool_use_result_balance(&self) -> Result<(), String> {
let mut pending_tool_use_ids: Vec<String> = Vec::new();
for msg in &self.messages {
if msg.role == Role::Assistant {
for content in msg.content.iter() {
if let SamplingMessageContent::ToolUse(tu) = content {
pending_tool_use_ids.push(tu.id.clone());
}
}
} else if msg.role == Role::User {
for content in msg.content.iter() {
if let SamplingMessageContent::ToolResult(tr) = content {
if !pending_tool_use_ids.contains(&tr.tool_use_id) {
return Err(format!(
"ToolResult with toolUseId '{}' has no matching ToolUse",
tr.tool_use_id
));
}
pending_tool_use_ids.retain(|id| id != &tr.tool_use_id);
}
}
}
}
if !pending_tool_use_ids.is_empty() {
return Err(format!(
"ToolUse with id(s) {:?} not balanced with ToolResult",
pending_tool_use_ids
));
}
Ok(())
}
}

/// Deprecated: Use [`CreateMessageRequestParams`] instead (SEP-1319 compliance).
#[deprecated(since = "0.13.0", note = "Use CreateMessageRequestParams instead")]
pub type CreateMessageRequestParam = CreateMessageRequestParams;
Expand Down Expand Up @@ -2229,6 +2308,14 @@ impl CreateMessageResult {
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
pub const STOP_REASON_TOOL_USE: &str = "toolUse";

/// Validate the result per SEP-1577: role must be "assistant".
pub fn validate(&self) -> Result<(), String> {
if self.message.role != Role::Assistant {
return Err("CreateMessageResult role must be 'assistant'".into());
}
Ok(())
}
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
Expand Down
61 changes: 0 additions & 61 deletions crates/rmcp/src/model/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,67 +129,6 @@ impl ToolResultContent {
}
}

/// Assistant message content types (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum AssistantMessageContent {
Text(RawTextContent),
Image(RawImageContent),
Audio(RawAudioContent),
ToolUse(ToolUseContent),
}

/// User message content types (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum UserMessageContent {
Text(RawTextContent),
Image(RawImageContent),
Audio(RawAudioContent),
ToolResult(ToolResultContent),
}

impl AssistantMessageContent {
/// Create a text content
pub fn text(text: impl Into<String>) -> Self {
Self::Text(RawTextContent {
text: text.into(),
meta: None,
})
}

/// Create a tool use content
pub fn tool_use(
id: impl Into<String>,
name: impl Into<String>,
input: super::JsonObject,
) -> Self {
Self::ToolUse(ToolUseContent::new(id, name, input))
}
}

impl UserMessageContent {
/// Create a text content
pub fn text(text: impl Into<String>) -> Self {
Self::Text(RawTextContent {
text: text.into(),
meta: None,
})
}

/// Create a tool result content
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
}

/// Create an error tool result content
pub fn tool_result_error(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
Self::ToolResult(ToolResultContent::error(tool_use_id, content))
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
Expand Down
27 changes: 27 additions & 0 deletions crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,37 @@ macro_rules! method {
}

impl Peer<RoleServer> {
/// Check if the client supports sampling tools capability.
pub fn supports_sampling_tools(&self) -> bool {
if let Some(client_info) = self.peer_info() {
client_info
.capabilities
.sampling
.as_ref()
.and_then(|s| s.tools.as_ref())
.is_some()
} else {
false
}
}

pub async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, ServiceError> {
// MUST throw error when tools/toolChoice provided without capability
if (params.tools.is_some() || params.tool_choice.is_some())
&& !self.supports_sampling_tools()
{
return Err(ServiceError::McpError(ErrorData::invalid_params(
"tools or toolChoice provided but client does not support sampling tools capability",
None,
)));
}
// Validate message structure
params
.validate()
.map_err(|e| ServiceError::McpError(ErrorData::invalid_params(e, None)))?;
let result = self
.send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest {
method: Default::default(),
Expand Down
Loading