Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 312 additions & 4 deletions tests/test_ve_identity_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_init_with_api_key_auth(self, mock_identity_client):

assert toolset._auth_config == config
assert toolset._connection_params == connection_params
assert toolset._tool_filter is None
assert toolset._tool_name_prefix is None
assert toolset.tool_filter is None
assert toolset.tool_name_prefix is None

@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
def test_init_with_oauth2_auth(self, mock_identity_client):
Expand All @@ -64,7 +64,7 @@ def test_init_with_tool_filter_list(self, mock_identity_client):
tool_filter=tool_filter,
)

assert toolset._tool_filter == tool_filter
assert toolset.tool_filter == tool_filter

@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
def test_init_with_tool_name_prefix(self, mock_identity_client):
Expand All @@ -79,7 +79,7 @@ def test_init_with_tool_name_prefix(self, mock_identity_client):
tool_name_prefix=prefix,
)

assert toolset._tool_name_prefix == prefix
assert toolset.tool_name_prefix == prefix

def test_init_with_none_connection_params(self):
"""Test that initialization fails with None connection_params."""
Expand Down Expand Up @@ -208,3 +208,311 @@ async def test_close_handles_exception(self, mock_identity_client):

# Should not raise, just log the error
await toolset.close()


class TestVeIdentityMcpToolsetGetToolsWithFilter:
"""Tests for VeIdentityMcpToolset.get_tools with tool_filter applied."""

@pytest.mark.asyncio
@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
@patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager")
@patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool")
async def test_get_tools_with_list_filter(
self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client
):
"""Test that get_tools correctly filters tools using list filter."""
from google.adk.agents.readonly_context import ReadonlyContext

# Setup mocks
mock_session_manager = Mock()
mock_session_manager_class.return_value = mock_session_manager
mock_session = AsyncMock()
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Create mock tools returned by MCP server
mock_tool1 = Mock()
mock_tool1.name = "read_file"
mock_tool2 = Mock()
mock_tool2.name = "write_file"
mock_tool3 = Mock()
mock_tool3.name = "delete_file"

mock_list_tools_result = Mock()
mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3]
mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result)

# Create mock VeIdentityMcpTool instances
mock_mcp_tool1 = Mock()
mock_mcp_tool1.name = "read_file"
mock_mcp_tool2 = Mock()
mock_mcp_tool2.name = "write_file"
mock_mcp_tool3 = Mock()
mock_mcp_tool3.name = "delete_file"

mock_mcp_tool_class.side_effect = [
mock_mcp_tool1,
mock_mcp_tool2,
mock_mcp_tool3,
]

# Create toolset with filter for only read_file and write_file
connection_params = Mock()
config = api_key_auth("test-provider")
tool_filter = ["read_file", "write_file"]

toolset = VeIdentityMcpToolset(
auth_config=config,
connection_params=connection_params,
tool_filter=tool_filter,
)

# Mock the credential
toolset._get_credential = AsyncMock(return_value=Mock())

# Create readonly context
mock_invocation_ctx = Mock()
mock_invocation_ctx.session = Mock()
mock_invocation_ctx.session.state = {}
mock_invocation_ctx.user_id = "test_user"
mock_invocation_ctx.agent = Mock()
mock_invocation_ctx.agent.name = "test_agent"
readonly_context = ReadonlyContext(
invocation_context=mock_invocation_ctx,
)

# Get tools
tools = await toolset.get_tools(readonly_context)

# Verify only filtered tools are returned
assert len(tools) == 2
assert mock_mcp_tool1 in tools
assert mock_mcp_tool2 in tools
assert mock_mcp_tool3 not in tools

@pytest.mark.asyncio
@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
@patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager")
@patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool")
async def test_get_tools_with_predicate_filter(
self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client
):
"""Test that get_tools correctly filters tools using predicate filter."""
from google.adk.agents.readonly_context import ReadonlyContext

# Setup mocks
mock_session_manager = Mock()
mock_session_manager_class.return_value = mock_session_manager
mock_session = AsyncMock()
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Create mock tools returned by MCP server
mock_tool1 = Mock()
mock_tool1.name = "test_read"
mock_tool2 = Mock()
mock_tool2.name = "prod_write"
mock_tool3 = Mock()
mock_tool3.name = "test_delete"

mock_list_tools_result = Mock()
mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3]
mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result)

# Create mock VeIdentityMcpTool instances
mock_mcp_tool1 = Mock()
mock_mcp_tool1.name = "test_read"
mock_mcp_tool2 = Mock()
mock_mcp_tool2.name = "prod_write"
mock_mcp_tool3 = Mock()
mock_mcp_tool3.name = "test_delete"

mock_mcp_tool_class.side_effect = [
mock_mcp_tool1,
mock_mcp_tool2,
mock_mcp_tool3,
]

# Create toolset with predicate filter (only tools starting with "test_")
connection_params = Mock()
config = api_key_auth("test-provider")

def test_only_filter(tool, context):
return tool.name.startswith("test_")

toolset = VeIdentityMcpToolset(
auth_config=config,
connection_params=connection_params,
tool_filter=test_only_filter,
)

# Mock the credential
toolset._get_credential = AsyncMock(return_value=Mock())

# Create readonly context
mock_invocation_ctx = Mock()
mock_invocation_ctx.session = Mock()
mock_invocation_ctx.session.state = {}
mock_invocation_ctx.user_id = "test_user"
mock_invocation_ctx.agent = Mock()
mock_invocation_ctx.agent.name = "test_agent"
readonly_context = ReadonlyContext(
invocation_context=mock_invocation_ctx,
)

# Get tools
tools = await toolset.get_tools(readonly_context)

# Verify only tools matching predicate are returned
assert len(tools) == 2
assert mock_mcp_tool1 in tools
assert mock_mcp_tool3 in tools
assert mock_mcp_tool2 not in tools


class TestVeIdentityMcpToolsetGetToolsWithPrefix:
"""Tests for VeIdentityMcpToolset.get_tools_with_prefix method."""

@pytest.mark.asyncio
@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
@patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager")
@patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool")
async def test_get_tools_with_prefix_applies_correctly(
self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client
):
"""Test that get_tools_with_prefix correctly adds prefix to tool names."""
from google.adk.agents.readonly_context import ReadonlyContext

# Setup mocks
mock_session_manager = Mock()
mock_session_manager_class.return_value = mock_session_manager
mock_session = AsyncMock()
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Create mock tools returned by MCP server
mock_tool1 = Mock()
mock_tool1.name = "read_file"
mock_tool2 = Mock()
mock_tool2.name = "write_file"

mock_list_tools_result = Mock()
mock_list_tools_result.tools = [mock_tool1, mock_tool2]
mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result)

# Create mock VeIdentityMcpTool instances with proper name attributes
mock_mcp_tool1 = Mock()
mock_mcp_tool1.name = "read_file"
mock_mcp_tool2 = Mock()
mock_mcp_tool2.name = "write_file"

mock_mcp_tool_class.side_effect = [mock_mcp_tool1, mock_mcp_tool2]

# Create toolset with prefix
connection_params = Mock()
config = api_key_auth("test-provider")
prefix = "github_"

toolset = VeIdentityMcpToolset(
auth_config=config,
connection_params=connection_params,
tool_name_prefix=prefix,
)

# Mock the credential
toolset._get_credential = AsyncMock(return_value=Mock())

# Create readonly context
mock_invocation_ctx = Mock()
mock_invocation_ctx.session = Mock()
mock_invocation_ctx.session.state = {}
mock_invocation_ctx.user_id = "test_user"
mock_invocation_ctx.agent = Mock()
mock_invocation_ctx.agent.name = "test_agent"
readonly_context = ReadonlyContext(
invocation_context=mock_invocation_ctx,
)

# Get tools with prefix via parent class method
tools = await toolset.get_tools_with_prefix(readonly_context)

# Verify tools have prefix applied (parent class adds underscore separator)
assert len(tools) == 2
assert tools[0].name == "github__read_file"
assert tools[1].name == "github__write_file"

@pytest.mark.asyncio
@patch("veadk.integrations.ve_identity.auth_mixins.IdentityClient")
@patch("veadk.integrations.ve_identity.mcp_toolset.MCPSessionManager")
@patch("veadk.integrations.ve_identity.mcp_toolset.VeIdentityMcpTool")
async def test_get_tools_with_prefix_and_filter_combined(
self, mock_mcp_tool_class, mock_session_manager_class, mock_identity_client
):
"""Test that get_tools_with_prefix works correctly with tool_filter combined."""
from google.adk.agents.readonly_context import ReadonlyContext

# Setup mocks
mock_session_manager = Mock()
mock_session_manager_class.return_value = mock_session_manager
mock_session = AsyncMock()
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Create mock tools returned by MCP server
mock_tool1 = Mock()
mock_tool1.name = "read_file"
mock_tool2 = Mock()
mock_tool2.name = "write_file"
mock_tool3 = Mock()
mock_tool3.name = "delete_file"

mock_list_tools_result = Mock()
mock_list_tools_result.tools = [mock_tool1, mock_tool2, mock_tool3]
mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result)

# Create mock VeIdentityMcpTool instances
mock_mcp_tool1 = Mock()
mock_mcp_tool1.name = "read_file"
mock_mcp_tool2 = Mock()
mock_mcp_tool2.name = "write_file"
mock_mcp_tool3 = Mock()
mock_mcp_tool3.name = "delete_file"

mock_mcp_tool_class.side_effect = [
mock_mcp_tool1,
mock_mcp_tool2,
mock_mcp_tool3,
]

# Create toolset with both prefix and filter
connection_params = Mock()
config = api_key_auth("test-provider")
prefix = "github_"
tool_filter = ["read_file", "write_file"]

toolset = VeIdentityMcpToolset(
auth_config=config,
connection_params=connection_params,
tool_name_prefix=prefix,
tool_filter=tool_filter,
)

# Mock the credential
toolset._get_credential = AsyncMock(return_value=Mock())

# Create readonly context
mock_invocation_ctx = Mock()
mock_invocation_ctx.session = Mock()
mock_invocation_ctx.session.state = {}
mock_invocation_ctx.user_id = "test_user"
mock_invocation_ctx.agent = Mock()
mock_invocation_ctx.agent.name = "test_agent"
readonly_context = ReadonlyContext(
invocation_context=mock_invocation_ctx,
)

# Get tools with prefix via parent class method
tools = await toolset.get_tools_with_prefix(readonly_context)

# Verify only filtered tools have prefix applied (parent class adds underscore separator)
assert len(tools) == 2
tool_names = [t.name for t in tools]
assert "github__read_file" in tool_names
assert "github__write_file" in tool_names
assert "github__delete_file" not in tool_names
27 changes: 3 additions & 24 deletions veadk/integrations/ve_identity/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ def __init__(
if not connection_params:
raise ValueError("Missing connection params in VeIdentityMcpToolset.")

# Initialize mixins first
# Initialize mixins and parent class with all parameters
super().__init__(
auth_config=auth_config,
tool_filter=tool_filter,
tool_name_prefix=tool_name_prefix,
)

# Store Identity specific configuration
self._auth_config = auth_config
self._connection_params = connection_params
self._tool_filter = tool_filter
self._tool_name_prefix = tool_name_prefix
self._errlog = errlog

# Create MCP session manager
Expand Down Expand Up @@ -196,31 +196,10 @@ async def get_tools(
auth_config=self._auth_config,
)

# Apply tool name prefix if specified
if self._tool_name_prefix:
mcp_tool._name = f"{self._tool_name_prefix}{mcp_tool.name}"

if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools

def _is_tool_selected(
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
) -> bool:
"""Check if a tool should be included based on filters and context."""
# Apply tool filter if specified
if self._tool_filter is not None:
if callable(self._tool_filter):
# ToolPredicate function
if not self._tool_filter(tool, readonly_context):
return False
elif isinstance(self._tool_filter, list):
# List of tool names
if tool.name not in self._tool_filter:
return False

return True

async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.

Expand Down