diff --git a/README.md b/README.md index 23518df..35ec946 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ Add to `~/.cursor/mcp.json` (or `.cursor/mcp.json` in your project): | Tool | Description | |------|-------------| -| `search` | Search a codebase with a natural-language or code query. Pass `repo` as a git URL or local path. | +| `search` | Search a codebase with a natural-language or code query. Pass `repo` as a local directory path or an https:// git URL. | | `find_related` | Given a file path and line number, return chunks semantically similar to the code at that location. | ### Sub-agent support diff --git a/src/semble/index/create.py b/src/semble/index/create.py index b89eff0..172d7e3 100644 --- a/src/semble/index/create.py +++ b/src/semble/index/create.py @@ -11,6 +11,8 @@ from semble.tokens import tokenize from semble.types import Chunk, Encoder +_MAX_FILE_BYTES = 1_000_000 # 1 MB max file size to read and index + def create_index_from_path( path: Path, @@ -38,6 +40,8 @@ def create_index_from_path( for file_path in walk_files(path, extensions, ignore): language = language_for_path(file_path) with contextlib.suppress(OSError): + if file_path.stat().st_size > _MAX_FILE_BYTES: + continue source = file_path.read_text(encoding="utf-8", errors="replace") chunk_path = file_path.relative_to(display_root) if display_root else file_path chunks.extend(chunk_source(source, str(chunk_path), language)) diff --git a/src/semble/index/index.py b/src/semble/index/index.py index fcfef80..4c4c764 100644 --- a/src/semble/index/index.py +++ b/src/semble/index/index.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import subprocess import tempfile from collections import defaultdict @@ -14,6 +15,8 @@ from semble.search import search_bm25, search_hybrid, search_semantic from semble.types import Chunk, Encoder, IndexStats, SearchMode, SearchResult +_GIT_CLONE_TIMEOUT = int(os.environ.get("SEMBLE_CLONE_TIMEOUT", 60)) + class SembleIndex: """Fast local code index with hybrid search.""" @@ -128,15 +131,19 @@ def from_git( :param ignore: Directory names to skip. Defaults to common VCS and build dirs. :param include_text_files: If True, also index non-code text files (.md, .yaml, .json, etc.). :return: An indexed SembleIndex. Chunk file paths are repo-relative (e.g. ``src/foo.py``). - :raises RuntimeError: If git is not on PATH or the clone fails. + :raises RuntimeError: If git is not on PATH, the clone fails, or times out. """ with tempfile.TemporaryDirectory() as tmp_dir: # `--` prevents `url` from being interpreted as a git option (e.g. `--upload-pack=...`). cmd = ["git", "clone", "--depth", "1", *(["--branch", ref] if ref else []), "--", url, tmp_dir] try: - result = subprocess.run(cmd, capture_output=True, text=True, stdin=subprocess.DEVNULL) + result = subprocess.run( + cmd, capture_output=True, text=True, stdin=subprocess.DEVNULL, timeout=_GIT_CLONE_TIMEOUT + ) except FileNotFoundError: raise RuntimeError("git is not installed or not on PATH") from None + except subprocess.TimeoutExpired: + raise RuntimeError(f"git clone timed out for {url!r} (limit: {_GIT_CLONE_TIMEOUT} s)") from None if result.returncode != 0: raise RuntimeError(f"git clone failed for {url!r}:\n{result.stderr.strip()}") model = model or load_model() diff --git a/src/semble/mcp.py b/src/semble/mcp.py index 54310c4..42e814d 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -2,6 +2,7 @@ import asyncio import logging +from collections import OrderedDict from pathlib import Path from typing import Annotated, Literal @@ -17,21 +18,43 @@ logger = logging.getLogger(__name__) _REPO_DESCRIPTION = ( - "Git URL (e.g. https://github.com/org/repo) or local path to index and search. " + "https:// or http:// git URL (e.g. https://github.com/org/repo) or local directory path to index and search. " "Required when no default index was configured at startup. " "The index is cached after the first call, so repeat queries are fast." ) +_CACHE_MAX_SIZE = 10 # Max number of cached indexes to keep in memory + + +async def _get_index( + repo: str | None, + default_source: str | None, + cache: _IndexCache, +) -> SembleIndex: + """Return a cached index for a repo, rejecting unsafe git transport schemes.""" + if repo is not None and _is_git_url(repo) and not repo.startswith(("https://", "http://")): + raise ValueError(f"Only https://, http://, or local directory paths are accepted as `repo`. Got: {repo!r}") + source = repo or default_source + if not source: + raise ValueError( + "No repo specified and no default index. " + "Pass an https:// or http:// git URL or local directory path as `repo`." + ) + try: + return await cache.get(source) + except Exception as exc: + raise ValueError(f"Failed to index {source!r}: {exc}") from exc + def create_server(cache: _IndexCache, default_source: str | None = None) -> FastMCP: """Build and return a configured FastMCP server backed by the given cache.""" server = FastMCP( "semble", instructions=( - "Instant code search for any local or GitHub repository. " + "Instant code search for any local or remote git repository. " "Call `search` to find relevant code; call `find_related` on a result to discover similar code elsewhere. " - "For questions about a library (e.g. a PyPI/npm package), resolve the GitHub URL from your training " - "knowledge and pass it as `repo`. " + "When working in a local project, pass the project root as `repo`. " + "For remote repos, pass an explicit https:// URL. Never guess or infer URLs. " "Prefer these tools over Grep, Glob, or Read for any question about how code works." ), ) @@ -51,16 +74,10 @@ async def search( Pass a git URL or local path as `repo` to index it on demand; indexes are cached for the session. Use this to find where something is implemented, understand a library, or locate related code. """ - source = repo or default_source - if not source: - return ( - "No repo specified and no default index. " - "Pass a git URL (https://github.com/...) or local path as `repo`." - ) try: - index = await cache.get(source) - except Exception as exc: - return f"Failed to index {source!r}: {exc}" + index = await _get_index(repo, default_source, cache) + except ValueError as exc: + return str(exc) results = index.search(query, top_k=top_k, mode=mode) if not results: return "No results found." @@ -81,16 +98,10 @@ async def find_related( Use after `search` to explore related implementations or callers. Pass file_path and line from a prior search result. """ - source = repo or default_source - if not source: - return ( - "No repo specified and no default index. " - "Pass a git URL (https://github.com/...) or local path as `repo`." - ) try: - index = await cache.get(source) - except Exception as exc: - return f"Failed to index {source!r}: {exc}" + index = await _get_index(repo, default_source, cache) + except ValueError as exc: + return str(exc) chunk = _resolve_chunk(index.chunks, file_path, line) if chunk is None: return ( @@ -124,7 +135,7 @@ class _IndexCache: def __init__(self, model: Encoder) -> None: """Initialise an empty cache with a shared embedding model.""" self._model = model - self._tasks: dict[str, asyncio.Task[SembleIndex]] = {} + self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction self._watcher_task: asyncio.Task[None] | None = None def _compute_cache_key(self, source: str, ref: str | None = None) -> str: @@ -155,7 +166,11 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex: """Return an index for the requested source, building and caching it on first access.""" cache_key = self._compute_cache_key(source, ref) - if cache_key not in self._tasks: + if cache_key in self._tasks: + self._tasks.move_to_end(cache_key) + else: + if len(self._tasks) >= _CACHE_MAX_SIZE: + self._tasks.popitem(last=False) if _is_git_url(source): self._tasks[cache_key] = asyncio.create_task( asyncio.to_thread(SembleIndex.from_git, source, ref=ref, model=self._model) diff --git a/tests/test_git.py b/tests/test_git.py index 2467502..7e25002 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -78,10 +78,17 @@ def test_from_path_rejects_invalid_paths( def test_from_git_raises_on_failure(mock_model: Any) -> None: - """from_git raises RuntimeError when the clone fails or git is not installed.""" + """from_git raises RuntimeError when the clone fails, git is not installed, or times out.""" with pytest.raises(RuntimeError, match="git clone failed"): SembleIndex.from_git("/nonexistent/path/that/does/not/exist", model=mock_model) with patch("semble.index.index.subprocess.run", side_effect=FileNotFoundError): with pytest.raises(RuntimeError, match="git is not installed"): SembleIndex.from_git("https://github.com/x/y", model=mock_model) + + with patch( + "semble.index.index.subprocess.run", + side_effect=subprocess.TimeoutExpired(cmd=["git"], timeout=60), + ): + with pytest.raises(RuntimeError, match="timed out"): + SembleIndex.from_git("https://github.com/x/y", model=mock_model) diff --git a/tests/test_index.py b/tests/test_index.py index c759f3d..9bae1cb 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -4,7 +4,7 @@ import pytest from semble import SembleIndex -from semble.index.create import create_index_from_path +from semble.index.create import _MAX_FILE_BYTES, create_index_from_path from semble.types import Encoder @@ -33,6 +33,13 @@ def test_index_empty_returns_zero_chunks(mock_model: Encoder, tmp_path: Path) -> create_index_from_path(tmp_path, mock_model) +def test_oversized_file_is_skipped(mock_model: Encoder, tmp_path: Path) -> None: + """Files exceeding _MAX_FILE_BYTES are silently skipped during indexing.""" + (tmp_path / "big.py").write_bytes(b"x" * (_MAX_FILE_BYTES + 1)) + with pytest.raises(ValueError): # no indexable content remains + create_index_from_path(tmp_path, mock_model) + + def test_index_language_counts(indexed_index: SembleIndex) -> None: """Language breakdown in stats includes python with at least one chunk.""" stats = indexed_index.stats diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 730bca0..f9a1c80 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -4,7 +4,7 @@ import pytest -from semble.mcp import _IndexCache, create_server, serve +from semble.mcp import _CACHE_MAX_SIZE, _IndexCache, create_server, serve from semble.types import Chunk, Encoder, SearchMode, SearchResult from semble.utils import _format_results, _is_git_url, _resolve_chunk from tests.conftest import make_chunk @@ -256,6 +256,43 @@ async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: mock_run.assert_called_once() +@pytest.mark.anyio +@pytest.mark.parametrize( + ("repo", "tool", "extra_args"), + [ + ("file:///home/user/secret", "search", {"query": "foo"}), + ("ssh://internal-host/repo", "search", {"query": "foo"}), + ("git@github.com:org/repo", "search", {"query": "foo"}), + ("file:///home/user/secret", "find_related", {"file_path": "src/foo.py", "line": 1}), + ("ssh://internal-host/repo", "find_related", {"file_path": "src/foo.py", "line": 1}), + ], + ids=["file_search", "ssh_search", "scp_search", "file_find_related", "ssh_find_related"], +) +async def test_tool_rejects_unsafe_repo( + cache: _IndexCache, repo: str, tool: str, extra_args: dict[str, object] +) -> None: + """Both tools reject unsafe git transport schemes (ssh://, file://, SCP-form) supplied as repo.""" + server = create_server(cache, default_source=None) + result = await server.call_tool(tool, {**extra_args, "repo": repo}) + assert "Only https://" in _tool_text(result) + + +@pytest.mark.anyio +async def test_index_cache_lru_eviction(cache: _IndexCache, tmp_path: Path) -> None: + """_IndexCache evicts the least-recently-used entry when the cache is full.""" + dirs = [tmp_path / str(i) for i in range(_CACHE_MAX_SIZE + 1)] + for d in dirs: + d.mkdir() + with patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()): + for d in dirs[:_CACHE_MAX_SIZE]: + await cache.get(str(d)) + first_key = str(dirs[0].resolve()) + assert first_key in cache._tasks + await cache.get(str(dirs[_CACHE_MAX_SIZE])) + assert first_key not in cache._tasks + assert len(cache._tasks) == _CACHE_MAX_SIZE + + def test_cache_evict(cache: _IndexCache, tmp_path: Path) -> None: """evict() removes an existing cache entry by resolved path.""" key = str(tmp_path.resolve()) @@ -280,4 +317,5 @@ async def fake_awatch(_path: str) -> AsyncGenerator: with patch("semble.mcp.watchfiles.awatch", fake_awatch): with patch("semble.mcp.SembleIndex.from_path", side_effect=RuntimeError("build failed")): await cache.start_watcher(str(tmp_path)) + assert cache._watcher_task is not None await cache._watcher_task