Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions dr_agent/mcp_backend/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Annotated, List, Optional

import aiohttp
Expand Down Expand Up @@ -33,6 +34,36 @@

dotenv.load_dotenv()


def _url_to_snippet_id(url: str) -> str:
"""Generate a deterministic short snippet ID from a URL using UUID5."""
return str(uuid.uuid5(uuid.NAMESPACE_URL, url))[:8]


def _assign_snippet_ids(items: list, url_fn) -> list:
"""Add a unique snippet_id to each item based on its URL.

Args:
items: List of result dicts to annotate.
url_fn: Callable that extracts the URL string from an item.

If multiple items share the same URL (e.g. multiple snippets from one paper),
a ``-<counter>`` suffix is appended to keep IDs unique within the list.
"""
seen: dict[str, int] = {}
for item in items:
url = url_fn(item) or ""
if url:
base_id = _url_to_snippet_id(url)
else:
base_id = str(uuid.uuid4())[:8]

count = seen.get(base_id, 0)
seen[base_id] = count + 1
item["snippet_id"] = f"{base_id}-{count}" if count > 0 else base_id
return items


mcp = FastMCP(
"RL-RAG MCP",
include_tags=os.environ.get("MCP_INCLUDE_TAGS", "search,browse,rerank").split(","),
Expand Down Expand Up @@ -80,6 +111,9 @@ def semantic_scholar_search(
limit=min(limit, 100), # Ensure limit doesn't exceed API maximum
)

if "data" in results and results["data"]:
_assign_snippet_ids(results["data"], lambda item: item.get("url"))

return results


Expand Down Expand Up @@ -128,6 +162,12 @@ def semantic_scholar_snippet_search(
limit=limit,
)

if "data" in results and results["data"]:
_assign_snippet_ids(
results["data"],
lambda item: item.get("paper", {}).get("url"),
)

return results


Expand Down Expand Up @@ -166,6 +206,9 @@ def pubmed_search(
offset=offset,
)

if "data" in results and results["data"]:
_assign_snippet_ids(results["data"], lambda item: item.get("url"))

return results


Expand Down Expand Up @@ -254,6 +297,11 @@ def massive_serve_search(
for result in parsed_results
]

# massive_serve results don't have URLs; generate IDs from doc_id
for item in response["data"]:
doc_id_str = str(item.get("doc_id", ""))
item["snippet_id"] = _url_to_snippet_id(doc_id_str)

return response


Expand Down Expand Up @@ -281,6 +329,9 @@ def serper_google_webpage_search(
query=query, num_results=num_results, search_type="search", gl=gl, hl=hl
)

if "organic" in results and results["organic"]:
_assign_snippet_ids(results["organic"], lambda item: item.get("link"))

return results


Expand Down Expand Up @@ -310,6 +361,7 @@ def serper_fetch_webpage_content(

return {
**result,
"snippet_id": _url_to_snippet_id(webpage_url),
"success": True,
}
except Exception as e:
Expand All @@ -318,6 +370,7 @@ def serper_fetch_webpage_content(
"markdown": "",
"metadata": {},
"url": webpage_url,
"snippet_id": _url_to_snippet_id(webpage_url),
"success": False,
"error": str(e),
}
Expand Down Expand Up @@ -347,6 +400,8 @@ def jina_fetch_webpage_content(
- error: Error message if fetch failed
"""
result = fetch_webpage_content_jina(url=webpage_url, timeout=timeout)
if isinstance(result, dict):
result["snippet_id"] = _url_to_snippet_id(webpage_url)
return result


Expand All @@ -367,6 +422,9 @@ def serper_google_scholar_search(
num_results=num_results,
)

if "organic" in results and results["organic"]:
_assign_snippet_ids(results["organic"], lambda item: item.get("link"))

return results


Expand Down Expand Up @@ -410,7 +468,9 @@ async def crawl4ai_fetch_webpage_content(
timeout_ms=timeout_ms,
include_html=include_html,
)
return result
result_dict = result.model_dump() if hasattr(result, "model_dump") else dict(result)
result_dict["snippet_id"] = _url_to_snippet_id(url)
return result_dict


@mcp.tool(tags={"browse", "necessary"})
Expand Down Expand Up @@ -463,7 +523,9 @@ async def crawl4ai_docker_fetch_webpage_content(
use_pruning=use_pruning,
timeout_ms=timeout_ms,
)
return result
result_dict = result.model_dump() if hasattr(result, "model_dump") else dict(result)
result_dict["snippet_id"] = _url_to_snippet_id(url)
return result_dict


@mcp.tool(tags={"browse"})
Expand Down Expand Up @@ -491,7 +553,7 @@ def webthinker_fetch_webpage_content(
keep_links=keep_links,
)

return {"url": url, "text": text}
return {"url": url, "text": text, "snippet_id": _url_to_snippet_id(url)}


@mcp.tool(tags={"browse"})
Expand Down Expand Up @@ -538,7 +600,7 @@ async def webthinker_fetch_webpage_content_async(
keep_links=keep_links,
)

return {"url": url, "text": text}
return {"url": url, "text": text, "snippet_id": _url_to_snippet_id(url)}


if __name__ == "__main__":
Expand Down