diff --git a/tests/test_ve_identity_mcp_toolset.py b/tests/test_ve_identity_mcp_toolset.py index 5b1f49c9..d45582e7 100644 --- a/tests/test_ve_identity_mcp_toolset.py +++ b/tests/test_ve_identity_mcp_toolset.py @@ -36,8 +36,8 @@ def test_init_with_api_key_auth(self, mock_identity_client): assert toolset._auth_config == config assert toolset._connection_params == connection_params - assert toolset._tool_filter is None - assert toolset._tool_name_prefix is None + assert toolset.tool_filter is None + assert toolset.tool_name_prefix is None @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") def test_init_with_oauth2_auth(self, mock_identity_client): @@ -64,7 +64,7 @@ def test_init_with_tool_filter_list(self, mock_identity_client): tool_filter=tool_filter, ) - assert toolset._tool_filter == tool_filter + assert toolset.tool_filter == tool_filter @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") def test_init_with_tool_name_prefix(self, mock_identity_client): @@ -79,7 +79,7 @@ def test_init_with_tool_name_prefix(self, mock_identity_client): tool_name_prefix=prefix, ) - assert toolset._tool_name_prefix == prefix + assert toolset.tool_name_prefix == prefix def test_init_with_none_connection_params(self): """Test that initialization fails with None connection_params.""" @@ -208,3 +208,311 @@ async def test_close_handles_exception(self, mock_identity_client): # Should not raise, just log the error await toolset.close() + + +class TestVeIdentityMcpToolsetGetToolsWithFilter: + """Tests for VeIdentityMcpToolset.get_tools with tool_filter applied.""" + + @pytest.mark.asyncio + @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") + @patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager") + @patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool") + async def test_get_tools_with_list_filter( + self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client + ): + """Test that get_tools correctly filters tools using list filter.""" + from google.adk.agents.readonly_context import ReadonlyContext + + # Setup mocks + mock_session_manager = Mock() + mock_session_manager_class.return_value = mock_session_manager + mock_session = AsyncMock() + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create mock tools returned by MCP server + mock_tool1 = Mock() + mock_tool1.name = "read_file" + mock_tool2 = Mock() + mock_tool2.name = "write_file" + mock_tool3 = Mock() + mock_tool3.name = "delete_file" + + mock_list_tools_result = Mock() + mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3] + mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) + + # Create mock VeIdentityMcpTool instances + mock_mcp_tool1 = Mock() + mock_mcp_tool1.name = "read_file" + mock_mcp_tool2 = Mock() + mock_mcp_tool2.name = "write_file" + mock_mcp_tool3 = Mock() + mock_mcp_tool3.name = "delete_file" + + mock_mcp_tool_class.side_effect = [ + mock_mcp_tool1, + mock_mcp_tool2, + mock_mcp_tool3, + ] + + # Create toolset with filter for only read_file and write_file + connection_params = Mock() + config = api_key_auth("test-provider") + tool_filter = ["read_file", "write_file"] + + toolset = VeIdentityMcpToolset( + auth_config=config, + connection_params=connection_params, + tool_filter=tool_filter, + ) + + # Mock the credential + toolset._get_credential = AsyncMock(return_value=Mock()) + + # Create readonly context + mock_invocation_ctx = Mock() + mock_invocation_ctx.session = Mock() + mock_invocation_ctx.session.state = {} + mock_invocation_ctx.user_id = "test_user" + mock_invocation_ctx.agent = Mock() + mock_invocation_ctx.agent.name = "test_agent" + readonly_context = ReadonlyContext( + invocation_context=mock_invocation_ctx, + ) + + # Get tools + tools = await toolset.get_tools(readonly_context) + + # Verify only filtered tools are returned + assert len(tools) == 2 + assert mock_mcp_tool1 in tools + assert mock_mcp_tool2 in tools + assert mock_mcp_tool3 not in tools + + @pytest.mark.asyncio + @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") + @patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager") + @patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool") + async def test_get_tools_with_predicate_filter( + self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client + ): + """Test that get_tools correctly filters tools using predicate filter.""" + from google.adk.agents.readonly_context import ReadonlyContext + + # Setup mocks + mock_session_manager = Mock() + mock_session_manager_class.return_value = mock_session_manager + mock_session = AsyncMock() + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create mock tools returned by MCP server + mock_tool1 = Mock() + mock_tool1.name = "test_read" + mock_tool2 = Mock() + mock_tool2.name = "prod_write" + mock_tool3 = Mock() + mock_tool3.name = "test_delete" + + mock_list_tools_result = Mock() + mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3] + mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) + + # Create mock VeIdentityMcpTool instances + mock_mcp_tool1 = Mock() + mock_mcp_tool1.name = "test_read" + mock_mcp_tool2 = Mock() + mock_mcp_tool2.name = "prod_write" + mock_mcp_tool3 = Mock() + mock_mcp_tool3.name = "test_delete" + + mock_mcp_tool_class.side_effect = [ + mock_mcp_tool1, + mock_mcp_tool2, + mock_mcp_tool3, + ] + + # Create toolset with predicate filter (only tools starting with "test_") + connection_params = Mock() + config = api_key_auth("test-provider") + + def test_only_filter(tool, context): + return tool.name.startswith("test_") + + toolset = VeIdentityMcpToolset( + auth_config=config, + connection_params=connection_params, + tool_filter=test_only_filter, + ) + + # Mock the credential + toolset._get_credential = AsyncMock(return_value=Mock()) + + # Create readonly context + mock_invocation_ctx = Mock() + mock_invocation_ctx.session = Mock() + mock_invocation_ctx.session.state = {} + mock_invocation_ctx.user_id = "test_user" + mock_invocation_ctx.agent = Mock() + mock_invocation_ctx.agent.name = "test_agent" + readonly_context = ReadonlyContext( + invocation_context=mock_invocation_ctx, + ) + + # Get tools + tools = await toolset.get_tools(readonly_context) + + # Verify only tools matching predicate are returned + assert len(tools) == 2 + assert mock_mcp_tool1 in tools + assert mock_mcp_tool3 in tools + assert mock_mcp_tool2 not in tools + + +class TestVeIdentityMcpToolsetGetToolsWithPrefix: + """Tests for VeIdentityMcpToolset.get_tools_with_prefix method.""" + + @pytest.mark.asyncio + @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") + @patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager") + @patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool") + async def test_get_tools_with_prefix_applies_correctly( + self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client + ): + """Test that get_tools_with_prefix correctly adds prefix to tool names.""" + from google.adk.agents.readonly_context import ReadonlyContext + + # Setup mocks + mock_session_manager = Mock() + mock_session_manager_class.return_value = mock_session_manager + mock_session = AsyncMock() + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create mock tools returned by MCP server + mock_tool1 = Mock() + mock_tool1.name = "read_file" + mock_tool2 = Mock() + mock_tool2.name = "write_file" + + mock_list_tools_result = Mock() + mock_list_tools_result.tools = [mock_tool1, mock_tool2] + mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) + + # Create mock VeIdentityMcpTool instances with proper name attributes + mock_mcp_tool1 = Mock() + mock_mcp_tool1.name = "read_file" + mock_mcp_tool2 = Mock() + mock_mcp_tool2.name = "write_file" + + mock_mcp_tool_class.side_effect = [mock_mcp_tool1, mock_mcp_tool2] + + # Create toolset with prefix + connection_params = Mock() + config = api_key_auth("test-provider") + prefix = "github_" + + toolset = VeIdentityMcpToolset( + auth_config=config, + connection_params=connection_params, + tool_name_prefix=prefix, + ) + + # Mock the credential + toolset._get_credential = AsyncMock(return_value=Mock()) + + # Create readonly context + mock_invocation_ctx = Mock() + mock_invocation_ctx.session = Mock() + mock_invocation_ctx.session.state = {} + mock_invocation_ctx.user_id = "test_user" + mock_invocation_ctx.agent = Mock() + mock_invocation_ctx.agent.name = "test_agent" + readonly_context = ReadonlyContext( + invocation_context=mock_invocation_ctx, + ) + + # Get tools with prefix via parent class method + tools = await toolset.get_tools_with_prefix(readonly_context) + + # Verify tools have prefix applied (parent class adds underscore separator) + assert len(tools) == 2 + assert tools[0].name == "github__read_file" + assert tools[1].name == "github__write_file" + + @pytest.mark.asyncio + @patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient") + @patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager") + @patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool") + async def test_get_tools_with_prefix_and_filter_combined( + self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client + ): + """Test that get_tools_with_prefix works correctly with tool_filter combined.""" + from google.adk.agents.readonly_context import ReadonlyContext + + # Setup mocks + mock_session_manager = Mock() + mock_session_manager_class.return_value = mock_session_manager + mock_session = AsyncMock() + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create mock tools returned by MCP server + mock_tool1 = Mock() + mock_tool1.name = "read_file" + mock_tool2 = Mock() + mock_tool2.name = "write_file" + mock_tool3 = Mock() + mock_tool3.name = "delete_file" + + mock_list_tools_result = Mock() + mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3] + mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) + + # Create mock VeIdentityMcpTool instances + mock_mcp_tool1 = Mock() + mock_mcp_tool1.name = "read_file" + mock_mcp_tool2 = Mock() + mock_mcp_tool2.name = "write_file" + mock_mcp_tool3 = Mock() + mock_mcp_tool3.name = "delete_file" + + mock_mcp_tool_class.side_effect = [ + mock_mcp_tool1, + mock_mcp_tool2, + mock_mcp_tool3, + ] + + # Create toolset with both prefix and filter + connection_params = Mock() + config = api_key_auth("test-provider") + prefix = "github_" + tool_filter = ["read_file", "write_file"] + + toolset = VeIdentityMcpToolset( + auth_config=config, + connection_params=connection_params, + tool_name_prefix=prefix, + tool_filter=tool_filter, + ) + + # Mock the credential + toolset._get_credential = AsyncMock(return_value=Mock()) + + # Create readonly context + mock_invocation_ctx = Mock() + mock_invocation_ctx.session = Mock() + mock_invocation_ctx.session.state = {} + mock_invocation_ctx.user_id = "test_user" + mock_invocation_ctx.agent = Mock() + mock_invocation_ctx.agent.name = "test_agent" + readonly_context = ReadonlyContext( + invocation_context=mock_invocation_ctx, + ) + + # Get tools with prefix via parent class method + tools = await toolset.get_tools_with_prefix(readonly_context) + + # Verify only filtered tools have prefix applied (parent class adds underscore separator) + assert len(tools) == 2 + tool_names = [t.name for t in tools] + assert "github__read_file" in tool_names + assert "github__write_file" in tool_names + assert "github__delete_file" not in tool_names diff --git a/veadk/integrations/ve_identity/mcp_toolset.py b/veadk/integrations/ve_identity/mcp_toolset.py index 95ea643f..9d562916 100644 --- a/veadk/integrations/ve_identity/mcp_toolset.py +++ b/veadk/integrations/ve_identity/mcp_toolset.py @@ -139,16 +139,16 @@ def __init__( if not connection_params: raise ValueError("Missing connection params in VeIdentityMcpToolset.") - # Initialize mixins first + # Initialize mixins and parent class with all parameters super().__init__( auth_config=auth_config, + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, ) # Store Identity specific configuration self._auth_config = auth_config self._connection_params = connection_params - self._tool_filter = tool_filter - self._tool_name_prefix = tool_name_prefix self._errlog = errlog # Create MCP session manager @@ -196,31 +196,10 @@ async def get_tools( auth_config=self._auth_config, ) - # Apply tool name prefix if specified - if self._tool_name_prefix: - mcp_tool._name = f"{self._tool_name_prefix}{mcp_tool.name}" - if self._is_tool_selected(mcp_tool, readonly_context): tools.append(mcp_tool) return tools - def _is_tool_selected( - self, tool: BaseTool, readonly_context: Optional[ReadonlyContext] - ) -> bool: - """Check if a tool should be included based on filters and context.""" - # Apply tool filter if specified - if self._tool_filter is not None: - if callable(self._tool_filter): - # ToolPredicate function - if not self._tool_filter(tool, readonly_context): - return False - elif isinstance(self._tool_filter, list): - # List of tool names - if tool.name not in self._tool_filter: - return False - - return True - async def close(self) -> None: """Performs cleanup and releases resources held by the toolset.