From 83092ce01307b9f25f8e1beb44c86093eb863bd5 Mon Sep 17 00:00:00 2001 From: skolton Date: Wed, 29 Apr 2026 09:51:31 +0200 Subject: [PATCH] Disallow running Agents using system user --- .basedpyright/baseline.json | 16 ------ splunklib/ai/agent.py | 56 ++++++++++++++++++-- splunklib/ai/tools.py | 35 +++--------- tests/integration/ai/test_agent_mcp_tools.py | 48 ++++++++++++++--- tests/unit/ai/test_security.py | 35 ++++++++++++ 5 files changed, 134 insertions(+), 56 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 6b319a36..4fe754e9 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -141,22 +141,6 @@ } ], "./splunklib/ai/tools.py": [ - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 48, - "endColumn": 56, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { diff --git a/splunklib/ai/agent.py b/splunklib/ai/agent.py index f5283f72..64db7923 100644 --- a/splunklib/ai/agent.py +++ b/splunklib/ai/agent.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import asyncio import os from collections.abc import AsyncGenerator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager @@ -46,6 +47,7 @@ _testing_app_id: str | None = None DEFAULT_TOOL_SETTINGS = ToolSettings(local=False, remote=None) +_SPLUNK_SYSTEM_USER = "splunk-system-user" @final @@ -181,9 +183,14 @@ async def _start_agent(self) -> AsyncGenerator[Self]: "internal error: _impl was not set to None after agent invocation" ) + splunk_username = await asyncio.to_thread( + lambda: _get_splunk_username(self._service) + ) + _validate_agent_privileges(splunk_username) + self.logger.debug(f"Creating agent {self.name=}; {self.trace_id=}") - self._tools = await self._load_tools(stack) + self._tools = await self._load_tools(stack, splunk_username) backend = get_backend() self._impl = await backend.create_agent(self) @@ -194,7 +201,9 @@ async def _start_agent(self) -> AsyncGenerator[Self]: self._impl = None - async def _load_tools(self, stack: AsyncExitStack) -> list[Tool]: + async def _load_tools( + self, stack: AsyncExitStack, splunk_username: str + ) -> list[Tool]: tools: list[Tool] = [] if not self.tool_settings.local and not self.tool_settings.remote: return tools @@ -225,7 +234,9 @@ async def _load_tools(self, stack: AsyncExitStack) -> list[Tool]: if self.tool_settings.remote: self.logger.debug("Probing MCP Server App availability") remote_session = await stack.enter_async_context( - connect_remote_mcp(self._service, app_id, self.trace_id) + connect_remote_mcp( + self._service, app_id, self.trace_id, splunk_username + ) ) if remote_session: @@ -301,6 +312,10 @@ async def invoke_with_data( ) +class PrivilegedExecutionError(Exception): + pass + + def _local_tools_path() -> tuple[str | None, str]: local_tools_path = _testing_local_tools_path app_id = _testing_app_id @@ -317,3 +332,38 @@ def _local_tools_path() -> tuple[str | None, str]: local_tools_path = None return local_tools_path, app_id + + +def _get_splunk_username(service: Service) -> str: + class Content(BaseModel): + username: str + + class Entry(BaseModel): + content: Content + + class ResponseBody(BaseModel): + entry: list[Entry] + + # Query Splunk API for the username. + res = service.get( + path_segment="authentication/current-context", + output_mode="json", + ) + + body = ResponseBody.model_validate_json(str(res.body)) # pyright: ignore[reportUnknownArgumentType] + if len(body.entry) == 0: + return "" + return body.entry[0].content.username + + +def _validate_agent_privileges(username: str) -> None: + """Enforces that the agent is not executed under a system account. + + Raises: + PrivilegedExecutionError: If the current execution context corresponds + to a disallowed system account. + """ + if username == _SPLUNK_SYSTEM_USER: + raise PrivilegedExecutionError( + f"Agent must not be executed by the system user: {_SPLUNK_SYSTEM_USER}" + ) diff --git a/splunklib/ai/tools.py b/splunklib/ai/tools.py index 5846f08e..20f4190b 100644 --- a/splunklib/ai/tools.py +++ b/splunklib/ai/tools.py @@ -247,37 +247,11 @@ def _convert_tool_result( ) -def _get_splunk_username(service: Service) -> str: - if service.username: - return service.username - - class Content(BaseModel): - username: str - - class Entry(BaseModel): - content: Content - - class ResponseBody(BaseModel): - entry: list[Entry] - - # In case service.username is unavailable, query Splunk API for the username. - # This can happen when a service is created with a token, without username/password. - res = service.get( - path_segment="authentication/current-context", - output_mode="json", - ) - - body = ResponseBody.model_validate_json(str(res.body)) - if len(body.entry) == 0: - return "" - return body.entry[0].content.username - - -def _get_mcp_token(service: Service) -> str | None: +def _get_mcp_token(splunk_username: str, service: Service) -> str | None: try: res = service.get( path_segment="mcp_token", - username=_get_splunk_username(service), + username=splunk_username, output_mode="json", ) except HTTPError as e: @@ -324,10 +298,13 @@ async def connect_remote_mcp( service: Service, app_id: str, trace_id: str, + splunk_username: str, ) -> AsyncGenerator[ClientSession | None]: management_url = f"{service.scheme}://{service.host}:{service.port}" mcp_url = f"{management_url}/services/mcp" - mcp_token = await asyncio.to_thread(lambda: _get_mcp_token(service)) + mcp_token = await asyncio.to_thread( + lambda: _get_mcp_token(splunk_username, service) + ) if mcp_token is not None: async with streamable_http_client( url=mcp_url, diff --git a/tests/integration/ai/test_agent_mcp_tools.py b/tests/integration/ai/test_agent_mcp_tools.py index 7bd4518d..3aa1d704 100644 --- a/tests/integration/ai/test_agent_mcp_tools.py +++ b/tests/integration/ai/test_agent_mcp_tools.py @@ -23,6 +23,9 @@ from starlette.routing import Mount, Route from splunklib.ai import Agent +from splunklib.ai.agent import ( + _get_splunk_username, # pyright: ignore[reportPrivateUsage] +) from splunklib.ai.engines.langchain import LOCAL_TOOL_PREFIX from splunklib.ai.messages import ( AIMessage, @@ -50,7 +53,6 @@ ) from splunklib.ai.tools import ( ToolType, - _get_splunk_username, # pyright: ignore[reportPrivateUsage] locate_app, ) from splunklib.client import connect @@ -296,6 +298,12 @@ async def mcp_token_handler(_: Request) -> Response: return JSONResponse(content={"token": AUTH_TOKEN}, status_code=200) +async def current_context_handler(_: Request) -> Response: + return JSONResponse( + content={"entry": [{"content": {"username": "admin"}}]}, status_code=200 + ) + + class TestRemoteTools(AITestCase): @patch( "splunklib.ai.agent._testing_local_tools_path", @@ -364,6 +372,11 @@ async def dispatch( routes=[ Mount("/services/mcp", app=mcp.streamable_http_app()), Route("/services/mcp_token", mcp_token_handler, methods=["GET"]), + Route( + "/services/authentication/current-context", + current_context_handler, + methods=["GET"], + ), ], lifespan=lifespan, middleware=[Middleware(MCPMiddleware)], @@ -376,7 +389,6 @@ async def dispatch( port=port, splunkToken=AUTH_TOKEN, autologin=True, - username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint ), ) @@ -427,7 +439,17 @@ async def dispatch( async def test_remote_tools_mcp_app_unavailable(self) -> None: pytest.importorskip("langchain_openai") - async with run_http_server(Starlette(routes=[])) as (host, port): + async with run_http_server( + Starlette( + routes=[ + Route( + "/services/authentication/current-context", + current_context_handler, + methods=["GET"], + ), + ] + ) + ) as (host, port): service = await asyncio.to_thread( lambda: connect( scheme="http", @@ -435,7 +457,6 @@ async def test_remote_tools_mcp_app_unavailable(self) -> None: port=port, splunkToken=AUTH_TOKEN, autologin=True, - username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint ), ) @@ -489,6 +510,11 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: routes=[ Mount("/services/mcp", app=mcp.streamable_http_app()), Route("/services/mcp_token", mcp_token_handler, methods=["GET"]), + Route( + "/services/authentication/current-context", + current_context_handler, + methods=["GET"], + ), ], lifespan=lifespan, ) @@ -500,7 +526,6 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: port=port, splunkToken=AUTH_TOKEN, autologin=True, - username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint ), ) @@ -579,6 +604,11 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: routes=[ Mount("/services/mcp", app=mcp.streamable_http_app()), Route("/services/mcp_token", mcp_token_handler, methods=["GET"]), + Route( + "/services/authentication/current-context", + current_context_handler, + methods=["GET"], + ), ], lifespan=lifespan, ) @@ -590,7 +620,6 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: port=port, splunkToken=AUTH_TOKEN, autologin=True, - username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint ), ) @@ -732,6 +761,11 @@ async def lifespan(_app: Starlette) -> AsyncGenerator[None, Any]: routes=[ Mount("/services/mcp", app=mcp.streamable_http_app()), Route("/services/mcp_token", mcp_token_handler, methods=["GET"]), + Route( + "/services/authentication/current-context", + current_context_handler, + methods=["GET"], + ), ], lifespan=lifespan, ) @@ -743,8 +777,6 @@ async def lifespan(_app: Starlette) -> AsyncGenerator[None, Any]: port=port, splunkToken=AUTH_TOKEN, autologin=True, - # To avoid mocking `authentication/current-context` endpoint - username="admin", ), ) diff --git a/tests/unit/ai/test_security.py b/tests/unit/ai/test_security.py index c2e57a07..ecb1fbd3 100644 --- a/tests/unit/ai/test_security.py +++ b/tests/unit/ai/test_security.py @@ -17,6 +17,8 @@ import pytest +from splunklib.ai import Agent, OpenAIModel +from splunklib.ai.agent import PrivilegedExecutionError from splunklib.ai.messages import AgentResponse, AIMessage, HumanMessage from splunklib.ai.middleware import ( AgentMiddlewareHandler, @@ -28,6 +30,8 @@ detect_injection, truncate_input, ) +from splunklib.client import Service +from splunklib.data import Record class TestDetectInjection(unittest.TestCase): @@ -168,3 +172,34 @@ async def handler(_request: AgentRequest) -> AgentResponse[Any]: ) await middleware.agent_middleware(request, handler) assert called + + +class TestPrivilegedExecution(unittest.IsolatedAsyncioTestCase): + @pytest.mark.asyncio + async def test_agent_with_system_user(self) -> None: + model = OpenAIModel( + model="test-model", base_url="test-url", api_key="test-api-key" + ) + + def handler(url: str, _message: dict[str, Any], **_kwargs: dict[str, Any]): + assert ( + url + == "https://localhost:8089/services/authentication/current-context?output_mode=json" + ) + return Record( + { + "status": 200, + "headers": [], + "body": '{"entry": [{"content": {"username": "splunk-system-user"}}]}', + } + ) + + service = Service(token="test-token", handler=handler) + + with pytest.raises(PrivilegedExecutionError, match="splunk-system-user"): + async with Agent( + model=model, + system_prompt="Your name is stefan", + service=service, + ): + ...