From d7643fd2ce57a6a7999f2ade2989fe9753c6d20c Mon Sep 17 00:00:00 2001 From: sciapanCA Date: Wed, 1 Apr 2026 20:39:10 +0200 Subject: [PATCH] COD-661: Add artifact_relations tool --- smoke_test.py | 2 +- src/codealive_mcp_server.py | 3 +- src/tests/test_artifact_relations.py | 346 +++++++++++++++++++++++++++ src/tests/test_fetch_artifacts.py | 131 +++++++++- src/tools/__init__.py | 3 +- src/tools/artifact_relations.py | 164 +++++++++++++ src/tools/fetch_artifacts.py | 73 +++++- 7 files changed, 714 insertions(+), 8 deletions(-) create mode 100644 src/tests/test_artifact_relations.py create mode 100644 src/tools/artifact_relations.py diff --git a/smoke_test.py b/smoke_test.py index b44647a..056a789 100644 --- a/smoke_test.py +++ b/smoke_test.py @@ -133,7 +133,7 @@ async def test_list_tools(self) -> bool: result = await self.session.list_tools() tools = result.tools - expected_tools = {"codebase_consultant", "get_data_sources", "codebase_search", "fetch_artifacts"} + expected_tools = {"codebase_consultant", "get_data_sources", "codebase_search", "fetch_artifacts", "get_artifact_relations"} actual_tools = {tool.name for tool in tools} if expected_tools == actual_tools: diff --git a/src/codealive_mcp_server.py b/src/codealive_mcp_server.py index 6f86f5d..4ac9cd4 100644 --- a/src/codealive_mcp_server.py +++ b/src/codealive_mcp_server.py @@ -26,7 +26,7 @@ # Import core components from core import codealive_lifespan, setup_debug_logging from middleware import N8NRemoveParametersMiddleware -from tools import codebase_consultant, get_data_sources, fetch_artifacts, codebase_search +from tools import codebase_consultant, get_data_sources, fetch_artifacts, codebase_search, get_artifact_relations # Initialize FastMCP server with lifespan and enhanced system instructions mcp = FastMCP( @@ -99,6 +99,7 @@ async def health_check(request: Request) -> JSONResponse: mcp.tool()(get_data_sources) mcp.tool()(codebase_search) mcp.tool()(fetch_artifacts) +mcp.tool()(get_artifact_relations) def main(): diff --git a/src/tests/test_artifact_relations.py b/src/tests/test_artifact_relations.py new file mode 100644 index 0000000..30e1a1a --- /dev/null +++ b/src/tests/test_artifact_relations.py @@ -0,0 +1,346 @@ +"""Tests for the get_artifact_relations tool.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastmcp import Context + +from tools.artifact_relations import get_artifact_relations, _build_relations_xml, PROFILE_MAP + + +class TestProfileMapping: + """Test MCP profile names map to backend enum values.""" + + def test_default_profile_is_calls_only(self): + """callsOnly is the default and maps to CallsOnly.""" + assert PROFILE_MAP["callsOnly"] == "CallsOnly" + + def test_inheritance_only_maps_correctly(self): + assert PROFILE_MAP["inheritanceOnly"] == "InheritanceOnly" + + def test_all_relevant_maps_correctly(self): + assert PROFILE_MAP["allRelevant"] == "AllRelevant" + + def test_references_only_maps_correctly(self): + assert PROFILE_MAP["referencesOnly"] == "ReferencesOnly" + + +class TestBuildRelationsXml: + """Test XML rendering of relation responses.""" + + def test_found_with_grouped_relations(self): + data = { + "sourceIdentifier": "org/repo::path::Symbol", + "profile": "CallsOnly", + "found": True, + "relations": [ + { + "relationType": "OutgoingCalls", + "totalCount": 57, + "returnedCount": 50, + "truncated": True, + "items": [ + { + "identifier": "org/repo::path::Dep", + "filePath": "src/Data/Repository.cs", + "startLine": 88, + "shortSummary": "Stores the aggregate", + } + ], + }, + { + "relationType": "IncomingCalls", + "totalCount": 3, + "returnedCount": 3, + "truncated": False, + "items": [ + { + "identifier": "org/repo::path::Caller", + "filePath": "src/Services/Worker.cs", + "startLine": 142, + } + ], + }, + ], + } + + result = _build_relations_xml(data) + + assert 'sourceIdentifier="org/repo::path::Symbol"' in result + assert 'profile="callsOnly"' in result + assert 'found="true"' in result + assert 'type="outgoing_calls"' in result + assert 'totalCount="57"' in result + assert 'returnedCount="50"' in result + assert 'truncated="true"' in result + assert 'filePath="src/Data/Repository.cs"' in result + assert 'startLine="88"' in result + assert 'shortSummary="Stores the aggregate"' in result + assert 'type="incoming_calls"' in result + assert 'truncated="false"' in result + # Incoming call has no shortSummary + assert result.count("shortSummary") == 1 + + def test_not_found_renders_self_closing(self): + data = { + "sourceIdentifier": "org/repo::path::Missing", + "profile": "CallsOnly", + "found": False, + "relations": [], + } + + result = _build_relations_xml(data) + + assert 'found="false"' in result + assert result.endswith("/>") + assert "", + "profile": "CallsOnly", + "found": True, + "relations": [ + { + "relationType": "OutgoingCalls", + "totalCount": 1, + "returnedCount": 1, + "truncated": False, + "items": [ + { + "identifier": "org/repo::path::Method", + "shortSummary": 'Returns "value" & more', + } + ], + }, + ], + } + + result = _build_relations_xml(data) + + assert "Class<T>" in result + assert "Method<T>" in result + assert "&" in result + assert """ in result + + def test_profile_mapped_back_to_mcp_name(self): + """Backend profile enum names are mapped back to MCP-friendly names.""" + for mcp_name, api_name in PROFILE_MAP.items(): + data = { + "sourceIdentifier": "id", + "profile": api_name, + "found": False, + "relations": [], + } + result = _build_relations_xml(data) + assert f'profile="{mcp_name}"' in result + + +class TestGetArtifactRelationsTool: + """Test the async tool function.""" + + @pytest.mark.asyncio + @patch("tools.artifact_relations.get_api_key_from_context") + async def test_default_profile_sends_calls_only(self, mock_get_api_key): + mock_get_api_key.return_value = "test_key" + + ctx = MagicMock(spec=Context) + ctx.info = AsyncMock() + ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.json.return_value = { + "sourceIdentifier": "org/repo::path::Symbol", + "profile": "CallsOnly", + "found": True, + "relations": [], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + mock_context = MagicMock() + mock_context.client = mock_client + mock_context.base_url = "https://app.codealive.ai" + ctx.request_context.lifespan_context = mock_context + + result = await get_artifact_relations( + ctx=ctx, + identifier="org/repo::path::Symbol", + ) + + # Verify the API was called with CallsOnly profile + call_args = mock_client.post.call_args + assert call_args[1]["json"]["profile"] == "CallsOnly" + assert 'found="true"' in result + + @pytest.mark.asyncio + @patch("tools.artifact_relations.get_api_key_from_context") + async def test_explicit_profile_maps_correctly(self, mock_get_api_key): + mock_get_api_key.return_value = "test_key" + + ctx = MagicMock(spec=Context) + ctx.info = AsyncMock() + ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.json.return_value = { + "sourceIdentifier": "id", + "profile": "InheritanceOnly", + "found": True, + "relations": [], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + mock_context = MagicMock() + mock_context.client = mock_client + mock_context.base_url = "https://app.codealive.ai" + ctx.request_context.lifespan_context = mock_context + + await get_artifact_relations( + ctx=ctx, + identifier="id", + profile="inheritanceOnly", + ) + + call_args = mock_client.post.call_args + assert call_args[1]["json"]["profile"] == "InheritanceOnly" + + @pytest.mark.asyncio + async def test_empty_identifier_returns_error(self): + ctx = MagicMock(spec=Context) + result = await get_artifact_relations(ctx=ctx, identifier="") + assert "" in result + assert "required" in result + + @pytest.mark.asyncio + async def test_unsupported_profile_returns_error(self): + ctx = MagicMock(spec=Context) + result = await get_artifact_relations( + ctx=ctx, identifier="id", profile="invalidProfile" + ) + assert "" in result + assert "Unsupported profile" in result + + @pytest.mark.asyncio + @patch("tools.artifact_relations.get_api_key_from_context") + async def test_api_error_returns_error_xml(self, mock_get_api_key): + import httpx + + mock_get_api_key.return_value = "test_key" + + ctx = MagicMock(spec=Context) + ctx.info = AsyncMock() + ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Unauthorized", request=MagicMock(), response=mock_response + ) + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + mock_context = MagicMock() + mock_context.client = mock_client + mock_context.base_url = "https://app.codealive.ai" + ctx.request_context.lifespan_context = mock_context + + result = await get_artifact_relations(ctx=ctx, identifier="id") + + assert "" in result + assert "401" in result + + @pytest.mark.asyncio + @patch("tools.artifact_relations.get_api_key_from_context") + async def test_not_found_response_renders_correctly(self, mock_get_api_key): + mock_get_api_key.return_value = "test_key" + + ctx = MagicMock(spec=Context) + ctx.info = AsyncMock() + ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.json.return_value = { + "sourceIdentifier": "org/repo::path::Missing", + "profile": "CallsOnly", + "found": False, + "relations": [], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + mock_context = MagicMock() + mock_context.client = mock_client + mock_context.base_url = "https://app.codealive.ai" + ctx.request_context.lifespan_context = mock_context + + result = await get_artifact_relations(ctx=ctx, identifier="org/repo::path::Missing") + + assert 'found="false"' in result + assert " element.""" + + def test_content_wrapped_in_element(self): + data = {"artifacts": [ + {"identifier": "repo::file.py::func", "content": "code here", "contentByteSize": 9} + ]} + result = _build_artifacts_xml(data) + assert "" in result + assert "" in result + assert "1 | code here" in result + + def test_artifact_structure_has_content_child(self): + data = {"artifacts": [ + {"identifier": "repo::f.py::fn", "content": "x = 1", "contentByteSize": 5} + ]} + result = _build_artifacts_xml(data) + assert "" in result + assert "" in result + + +class TestBuildArtifactsXmlRelations: + """Test relations rendering in _build_artifacts_xml.""" + + def test_relations_with_outgoing_and_incoming(self): + data = {"artifacts": [{ + "identifier": "repo::src/a.ts::FuncA", + "content": "code", + "contentByteSize": 4, + "relations": { + "outgoingCallsCount": 12, + "outgoingCalls": [ + {"identifier": "repo::src/b.ts::FuncB", "summary": "Validates token"}, + {"identifier": "repo::src/c.ts::FuncC", "summary": "Logs event"}, + ], + "incomingCallsCount": 3, + "incomingCalls": [ + {"identifier": "repo::src/d.ts::FuncD", "summary": "Entry point"}, + ], + } + }]} + result = _build_artifacts_xml(data) + assert "" in result + assert '' in result + assert '' in result + assert '' in result + assert '' in result + assert '' in result + assert 'identifier="repo::src/b.ts::FuncB" summary="Validates token"' in result + assert 'identifier="repo::src/d.ts::FuncD" summary="Entry point"' in result + + def test_relations_with_only_outgoing(self): + data = {"artifacts": [{ + "identifier": "repo::src/a.ts::FuncA", + "content": "code", + "contentByteSize": 4, + "relations": { + "outgoingCallsCount": 2, + "outgoingCalls": [ + {"identifier": "repo::src/b.ts::FuncB", "summary": "Does stuff"}, + ], + "incomingCallsCount": None, + "incomingCalls": None, + } + }]} + result = _build_artifacts_xml(data) + assert "" in result + assert "" not in result + assert "" in result + + def test_relations_absent_omits_relations_element(self): + data = {"artifacts": [{ + "identifier": "repo::src/a.ts", + "content": "code", + "contentByteSize": 4, + }]} + result = _build_artifacts_xml(data) + assert "" not in result + + def test_relations_call_without_summary_omits_summary_attr(self): + data = {"artifacts": [{ + "identifier": "repo::src/a.ts::FuncA", + "content": "code", + "contentByteSize": 4, + "relations": { + "outgoingCallsCount": 1, + "outgoingCalls": [ + {"identifier": "repo::src/b.ts::FuncB", "summary": None}, + ], + "incomingCallsCount": None, + "incomingCalls": None, + } + }]} + result = _build_artifacts_xml(data) + assert 'identifier="repo::src/b.ts::FuncB"/>' in result + assert 'summary' not in result.split('FuncB')[1].split('/>')[0] + + def test_relations_escapes_xml_in_summary(self): + data = {"artifacts": [{ + "identifier": "repo::src/a.ts::FuncA", + "content": "code", + "contentByteSize": 4, + "relations": { + "outgoingCallsCount": 1, + "outgoingCalls": [ + {"identifier": "repo::src/b.ts::FuncB", "summary": 'Checks if x < 10 & y > 5'}, + ], + "incomingCallsCount": None, + "incomingCalls": None, + } + }]} + result = _build_artifacts_xml(data) + assert "<" in result + assert "&" in result + + @pytest.mark.asyncio @patch('tools.fetch_artifacts.get_api_key_from_context') async def test_fetch_artifacts_returns_xml(mock_get_api_key): @@ -119,7 +247,8 @@ async def test_fetch_artifacts_returns_xml(mock_get_api_key): assert isinstance(result, str) assert "" in result assert "" in result - # Found artifact has line-numbered content + # Found artifact has line-numbered content wrapped in + assert "" in result assert "1 | def login(user, pwd):" in result assert "2 | return True" in result assert 'contentByteSize="38"' in result diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 8e604fc..c3904a1 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -1,8 +1,9 @@ """Tool implementations for CodeAlive MCP server.""" +from .artifact_relations import get_artifact_relations from .chat import codebase_consultant from .datasources import get_data_sources from .fetch_artifacts import fetch_artifacts from .search import codebase_search -__all__ = ['codebase_consultant', 'get_data_sources', 'fetch_artifacts', 'codebase_search'] \ No newline at end of file +__all__ = ['codebase_consultant', 'get_data_sources', 'fetch_artifacts', 'codebase_search', 'get_artifact_relations'] \ No newline at end of file diff --git a/src/tools/artifact_relations.py b/src/tools/artifact_relations.py new file mode 100644 index 0000000..ae07ed8 --- /dev/null +++ b/src/tools/artifact_relations.py @@ -0,0 +1,164 @@ +"""Artifact relations tool implementation.""" + +import html +from typing import Optional +from urllib.parse import urljoin + +import httpx +from fastmcp import Context + +from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response +from utils import handle_api_error + +# Map MCP profile names to backend enum values +PROFILE_MAP = { + "callsOnly": "CallsOnly", + "inheritanceOnly": "InheritanceOnly", + "allRelevant": "AllRelevant", + "referencesOnly": "ReferencesOnly", +} + +# Backend relation type to MCP-friendly snake_case +RELATION_TYPE_MAP = { + "OutgoingCalls": "outgoing_calls", + "IncomingCalls": "incoming_calls", + "Ancestors": "ancestors", + "Descendants": "descendants", + "References": "references", +} + + +async def get_artifact_relations( + ctx: Context, + identifier: str, + profile: str = "callsOnly", + max_count_per_type: int = 50, +) -> str: + """ + Retrieve relation groups for a single artifact by profile. + + Use this tool to explore an artifact's call graph, inheritance hierarchy, + or references. This is a drill-down tool — use it AFTER `codebase_search` + or `fetch_artifacts` when you need to understand how an artifact relates + to others in the codebase. + + Args: + identifier: Fully qualified artifact identifier from search or fetch results. + profile: Relation profile to expand. One of: + - "callsOnly" (default): outgoing and incoming calls + - "inheritanceOnly": ancestors and descendants + - "allRelevant": calls + inheritance (4 groups) + - "referencesOnly": symbol references + max_count_per_type: Maximum related artifacts per relation type (1–1000, default 50). + + Returns: + XML with grouped relations: + + + + + + + + + + When the artifact is not found or inaccessible: + + """ + if not identifier: + return "Artifact identifier is required." + + api_profile = PROFILE_MAP.get(profile) + if api_profile is None: + supported = ", ".join(PROFILE_MAP.keys()) + return f'Unsupported profile "{profile}". Use one of: {supported}' + + context: CodeAliveContext = ctx.request_context.lifespan_context + + try: + api_key = get_api_key_from_context(ctx) + headers = {"Authorization": f"Bearer {api_key}"} + + body = { + "identifier": identifier, + "profile": api_profile, + "maxCountPerType": max_count_per_type, + } + + await ctx.info(f"Fetching {profile} relations for artifact") + + full_url = urljoin(context.base_url, "/api/search/artifact-relations") + request_id = log_api_request("POST", full_url, headers, body=body) + + response = await context.client.post( + "/api/search/artifact-relations", json=body, headers=headers + ) + + log_api_response(response, request_id) + response.raise_for_status() + + return _build_relations_xml(response.json()) + + except (httpx.HTTPStatusError, Exception) as e: + error_msg = await handle_api_error(ctx, e, "get artifact relations") + return f"{error_msg}" + + +def _build_relations_xml(data: dict) -> str: + """Build XML representation of artifact relations response.""" + source_id = html.escape(data.get("sourceIdentifier", "")) + profile = html.escape(data.get("profile", "")) + found = data.get("found", False) + + # Map profile back to MCP-friendly name + mcp_profile = profile + for mcp_name, api_name in PROFILE_MAP.items(): + if api_name == profile: + mcp_profile = mcp_name + break + + if not found: + return f'' + + relations = data.get("relations", []) + if not relations: + return f'' + + xml_parts = [ + f'' + ] + + for group in relations: + relation_type = group.get("relationType", "") + mcp_type = RELATION_TYPE_MAP.get(relation_type, relation_type.lower()) + total_count = group.get("totalCount", 0) + returned_count = group.get("returnedCount", 0) + truncated = str(group.get("truncated", False)).lower() + + xml_parts.append( + f' ' + ) + + for item in group.get("items", []): + attrs = [f'identifier="{html.escape(item.get("identifier", ""))}"'] + + file_path = item.get("filePath") + if file_path is not None: + attrs.append(f'filePath="{html.escape(file_path)}"') + + start_line = item.get("startLine") + if start_line is not None: + attrs.append(f'startLine="{start_line}"') + + short_summary = item.get("shortSummary") + if short_summary is not None: + attrs.append(f'shortSummary="{html.escape(short_summary)}"') + + xml_parts.append(f' ') + + xml_parts.append(' ') + + xml_parts.append('') + return "\n".join(xml_parts) diff --git a/src/tools/fetch_artifacts.py b/src/tools/fetch_artifacts.py index 99d1253..18bddad 100644 --- a/src/tools/fetch_artifacts.py +++ b/src/tools/fetch_artifacts.py @@ -35,11 +35,27 @@ async def fetch_artifacts( Chunk: "my-org/backend::README.md::0042" Returns: - XML with full content for each found artifact: + XML with full content and call relations for each found artifact: - content here + + numbered source code + + + + + + + + + + Only artifacts with content are included in the response. + The `` element shows the artifact's call graph: + - **outgoing_calls**: functions this artifact calls (its dependencies) + - **incoming_calls**: functions that call this artifact (its blast radius) + Each shows up to 3 related artifacts with summaries. The `count` attribute + gives the total. Relations are omitted for non-function artifacts. Note: - Maximum 20 identifiers per request to avoid excessive payloads. @@ -107,7 +123,8 @@ def _add_line_numbers(content: str, start_line: int = 1) -> str: def _build_artifacts_xml(data: dict) -> str: """Build XML representation of fetched artifacts. - Backend DTO: Identifier (string), Content (string?), ContentByteSize (long?). + Backend DTO: Identifier (string), Content (string?), ContentByteSize (long?), + Relations (object?). Content is null when artifact is not found or has no content. Only artifacts with content are included in output. """ @@ -125,10 +142,58 @@ def _build_artifacts_xml(data: dict) -> str: attrs = [f'identifier="{identifier}"'] if content_byte_size is not None: attrs.append(f'contentByteSize="{content_byte_size}"') + start_line = artifact.get("startLine") or 1 numbered_content = _add_line_numbers(content, start_line) escaped_content = html.escape(numbered_content) - xml_parts.append(f' {escaped_content}') + + xml_parts.append(f' ') + xml_parts.append(f' {escaped_content}') + + relations = artifact.get("relations") + if relations is not None: + relations_xml = _build_relations_xml(relations) + if relations_xml: + xml_parts.append(relations_xml) + + xml_parts.append(' ') xml_parts.append("") return "\n".join(xml_parts) + + +def _build_relations_xml(relations: dict) -> str | None: + """Build XML for artifact call relations. + + Returns None if no relation types are present. + """ + parts = [] + + for rel_type in ("outgoingCalls", "incomingCalls"): + tag = "outgoing_calls" if rel_type == "outgoingCalls" else "incoming_calls" + count = relations.get(f"{rel_type}Count") + calls = relations.get(rel_type) + + if count is None: + continue + + call_elements = [] + if calls: + for call in calls: + call_id = html.escape(call.get("identifier", "")) + summary = call.get("summary") + if summary is not None: + call_elements.append( + f' ' + ) + else: + call_elements.append(f' ') + + parts.append(f' <{tag} count="{count}">') + parts.extend(call_elements) + parts.append(f' ') + + if not parts: + return None + + return " \n" + "\n".join(parts) + "\n "