diff --git a/src/utils/responses.py b/src/utils/responses.py index 0ed7477bc..1327faeb2 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -8,6 +8,7 @@ from fastapi import HTTPException from llama_stack_api import OpenAIResponseObject +from llama_stack_api.openai_responses import ApprovalFilter from llama_stack_api.openai_responses import ( OpenAIResponseContentPartRefusal as ContentPartRefusal, ) @@ -732,12 +733,20 @@ async def get_mcp_tools( continue authorization = headers.pop("Authorization", None) + + require_approval = ( + mcp_server.require_approval + if isinstance(mcp_server.require_approval, str) + else ApprovalFilter( + always=mcp_server.require_approval.always or None, + never=mcp_server.require_approval.never or None, + ) + ) tools.append( InputToolMCP( - type="mcp", server_label=mcp_server.name, server_url=mcp_server.url, - require_approval="never", + require_approval=require_approval, headers=headers or None, authorization=authorization, ) diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index f338f5401..1a0668211 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -12,6 +12,7 @@ AllowedToolsFilter, OpenAIResponseInputToolChoiceAllowedTools, ) +from llama_stack_api.openai_responses import ApprovalFilter as LlamaStackApprovalFilter from llama_stack_api.openai_responses import ( OpenAIResponseInputTool as InputTool, ) @@ -60,7 +61,7 @@ import constants from models.api.requests import QueryRequest -from models.config import ByokRag, ModelContextProtocolServer +from models.config import ApprovalFilter, ByokRag, ModelContextProtocolServer from utils.responses import ( _build_chunk_attributes, _merge_tools, @@ -405,6 +406,50 @@ async def test_get_mcp_tools_without_auth(self, mocker: MockerFixture) -> None: assert tools_no_auth[0].server_label == "fs" assert tools_no_auth[0].server_url == "http://localhost:3000" assert tools_no_auth[0].headers is None + assert all(tool.require_approval == "never" for tool in tools_no_auth) + + @pytest.mark.asyncio + async def test_get_mcp_tools_require_approval_always( + self, mocker: MockerFixture + ) -> None: + """Test get_mcp_tools passes require_approval='always' from config.""" + server = ModelContextProtocolServer( + name="strict", + url="http://localhost:3000", + provider_id="mcp", + require_approval="always", + ) + mock_config = mocker.Mock() + mock_config.mcp_servers = [server] + mocker.patch("utils.responses.configuration", mock_config) + + tools = await get_mcp_tools(token=None) + assert len(tools) == 1 + assert tools[0].require_approval == "always" + + @pytest.mark.asyncio + async def test_get_mcp_tools_require_approval_filter( + self, mocker: MockerFixture + ) -> None: + """Test get_mcp_tools translates ApprovalFilter to Llama Stack format.""" + server = ModelContextProtocolServer( + name="github", + url="http://localhost:3000", + provider_id="mcp", + require_approval=ApprovalFilter( + always=["create_issue"], + never=["list_repos"], + ), + ) + mock_config = mocker.Mock() + mock_config.mcp_servers = [server] + mocker.patch("utils.responses.configuration", mock_config) + + tools = await get_mcp_tools(token=None) + assert len(tools) == 1 + assert isinstance(tools[0].require_approval, LlamaStackApprovalFilter) + assert tools[0].require_approval.always == ["create_issue"] + assert tools[0].require_approval.never == ["list_repos"] @pytest.mark.asyncio async def test_get_mcp_tools_with_kubernetes_auth(