Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tempfile::{tempdir, TempDir};
use tokio::io::AsyncReadExt;
Expand All @@ -34,7 +34,9 @@ use super::tool_execution::{ToolCallContext, ToolCallResult};
use super::types::SharedProvider;
use crate::agents::extension::{Envs, ProcessExit};
use crate::agents::extension_malware_check;
use crate::agents::mcp_client::{GooseMcpClientCapabilities, McpClient, McpClientTrait};
use crate::agents::mcp_client::{
GooseMcpClientCapabilities, McpClient, McpClientTrait, ToolCacheInvalidator,
};
use crate::builtin_extension::get_builtin_extension;
use crate::config::extensions::name_to_key;
use crate::config::search_path::SearchPaths;
Expand Down Expand Up @@ -235,6 +237,7 @@ struct ResolvedTool {
client: McpClientBox,
}

#[allow(clippy::too_many_arguments)]
async fn child_process_client(
mut command: Command,
timeout: &Option<u64>,
Expand All @@ -243,6 +246,7 @@ async fn child_process_client(
docker_container: Option<String>,
client_name: String,
capabilities: GooseMcpClientCapabilities,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
) -> ExtensionResult<McpClient> {
configure_subprocess(&mut command);

Expand Down Expand Up @@ -281,6 +285,7 @@ async fn child_process_client(
client_name,
capabilities,
working_dir.clone(),
tool_cache_invalidator,
)
.await;

Expand Down Expand Up @@ -413,6 +418,7 @@ async fn create_streamable_http_client(
client_name: String,
capabilities: GooseMcpClientCapabilities,
roots_dir: &std::path::Path,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
) -> ExtensionResult<Box<dyn McpClientTrait>> {
let mut default_headers = HeaderMap::new();

Expand Down Expand Up @@ -451,6 +457,7 @@ async fn create_streamable_http_client(
client_name.clone(),
capabilities.clone(),
roots_dir.to_path_buf(),
tool_cache_invalidator.clone(),
)
.await;

Expand Down Expand Up @@ -481,6 +488,7 @@ async fn create_streamable_http_client(
client_name,
capabilities,
roots_dir.to_path_buf(),
tool_cache_invalidator,
)
.await?,
))
Expand Down Expand Up @@ -601,6 +609,9 @@ impl ExtensionManager {
mcpui: self.capabilities.mcpui,
};

let invalidator: Option<Weak<dyn ToolCacheInvalidator>> =
Some(Arc::downgrade(self) as Weak<dyn ToolCacheInvalidator>);

create_streamable_http_client(
&resolved_uri,
*timeout,
Expand All @@ -610,6 +621,7 @@ impl ExtensionManager {
self.client_name.clone(),
capability,
&effective_working_dir,
invalidator,
)
.await?
}
Expand All @@ -621,6 +633,8 @@ impl ExtensionManager {
None
};
let normalized_name = name_to_key(name);
let invalidator: Option<Weak<dyn ToolCacheInvalidator>> =
Some(Arc::downgrade(self) as Weak<dyn ToolCacheInvalidator>);

if let Some(def) = PLATFORM_EXTENSIONS.get(normalized_name.as_str()) {
// Platform extension: create via in-process client factory
Expand Down Expand Up @@ -671,6 +685,7 @@ impl ExtensionManager {
Some(container_id.to_string()),
self.client_name.clone(),
capabilities,
invalidator.clone(),
)
.await?;
Box::new(client)
Expand All @@ -691,6 +706,7 @@ impl ExtensionManager {
self.client_name.clone(),
capabilities,
effective_working_dir.clone(),
invalidator,
)
.await?,
)
Expand Down Expand Up @@ -742,6 +758,8 @@ impl ExtensionManager {
let capabilities = GooseMcpClientCapabilities {
mcpui: self.capabilities.mcpui,
};
let invalidator: Option<Weak<dyn ToolCacheInvalidator>> =
Some(Arc::downgrade(self) as Weak<dyn ToolCacheInvalidator>);
let client = child_process_client(
command,
timeout,
Expand All @@ -750,6 +768,7 @@ impl ExtensionManager {
container.map(|c| c.id().to_string()),
self.client_name.clone(),
capabilities,
invalidator,
)
.await?;
Box::new(client)
Expand Down Expand Up @@ -778,6 +797,9 @@ impl ExtensionManager {
mcpui: self.capabilities.mcpui,
};

let invalidator: Option<Weak<dyn ToolCacheInvalidator>> =
Some(Arc::downgrade(self) as Weak<dyn ToolCacheInvalidator>);

let client = child_process_client(
command,
timeout,
Expand All @@ -786,6 +808,7 @@ impl ExtensionManager {
container.map(|c| c.id().to_string()),
self.client_name.clone(),
capabilities,
invalidator,
)
.await?;

Expand Down Expand Up @@ -1705,6 +1728,13 @@ impl ExtensionManager {
}
}

#[async_trait::async_trait]
impl ToolCacheInvalidator for ExtensionManager {
async fn invalidate_tools(&self) {
self.invalidate_tools_cache_and_bump_version().await;
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
145 changes: 143 additions & 2 deletions crates/goose/src/agents/mcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use rmcp::{
ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt,
};
use serde_json::Value;
use std::{path::PathBuf, sync::Arc, time::Duration};
use std::{path::PathBuf, sync::Arc, sync::Weak, time::Duration};
use tokio::sync::{
mpsc::{self, Sender},
Mutex,
Expand All @@ -37,6 +37,14 @@ pub type BoxError = Box<dyn std::error::Error + Sync + Send>;

pub type Error = rmcp::ServiceError;

/// Trait for invalidating the tool cache in the ExtensionManager.
/// Used by GooseClient to trigger cache invalidation when an MCP server
/// sends `notifications/tools/list_changed`.
#[async_trait::async_trait]
pub trait ToolCacheInvalidator: Send + Sync {
async fn invalidate_tools(&self);
}

#[async_trait::async_trait]
pub trait McpClientTrait: Send + Sync {
async fn list_tools(
Expand Down Expand Up @@ -113,6 +121,7 @@ pub struct GooseClient {
client_name: String,
capabilities: GooseMcpClientCapabilities,
working_dir: Arc<tokio::sync::RwLock<PathBuf>>,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
}

impl GooseClient {
Expand All @@ -122,6 +131,24 @@ impl GooseClient {
client_name: String,
capabilities: GooseMcpClientCapabilities,
working_dir: PathBuf,
) -> Self {
Self::new_with_invalidator(
handlers,
provider,
client_name,
capabilities,
working_dir,
None,
)
}

pub fn new_with_invalidator(
handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
provider: SharedProvider,
client_name: String,
capabilities: GooseMcpClientCapabilities,
working_dir: PathBuf,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
) -> Self {
GooseClient {
notification_handlers: handlers,
Expand All @@ -130,6 +157,7 @@ impl GooseClient {
client_name,
capabilities,
working_dir: Arc::new(tokio::sync::RwLock::new(working_dir)),
tool_cache_invalidator,
}
}

Expand Down Expand Up @@ -164,6 +192,20 @@ impl GooseClient {
.and_then(|(_, value)| value.as_str())
.map(|value| value.to_string())
}

/// Handles the tools/list_changed notification by upgrading the weak reference
/// to the cache invalidator and calling `invalidate_tools()`.
/// Returns `true` if the invalidator was successfully called, `false` if the
/// weak reference had been dropped.
async fn handle_tool_list_changed(&self) -> bool {
if let Some(ref invalidator) = self.tool_cache_invalidator {
if let Some(invalidator) = invalidator.upgrade() {
invalidator.invalidate_tools().await;
return true;
}
}
false
}
}

fn working_dir_roots(dir: &std::path::Path) -> ListRootsResult {
Expand Down Expand Up @@ -214,6 +256,14 @@ impl ClientHandler for GooseClient {
});
}

async fn on_tool_list_changed(
&self,
_context: rmcp::service::NotificationContext<rmcp::RoleClient>,
) {
tracing::info!("Received tools/list_changed notification from MCP server");
self.handle_tool_list_changed().await;
}

async fn create_message(
&self,
params: CreateMessageRequestParams,
Expand Down Expand Up @@ -396,6 +446,7 @@ impl McpClient {
client_name: String,
capabilities: GooseMcpClientCapabilities,
working_dir: PathBuf,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
) -> Result<Self, ClientInitializeError>
where
T: IntoTransport<RoleClient, E, A>,
Expand All @@ -409,10 +460,12 @@ impl McpClient {
client_name,
capabilities,
working_dir,
tool_cache_invalidator,
)
.await
}

#[allow(clippy::too_many_arguments)]
pub async fn connect_with_container<T, E, A>(
transport: T,
timeout: std::time::Duration,
Expand All @@ -421,6 +474,7 @@ impl McpClient {
client_name: String,
capabilities: GooseMcpClientCapabilities,
working_dir: PathBuf,
tool_cache_invalidator: Option<Weak<dyn ToolCacheInvalidator>>,
) -> Result<Self, ClientInitializeError>
where
T: IntoTransport<RoleClient, E, A>,
Expand All @@ -429,12 +483,13 @@ impl McpClient {
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));

let client = GooseClient::new(
let client = GooseClient::new_with_invalidator(
notification_subscribers.clone(),
provider,
client_name.clone(),
capabilities.clone(),
working_dir,
tool_cache_invalidator,
);
let client: rmcp::service::RunningService<rmcp::RoleClient, GooseClient> =
client.serve(transport).await?;
Expand Down Expand Up @@ -1009,4 +1064,90 @@ mod tests {
assert_eq!(result.roots[0].uri, "file:///tmp/test-project");
assert_eq!(result.roots[0].name.as_deref(), Some("working_directory"));
}

#[test]
fn test_on_tool_list_changed_calls_invalidator() {
struct MockInvalidator {
call_count: std::sync::atomic::AtomicUsize,
}

#[async_trait::async_trait]
impl ToolCacheInvalidator for MockInvalidator {
async fn invalidate_tools(&self) {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}

let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let invalidator = Arc::new(MockInvalidator {
call_count: std::sync::atomic::AtomicUsize::new(0),
});

let capabilities = GooseMcpClientCapabilities { mcpui: false };
let client = GooseClient::new_with_invalidator(
Arc::new(Mutex::new(Vec::new())),
Arc::new(Mutex::new(None)),
"test".to_string(),
capabilities,
std::env::current_dir().unwrap_or_default(),
Some(Arc::downgrade(&invalidator) as Weak<dyn ToolCacheInvalidator>),
);

// Exercise the actual code path that handles tool_list_changed.
// This tests the weak-ref upgrade and invalidation call.
let result1 = client.handle_tool_list_changed().await;
let result2 = client.handle_tool_list_changed().await;

assert!(result1, "should return true when invalidator is alive");
assert!(result2, "should return true when invalidator is alive");
assert_eq!(
invalidator
.call_count
.load(std::sync::atomic::Ordering::SeqCst),
2,
"invalidate_tools should have been called twice"
);
});
}

#[test]
fn test_on_tool_list_changed_noop_when_invalidator_dropped() {
struct MockInvalidator {
call_count: std::sync::atomic::AtomicUsize,
}

#[async_trait::async_trait]
impl ToolCacheInvalidator for MockInvalidator {
async fn invalidate_tools(&self) {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}

let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let invalidator = Arc::new(MockInvalidator {
call_count: std::sync::atomic::AtomicUsize::new(0),
});

let capabilities = GooseMcpClientCapabilities { mcpui: false };
let client = GooseClient::new_with_invalidator(
Arc::new(Mutex::new(Vec::new())),
Arc::new(Mutex::new(None)),
"test".to_string(),
capabilities,
std::env::current_dir().unwrap_or_default(),
Some(Arc::downgrade(&invalidator) as Weak<dyn ToolCacheInvalidator>),
);

// Drop the Arc so the weak ref becomes stale.
drop(invalidator);

// handle_tool_list_changed should return false (weak upgrade fails).
let result = client.handle_tool_list_changed().await;
assert!(!result, "should return false when invalidator is dropped");
});
}
}
Loading