diff --git a/cli/__init__.py b/cli/__init__.py deleted file mode 100644 index 07e1aac..0000000 --- a/cli/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""AsyncReview CLI - Review GitHub PRs and Issues from the command line.""" - -__version__ = "0.1.0" diff --git a/cli/github_fetcher.py b/cli/github_fetcher.py deleted file mode 100644 index a9a9555..0000000 --- a/cli/github_fetcher.py +++ /dev/null @@ -1,316 +0,0 @@ -"""GitHub URL parsing and content fetching for PR code review.""" - -import re -from typing import Literal - -import httpx - -# Import config from cr package -from cr.config import GITHUB_TOKEN, GITHUB_API_BASE - - -UrlType = Literal["pr", "issue"] - - -def parse_github_url(url: str) -> tuple[str, str, int, UrlType]: - """Parse a GitHub URL into (owner, repo, number, type). - - Args: - url: GitHub Issue or PR URL - - Returns: - Tuple of (owner, repo, number, type) - - Raises: - ValueError: If URL format is invalid - - Examples: - >>> parse_github_url("https://github.com/vercel-labs/json-render/pull/35") - ('vercel-labs', 'json-render', 35, 'pr') - >>> parse_github_url("https://github.com/AsyncFuncAI/AsyncReview/issues/1") - ('AsyncFuncAI', 'AsyncReview', 1, 'issue') - """ - # Try PR URL first (primary use case) - pr_pattern = r"github\.com/([^/]+)/([^/]+)/pull/(\d+)" - pr_match = re.search(pr_pattern, url) - if pr_match: - return pr_match.group(1), pr_match.group(2), int(pr_match.group(3)), "pr" - - # Try Issue URL - issue_pattern = r"github\.com/([^/]+)/([^/]+)/issues/(\d+)" - issue_match = re.search(issue_pattern, url) - if issue_match: - return issue_match.group(1), issue_match.group(2), int(issue_match.group(3)), "issue" - - raise ValueError( - f"Invalid GitHub URL: {url}\n" - "Expected format: https://github.com/owner/repo/pull/123 or .../issues/123" - ) - - -def _get_headers() -> dict[str, str]: - """Get HTTP headers for GitHub API requests.""" - headers = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "asyncreview-cli", - } - if GITHUB_TOKEN: - headers["Authorization"] = f"token {GITHUB_TOKEN}" - return headers - - -async def fetch_pr(owner: str, repo: str, number: int) -> dict: - """Fetch PR with full code review context. - - Returns dict with: - - metadata: title, body, author, state, etc. - - files: list of changed files with patches - - commits: commit history - - comments: PR discussion comments - """ - async with httpx.AsyncClient() as client: - # Fetch PR metadata - pr_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls/{number}", - headers=_get_headers(), - timeout=30.0, - ) - pr_resp.raise_for_status() - pr_data = pr_resp.json() - - # Fetch changed files with patches - files_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls/{number}/files", - headers=_get_headers(), - params={"per_page": 100}, - timeout=30.0, - ) - files_resp.raise_for_status() - files_data = files_resp.json() - - # Fetch commits - commits_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/pulls/{number}/commits", - headers=_get_headers(), - params={"per_page": 100}, - timeout=30.0, - ) - commits_list = [] - if commits_resp.status_code == 200: - commits_data = commits_resp.json() - commits_list = [ - { - "sha": c["sha"][:7], - "message": c["commit"]["message"].split("\n")[0], # First line only - "author": c["commit"]["author"]["name"], - } - for c in commits_data - ] - - # Fetch PR comments - comments_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues/{number}/comments", - headers=_get_headers(), - params={"per_page": 50}, - timeout=30.0, - ) - comments_list = [] - if comments_resp.status_code == 200: - comments_data = comments_resp.json() - comments_list = [ - { - "author": c["user"]["login"], - "body": c["body"], - } - for c in comments_data - ] - - # Build structured result - files = [ - { - "path": f["filename"], - "status": f.get("status", "modified"), - "additions": f.get("additions", 0), - "deletions": f.get("deletions", 0), - "patch": f.get("patch", ""), - } - for f in files_data - ] - - return { - "type": "pr", - "owner": owner, - "repo": repo, - "number": number, - "title": pr_data.get("title", ""), - "body": pr_data.get("body") or "", - "author": pr_data["user"]["login"], - "state": pr_data.get("state", "open"), - "base_branch": pr_data["base"]["ref"], - "head_branch": pr_data["head"]["ref"], - "files": files, - "commits": commits_list, - "comments": comments_list, - "additions": pr_data.get("additions", 0), - "deletions": pr_data.get("deletions", 0), - "changed_files_count": pr_data.get("changed_files", 0), - } - - -async def fetch_issue(owner: str, repo: str, number: int) -> dict: - """Fetch issue content and comments (secondary use case).""" - async with httpx.AsyncClient() as client: - # Fetch issue metadata - issue_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues/{number}", - headers=_get_headers(), - timeout=30.0, - ) - issue_resp.raise_for_status() - issue_data = issue_resp.json() - - # Fetch comments - comments_resp = await client.get( - f"{GITHUB_API_BASE}/repos/{owner}/{repo}/issues/{number}/comments", - headers=_get_headers(), - params={"per_page": 50}, - timeout=30.0, - ) - comments_list = [] - if comments_resp.status_code == 200: - comments_data = comments_resp.json() - comments_list = [ - { - "author": c["user"]["login"], - "body": c["body"], - } - for c in comments_data - ] - - return { - "type": "issue", - "owner": owner, - "repo": repo, - "number": number, - "title": issue_data.get("title", ""), - "body": issue_data.get("body") or "", - "author": issue_data["user"]["login"], - "state": issue_data.get("state", "open"), - "labels": [l["name"] for l in issue_data.get("labels", [])], - "comments": comments_list, - } - - -def build_pr_context(data: dict) -> str: - """Build a structured text representation of a PR for RLM input. - - Optimized for code review - includes full diff patches. - """ - lines = [ - f"# Pull Request: {data['title']}", - f"", - f"**Repository:** {data['owner']}/{data['repo']}", - f"**Author:** {data['author']}", - f"**Branch:** {data['head_branch']} → {data['base_branch']}", - f"**Changes:** +{data['additions']} -{data['deletions']} across {data['changed_files_count']} files", - f"", - ] - - # PR description - if data["body"]: - lines.extend([ - "## Description", - "", - data["body"], - "", - ]) - - # Commits - if data["commits"]: - lines.extend([ - "## Commits", - "", - ]) - for commit in data["commits"]: - lines.append(f"- `{commit['sha']}` {commit['message']} ({commit['author']})") - lines.append("") - - # Changed files with patches - lines.extend([ - "## Changed Files", - "", - ]) - - for file in data["files"]: - status_icon = {"added": "+", "removed": "-", "modified": "~"}.get(file["status"], "~") - lines.append(f"### [{status_icon}] {file['path']}") - lines.append(f"*+{file['additions']} -{file['deletions']}*") - lines.append("") - - if file["patch"]: - lines.append("```diff") - lines.append(file["patch"]) - lines.append("```") - lines.append("") - - # Comments/Discussion - if data["comments"]: - lines.extend([ - "## Discussion", - "", - ]) - for comment in data["comments"]: - lines.append(f"**{comment['author']}:**") - lines.append(comment["body"]) - lines.append("") - - return "\n".join(lines) - - -def build_issue_context(data: dict) -> str: - """Build a text representation of an issue for RLM input.""" - lines = [ - f"# Issue: {data['title']}", - f"", - f"**Repository:** {data['owner']}/{data['repo']}", - f"**Author:** {data['author']}", - f"**State:** {data['state']}", - ] - - if data["labels"]: - lines.append(f"**Labels:** {', '.join(data['labels'])}") - - lines.append("") - - # Issue body - if data["body"]: - lines.extend([ - "## Description", - "", - data["body"], - "", - ]) - - # Comments - if data["comments"]: - lines.extend([ - "## Discussion", - "", - ]) - for comment in data["comments"]: - lines.append(f"**{comment['author']}:**") - lines.append(comment["body"]) - lines.append("") - - return "\n".join(lines) - - -def build_review_context(data: dict) -> str: - """Build a structured text representation for RLM input. - - Dispatches to PR or Issue context builder based on type. - """ - if data["type"] == "pr": - return build_pr_context(data) - else: - return build_issue_context(data) diff --git a/cli/main.py b/cli/main.py deleted file mode 100644 index 930cd3b..0000000 --- a/cli/main.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -"""AsyncReview CLI - Review GitHub PRs and Issues from the command line. - -Primary use case: PR code review with full diff context. - -Examples: - # Quick PR review - asyncreview review --url https://github.com/org/repo/pull/123 -q "Any security concerns?" - - # Output as markdown for docs - asyncreview review --url https://github.com/org/repo/pull/123 -q "Summarize changes" --output markdown - - # Quiet mode for scripting - asyncreview review --url https://github.com/org/repo/pull/123 -q "Review this" --quiet --output json -""" - -import argparse -import asyncio -import sys - -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown - -from . import __version__ -from .github_fetcher import parse_github_url -from .output_formatter import format_output -from .virtual_runner import VirtualReviewRunner - - -console = Console() - - -def print_step(step_num: int, reasoning: str, code: str): - """Print RLM step progress (when not in quiet mode).""" - console.print(f"\n[cyan]Step {step_num}[/cyan]", style="bold") - if reasoning: - # Truncate for display - display = reasoning[:200] + "..." if len(reasoning) > 200 else reasoning - console.print(f"[dim]{display}[/dim]") - - -def print_info(message: str): - """Print an info message.""" - console.print(f"[dim]{message}[/dim]") - - -def print_error(message: str): - """Print an error message.""" - console.print(f"[red]Error: {message}[/red]") - - -async def run_review( - url: str, - question: str, - output_format: str = "text", - quiet: bool = False, - model: str | None = None, -): - """Run a review on a GitHub URL.""" - # Parse URL first to validate - try: - owner, repo, number, url_type = parse_github_url(url) - except ValueError as e: - print_error(str(e)) - sys.exit(1) - - if not quiet: - type_label = "PR" if url_type == "pr" else "Issue" - print_info(f"Reviewing {type_label}: {owner}/{repo}#{number}") - print_info(f"Question: {question}") - console.print() - - # Create runner - runner = VirtualReviewRunner( - model=model, - quiet=quiet, - on_step=None if quiet else print_step, - ) - - try: - answer, sources, metadata = await runner.review(url, question) - except Exception as e: - print_error(f"Review failed: {e}") - sys.exit(1) - - # Format and print output - model_name = metadata.get("model", model or "unknown") - output = format_output( - answer=answer, - sources=sources, - model=model_name, - output_format=output_format, - metadata=metadata if output_format == "json" else None, - ) - - if quiet or output_format == "json": - # Raw output for scripting - print(output) - else: - # Rich formatted output - console.print() - if output_format == "markdown": - console.print(Panel(Markdown(output), title="Review", border_style="green")) - else: - console.print(Panel(output, title="Review", border_style="green")) - - -def main(): - """Main CLI entry point.""" - parser = argparse.ArgumentParser( - description="AsyncReview CLI - Review GitHub PRs and Issues", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - asyncreview review --url https://github.com/org/repo/pull/123 -q "Any risks?" - asyncreview review --url https://github.com/org/repo/issues/42 -q "What's needed?" --output markdown - asyncreview review --url -q "Review" --quiet --output json - """, - ) - - parser.add_argument( - "--version", "-V", - action="version", - version=f"asyncreview {__version__}", - ) - - subparsers = parser.add_subparsers(dest="command", help="Available commands") - - # review command - review_parser = subparsers.add_parser( - "review", - help="Review a GitHub PR or Issue", - ) - review_parser.add_argument( - "--url", "-u", - type=str, - required=True, - help="GitHub PR or Issue URL", - ) - review_parser.add_argument( - "--question", "-q", - type=str, - required=True, - help="Question to ask about the PR/Issue", - ) - review_parser.add_argument( - "--output", "-o", - type=str, - choices=["text", "markdown", "json"], - default="text", - help="Output format (default: text)", - ) - review_parser.add_argument( - "--quiet", - action="store_true", - help="Suppress progress output, print only the result", - ) - review_parser.add_argument( - "--model", "-m", - type=str, - default=None, - help="Model to use (e.g. gemini-3.0-pro-preview)", - ) - - args = parser.parse_args() - - if args.command == "review": - asyncio.run(run_review( - url=args.url, - question=args.question, - output_format=args.output, - quiet=args.quiet, - model=args.model, - )) - else: - parser.print_help() - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/cli/output_formatter.py b/cli/output_formatter.py deleted file mode 100644 index 61f3f5c..0000000 --- a/cli/output_formatter.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Output formatting for different use cases.""" - -import json -from typing import Any - - -def format_text(answer: str, sources: list[str], model: str) -> str: - """Plain text output for terminal display. - - Clean output without rich formatting. - """ - lines = [answer] - - if sources: - lines.append("") - lines.append("Sources:") - for source in sources: - lines.append(f" • {source}") - - lines.append("") - lines.append(f"— AsyncReview • {model}") - - return "\n".join(lines) - - -def format_markdown(answer: str, sources: list[str], model: str) -> str: - """Markdown output for docs/skills. - - Suitable for embedding in documentation or Claude Code skills. - """ - lines = [answer] - - if sources: - lines.append("") - lines.append("### Sources") - for source in sources: - lines.append(f"- `{source}`") - - lines.append("") - lines.append("---") - lines.append(f"*AsyncReview • {model}*") - - return "\n".join(lines) - - -def format_json( - answer: str, - sources: list[str], - model: str, - metadata: dict[str, Any] | None = None, -) -> str: - """JSON output for scripting and automation. - - Returns a JSON string with answer, sources, model, and optional metadata. - """ - result = { - "answer": answer, - "sources": sources, - "model": model, - } - - if metadata: - result["metadata"] = metadata - - return json.dumps(result, indent=2) - - -def format_output( - answer: str, - sources: list[str], - model: str, - output_format: str = "text", - metadata: dict[str, Any] | None = None, -) -> str: - """Format output based on specified format. - - Args: - answer: The review answer text - sources: List of source citations - model: Model name used for review - output_format: One of "text", "markdown", "json" - metadata: Optional metadata for JSON output - - Returns: - Formatted output string - """ - if output_format == "markdown": - return format_markdown(answer, sources, model) - elif output_format == "json": - return format_json(answer, sources, model, metadata) - else: - return format_text(answer, sources, model) diff --git a/cli/virtual_runner.py b/cli/virtual_runner.py deleted file mode 100644 index b9ad29e..0000000 --- a/cli/virtual_runner.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Virtual review runner - runs RLM reviews on GitHub content without local repo.""" - -import asyncio -import logging -from typing import Callable - -import dspy - -from cr.config import MAIN_MODEL, SUB_MODEL, MAX_ITERATIONS, MAX_LLM_CALLS - -from .github_fetcher import ( - parse_github_url, - fetch_pr, - fetch_issue, - build_review_context, -) - - -class VirtualReviewRunner: - """Run RLM code reviews on GitHub PRs without a local repository. - - Creates a 'virtual' codebase context from GitHub API data. - """ - - def __init__( - self, - model: str | None = None, - quiet: bool = False, - on_step: Callable[[int, str, str], None] | None = None, - ): - """Initialize the virtual runner. - - Args: - model: Override model (e.g. "gemini-3.0-pro-preview") - quiet: If True, suppress progress output - on_step: Optional callback for RLM step updates - """ - self.model = model or MAIN_MODEL - self.quiet = quiet - self.on_step = on_step - self._rlm = None - self._configured = False - - def _ensure_configured(self): - """Configure DSPy and RLM on first use.""" - if self._configured: - return - - # Configure logging based on quiet mode - if self.quiet: - logging.getLogger("dspy").setLevel(logging.WARNING) - logging.getLogger("dspy.predict.rlm").setLevel(logging.WARNING) - logging.getLogger("httpx").setLevel(logging.WARNING) - else: - logging.getLogger("dspy.predict.rlm").setLevel(logging.INFO) - - # Suppress noisy loggers - for name in ("httpx", "anthropic", "google", "urllib3"): - logging.getLogger(name).setLevel(logging.WARNING) - - # Configure DSPy with specified model - model_name = self.model - if not model_name.startswith("gemini/"): - model_name = f"gemini/{model_name}" - - dspy.configure(lm=dspy.LM(model_name)) - - # Create RLM with custom interpreter that has Deno 2.x fix - from dspy.primitives.python_interpreter import PythonInterpreter - from cr.rlm_runner import build_deno_command - - deno_command = build_deno_command() - interpreter = PythonInterpreter(deno_command=deno_command) - - self._rlm = dspy.RLM( - signature="context, question -> answer, sources", - max_iterations=MAX_ITERATIONS, - max_llm_calls=MAX_LLM_CALLS, - sub_lm=dspy.LM(f"gemini/{SUB_MODEL}" if not SUB_MODEL.startswith("gemini/") else SUB_MODEL), - verbose=not self.quiet, - interpreter=interpreter, - ) - self._configured = True - - async def review(self, url: str, question: str) -> tuple[str, list[str], dict]: - """Review a GitHub URL (PR or Issue). - - Args: - url: GitHub PR or Issue URL - question: Question to ask about the content - - Returns: - Tuple of (answer, sources, metadata) - """ - # Parse URL to determine type - owner, repo, number, url_type = parse_github_url(url) - - # Fetch content - if url_type == "pr": - data = await fetch_pr(owner, repo, number) - else: - data = await fetch_issue(owner, repo, number) - - # Build context - context = build_review_context(data) - - # Run RLM - self._ensure_configured() - - # Run in thread pool since RLM is sync - loop = asyncio.get_event_loop() - result = await loop.run_in_executor( - None, - lambda: self._run_rlm(context, question) - ) - - answer, sources = result - - metadata = { - "type": url_type, - "owner": owner, - "repo": repo, - "number": number, - "title": data.get("title", ""), - "model": self.model, - } - - return answer, sources, metadata - - def _run_rlm(self, context: str, question: str) -> tuple[str, list[str]]: - """Run the RLM synchronously.""" - result = self._rlm(context=context, question=question) - - answer = getattr(result, "answer", str(result)) - sources = getattr(result, "sources", []) - - if isinstance(sources, str): - sources = [s.strip() for s in sources.split(",") if s.strip()] - - return answer, sources - - async def review_pr(self, url: str, question: str) -> tuple[str, list[str], dict]: - """Review a GitHub PR with full diff context. - - This is the primary use case - builds comprehensive context including - all changed files with their patches, PR description, and commit history. - """ - return await self.review(url, question) - - async def review_issue(self, url: str, question: str) -> tuple[str, list[str], dict]: - """Review a GitHub issue (secondary use case).""" - return await self.review(url, question) diff --git a/npx/python/cli/local_repo_tools.py b/npx/python/cli/local_repo_tools.py index 6b15f31..d54e534 100644 --- a/npx/python/cli/local_repo_tools.py +++ b/npx/python/cli/local_repo_tools.py @@ -6,6 +6,7 @@ import asyncio import os +import re import subprocess from typing import Any @@ -197,14 +198,580 @@ async def search_code(self, query: str) -> list[dict[str, Any]]: return results + async def get_symbol_definition(self, symbol: str, context_file: str = "") -> str: + """Find the definition of a symbol (function or class). + + Uses grep to find 'def symbol' or 'class symbol' patterns. + Returns path + snippet or error message. + """ + if not symbol or not symbol.strip(): + return "[ERROR: empty symbol]" + + symbol = symbol.strip() + # Extract simple name from dotted paths (e.g., "dspy.adapters.DataFrame" -> "DataFrame") + if "." in symbol: + symbol = symbol.rsplit(".", 1)[-1] + + # Build grep command to find definitions. Use extended regex to allow OR. + args = ["grep", "-Ern"] + for ext in SEARCH_EXTENSIONS: + args.append(f"--include=*{ext}") + args.append("--") + # Search for "def symbol" or "class symbol" as a standalone identifier. + symbol_pattern = rf"(def|class)\s+{re.escape(symbol)}\b" + args.append(symbol_pattern) + + # Narrow search scope when context_file is provided. + if context_file and context_file.strip(): + context_dir = os.path.dirname(context_file.strip()) + scope_abs = self._resolve_path(context_dir) if context_dir else self.root_path + args.append(scope_abs if scope_abs else self.root_path) + else: + args.append(self.root_path) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + ) + except subprocess.TimeoutExpired: + return "[ERROR: search timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: symbol '{symbol}' not found]" + + # Parse first match + lines = result.stdout.splitlines() + if not lines: + return f"[ERROR: symbol '{symbol}' not found]" + + first_match = lines[0] + parts = first_match.split(":", 2) + if len(parts) >= 3: + file_path = parts[0] + line_num = parts[1] + rel_path = os.path.relpath(file_path, self.root_path) + snippet = parts[2][:200] + return f"local:{rel_path}#L{line_num}\n{snippet}" + + return f"[ERROR: could not parse definition]" + + async def find_usages(self, symbol: str, scope_path: str = ".") -> str: + """Find all usages of a symbol in the codebase. + + Uses grep to find references. Returns formatted list of matches. + """ + if not symbol or not symbol.strip(): + return "[ERROR: empty symbol]" + + symbol = symbol.strip() + # Extract simple name from dotted paths (e.g., "dspy.predict.rlm.RLM" -> "RLM") + if "." in symbol: + symbol = symbol.rsplit(".", 1)[-1] + + # Resolve scope path + scope_abs = self._resolve_path(scope_path) if scope_path != "." else self.root_path + if scope_abs is None: + scope_abs = self.root_path + + # Build grep command + args = ["grep", "-rn"] + for ext in SEARCH_EXTENSIONS: + args.append(f"--include=*{ext}") + args.append("--") + args.append(symbol) + args.append(scope_abs) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + ) + except subprocess.TimeoutExpired: + return "[ERROR: search timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: no usages found for '{symbol}']" + + # Format results + results = [] + for line in result.stdout.splitlines()[:20]: # Limit to 20 results + parts = line.split(":", 2) + if len(parts) >= 3: + file_path = parts[0] + line_num = parts[1] + rel_path = os.path.relpath(file_path, self.root_path) + snippet = parts[2][:100] + results.append(f" {rel_path}:{line_num} {snippet}") + + if not results: + return f"[ERROR: no usages found for '{symbol}']" + + return f"Found {len(results)} usages of '{symbol}':\n" + "\n".join(results) + + async def get_type_hierarchy(self, class_name: str) -> str: + """Get the type hierarchy (parent classes) for a class. + + Finds class definition and parses parent classes, including one level of parent resolution. + """ + if not class_name or not class_name.strip(): + return "[ERROR: empty class name]" + + class_name = class_name.strip() + # Extract simple name from dotted paths (e.g., "dspy.adapters.DataFrame" -> "DataFrame") + if "." in class_name: + class_name = class_name.rsplit(".", 1)[-1] + + # Find class definition + args = ["grep", "-rn"] + for ext in SEARCH_EXTENSIONS: + args.append(f"--include=*{ext}") + args.append("--") + args.append(f"class {class_name}") + args.append(self.root_path) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + ) + except subprocess.TimeoutExpired: + return "[ERROR: search timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: class '{class_name}' not found]" + + # Parse first match to extract parent classes + lines = result.stdout.splitlines() + if not lines: + return f"[ERROR: class '{class_name}' not found]" + + first_match = lines[0] + parts = first_match.split(":", 2) + if len(parts) < 3: + return f"[ERROR: could not parse class definition]" + + file_path = parts[0] + line_num = parts[1] + rel_path = os.path.relpath(file_path, self.root_path) + + # Read the actual file and apply regex to full content to handle multi-line defs + import re + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + file_content = f.read() + except Exception: + return f"[ERROR: could not read {rel_path}]" + + # Extract parent classes from "class X(Parent1, Parent2):" pattern + # Use re.DOTALL to handle multi-line class definitions + pattern = rf"class\s+{re.escape(class_name)}\s*\(([^)]+)\)" + match = re.search(pattern, file_content, re.DOTALL) + parents = [] + if match: + parent_str = match.group(1) + # Strip whitespace and newlines from each parent, filter out empty strings + parents = [p.strip() for p in parent_str.split(",") if p.strip()] + + hierarchy = f"Type hierarchy for '{class_name}':\n" + hierarchy += f" {class_name} extends: {', '.join(parents) if parents else '(no parents)'}\n" + hierarchy += f" Parent details:\n" + + # Resolve one level of parent classes + for parent in parents: + parent_info = await self._resolve_parent_class_local(parent) + if parent_info: + hierarchy += f" {parent_info}\n" + else: + hierarchy += f" {parent} (no parents found)\n" + + return hierarchy + + async def _resolve_parent_class_local(self, parent_name: str) -> str | None: + """Resolve one level of parent class hierarchy in local filesystem. + + Searches for the parent class definition and extracts its parents. + Returns a string like "ParentClass extends: GrandParent" or None if not found. + """ + if not parent_name or not parent_name.strip(): + return None + + parent_name = parent_name.strip() + + # Find parent class definition + args = ["grep", "-rn"] + for ext in SEARCH_EXTENSIONS: + args.append(f"--include=*{ext}") + args.append("--") + args.append(f"class {parent_name}") + args.append(self.root_path) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + ) + except (subprocess.TimeoutExpired, Exception): + return None + + if result.returncode != 0: + return None + + lines = result.stdout.splitlines() + if not lines: + return None + + first_match = lines[0] + parts = first_match.split(":", 2) + if len(parts) < 3: + return None + + file_path = parts[0] + + # Read the actual file and apply regex to full content + import re + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + file_content = f.read() + except Exception: + return None + + # Extract parent classes of the parent class + pattern = rf"class\s+{re.escape(parent_name)}\s*\(([^)]+)\)" + match = re.search(pattern, file_content, re.DOTALL) + + if not match: + return None + + parent_str = match.group(1) + grandparents = [p.strip() for p in parent_str.split(",") if p.strip()] + + return f"{parent_name} extends: {', '.join(grandparents)}" + + async def get_call_graph(self, func_name: str, depth: int = 1) -> str: + """Get the call graph for a function (functions it calls and callers). + + Limited to depth 1 for performance. + """ + if not func_name or not func_name.strip(): + return "[ERROR: empty function name]" + + func_name = func_name.strip() + # Extract simple name from dotted paths (e.g., "module.sub.my_func" -> "my_func") + if "." in func_name: + func_name = func_name.rsplit(".", 1)[-1] + + # Find function definition + args = ["grep", "-rn"] + for ext in SEARCH_EXTENSIONS: + args.append(f"--include=*{ext}") + args.append("--") + args.append(f"def {func_name}") + args.append(self.root_path) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + ) + except subprocess.TimeoutExpired: + return "[ERROR: search timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: function '{func_name}' not found]" + + lines = result.stdout.splitlines() + if not lines: + return f"[ERROR: function '{func_name}' not found]" + + first_match = lines[0] + parts = first_match.split(":", 2) + if len(parts) < 3: + return f"[ERROR: could not parse function definition]" + + file_path = parts[0] + line_num = parts[1] + rel_path = os.path.relpath(file_path, self.root_path) + + # Find callers of this function + caller_args = ["grep", "-rn"] + for ext in SEARCH_EXTENSIONS: + caller_args.append(f"--include=*{ext}") + caller_args.append("--") + caller_args.append(f"{func_name}(") + caller_args.append(self.root_path) + + callers = [] + try: + caller_result = subprocess.run( + caller_args, + capture_output=True, + text=True, + timeout=10, + ) + if caller_result.returncode == 0: + for line in caller_result.stdout.splitlines()[:10]: + parts = line.split(":", 2) + if len(parts) >= 3: + caller_file = parts[0] + caller_line = parts[1] + caller_rel = os.path.relpath(caller_file, self.root_path) + callers.append(f" {caller_rel}:{caller_line}") + except Exception: + pass # Soft fail on caller search + + result_str = f"local:{rel_path}#L{line_num}\ndef {func_name}(...)" + if callers: + result_str += f"\n\nCallers ({len(callers)}):\n" + "\n".join(callers) + else: + result_str += "\n\nNo callers found" + + # If depth >= 1, find outgoing calls from this function + if depth >= 1: + outgoing = await self._get_outgoing_calls(file_path, func_name) + if outgoing: + result_str += f"\n\nCalls (outgoing):\n" + "\n".join(f" {call}" for call in outgoing) + else: + result_str += "\n\nCalls (outgoing): none found" + + return result_str + + async def _get_outgoing_calls(self, file_path: str, func_name: str) -> list[str]: + """Extract outgoing calls from a function definition. + + Returns list of function names called by func_name. + """ + # Extract simple name from dotted paths + if "." in func_name: + func_name = func_name.rsplit(".", 1)[-1] + + # Python keywords to filter out + keywords = { + "if", "for", "while", "return", "print", "range", "len", "str", + "int", "list", "dict", "set", "tuple", "type", "isinstance", + "hasattr", "getattr", "setattr", "super", "enumerate", "zip", + "map", "filter", "sorted", "reversed", "any", "all", "min", + "max", "sum", "abs", "round", "open", "format", "repr", "hash", + "id", "input", "next", "iter" + } + + # Read the file + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + except Exception: + return [] + + # Extract function body + func_body = self._extract_function_body(content, func_name) + if not func_body: + return [] + + # Find all function calls using regex: word followed by ( + pattern = r"\b(\w+)\s*\(" + matches = re.findall(pattern, func_body) + + # Filter out keywords and duplicates + calls = [] + seen = set() + for match in matches: + if match not in keywords and match not in seen: + calls.append(match) + seen.add(match) + + return calls[:10] # Limit to 10 results + + def _extract_function_body(self, content: str, func_name: str) -> str: + """Extract the body of a function from file content. + + Returns the function body as a string, or empty string if not found. + """ + lines = content.splitlines() + func_start = None + + # Find the function definition line + for i, line in enumerate(lines): + if re.match(rf"def\s+{re.escape(func_name)}\s*\(", line): + func_start = i + break + + if func_start is None: + return "" + + # Get the indentation level of the function definition + def_line = lines[func_start] + def_indent = len(def_line) - len(def_line.lstrip()) + + # Extract lines until we hit a line with same or less indentation (next function/class) + body_lines = [def_line] + for i in range(func_start + 1, len(lines)): + line = lines[i] + # Skip empty lines + if not line.strip(): + body_lines.append(line) + continue + # Check indentation + line_indent = len(line) - len(line.lstrip()) + if line_indent <= def_indent and line.strip(): + # Hit next function/class at same level + break + body_lines.append(line) + + return "\n".join(body_lines) + + async def get_pr_comments(self, pr_number: int) -> str: + """Get PR comments (not available in local mode).""" + return "[Not available for local repos — use --url mode]" + + async def get_blame(self, path: str, line_range: str = "") -> str: + """Get git blame information for a file or line range. + + Uses 'git blame' subprocess. Line range format: "10,20" for lines 10-20. + """ + if not path or not path.strip(): + return "[ERROR: empty path]" + + abs_path = self._resolve_path(path) + if abs_path is None: + return "[ERROR: invalid path]" + + if not os.path.exists(abs_path): + return "[ERROR: file not found]" + + # Build git blame command + args = ["git", "blame"] + if line_range: + # Parse line range "start,end" + try: + parts = line_range.split(",") + if len(parts) == 2: + start, end = parts[0].strip(), parts[1].strip() + args.append(f"-L{start},{end}") + except Exception: + pass # Ignore malformed line range + + args.append(abs_path) + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + cwd=self.root_path, + ) + except subprocess.TimeoutExpired: + return "[ERROR: blame timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: git blame failed]" + + # Limit output to first 30 lines + lines = result.stdout.splitlines()[:30] + return "\n".join(lines) if lines else "[ERROR: no blame output]" + + async def get_commit_history(self, path: str, limit: int = 5) -> str: + """Get commit history for a file. + + Uses 'git log --oneline' subprocess. + """ + if not path or not path.strip(): + return "[ERROR: empty path]" + + abs_path = self._resolve_path(path) + if abs_path is None: + return "[ERROR: invalid path]" + + if not os.path.exists(abs_path): + return "[ERROR: file not found]" + + # Clamp limit + limit = max(1, min(limit, 50)) + + # Build git log command + args = ["git", "log", "--oneline", f"-n{limit}", "--", abs_path] + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + cwd=self.root_path, + ) + except subprocess.TimeoutExpired: + return "[ERROR: log timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return "[ERROR: git log failed]" + + lines = result.stdout.splitlines() + if not lines: + return "[ERROR: no commit history found]" + + return "\n".join(lines) + + async def get_related_issues(self, query_text: str) -> str: + """Search for related issues in git log (not available in local mode for GitHub issues).""" + if not query_text or not query_text.strip(): + return "[ERROR: empty query]" + + query_text = query_text.strip() + + # Search git log messages for the query + args = ["git", "log", "--oneline", "--all", "--grep", query_text] + + try: + result = subprocess.run( + args, + capture_output=True, + text=True, + timeout=10, + cwd=self.root_path, + ) + except subprocess.TimeoutExpired: + return "[ERROR: search timeout]" + except Exception as e: + return f"[ERROR: {str(e)[:50]}]" + + if result.returncode != 0: + return f"[ERROR: no matching commits found for '{query_text}']" + + lines = result.stdout.splitlines()[:20] # Limit to 20 results + if not lines: + return f"[ERROR: no matching commits found for '{query_text}']" + + return f"Found {len(lines)} matching commits:\n" + "\n".join(lines) + async def close(self): """No-op for local tools (no HTTP client to close).""" pass - + def format_source(self, path: str, content: str | None = None, needle: str | None = None) -> str: """Format a source citation as local:path#Lx-Ly.""" line_range = "" if content: line_range = find_line_range(content, needle) return f"local:{path}{line_range}" - diff --git a/npx/python/cli/main.py b/npx/python/cli/main.py index 9f8207c..3d6e526 100644 --- a/npx/python/cli/main.py +++ b/npx/python/cli/main.py @@ -114,6 +114,8 @@ async def run_review( try: answer, sources, metadata = await runner.review(url, actual_question) except Exception as e: + import traceback + traceback.print_exc() print_error(f"Review failed: {e}") sys.exit(1) @@ -199,6 +201,8 @@ async def run_local_review( try: answer, sources, metadata = await runner.review_local(abs_path, actual_question) except Exception as e: + import traceback + traceback.print_exc() print_error(f"Review failed: {e}") sys.exit(1) diff --git a/npx/python/cli/repo_tools.py b/npx/python/cli/repo_tools.py index fcb31e1..9a04bd2 100644 --- a/npx/python/cli/repo_tools.py +++ b/npx/python/cli/repo_tools.py @@ -5,7 +5,9 @@ """ import asyncio +import logging import os +import re from typing import Any import httpx @@ -22,6 +24,7 @@ # --- State (per-run) --- _file_cache: dict[tuple[str, str], str] = {} # (ref, path) -> content _semaphore = asyncio.Semaphore(5) # Max 5 concurrent GitHub calls +logger = logging.getLogger(__name__) def _get_headers() -> dict[str, str]: @@ -136,28 +139,52 @@ def find_line_range(content: str, needle: str | None = None) -> str: class RepoTools: """Tools for exploring a GitHub repository beyond the PR diff.""" - def __init__(self, owner: str, repo: str, head_sha: str): + def __init__(self, owner: str, repo: str, head_sha: str, pr_number: int | None = None): """Initialize with repo context. - + Args: owner: Repository owner repo: Repository name head_sha: PR head commit SHA for consistent reads + pr_number: Optional PR number for PR-specific tools """ self.owner = owner self.repo = repo self.head_sha = head_sha + self.pr_number = pr_number self._client: httpx.AsyncClient | None = None - + self._client_loop: asyncio.AbstractEventLoop | None = None + async def _get_client(self) -> httpx.AsyncClient: - if self._client is None: + """Return an httpx AsyncClient, event-loop-aware. + + Tool calls are bridged via _sync_call → asyncio.run(), which creates a + new event loop each time. A client cached from a previous event loop + cannot be safely reused. This method detects loop changes and recreates + the client when needed, while caching within the same loop (useful when + a single tool call invokes _get_client multiple times, e.g. + get_call_graph → _get_outgoing_calls). + """ + loop = asyncio.get_running_loop() + if self._client is None or self._client_loop is not loop: + # Close stale client from a previous event loop + if self._client is not None: + try: + await self._client.aclose() + except Exception: + pass self._client = httpx.AsyncClient() + self._client_loop = loop return self._client - + async def close(self): if self._client: - await self._client.aclose() + try: + await self._client.aclose() + except Exception: + pass self._client = None + self._client_loop = None async def fetch_file(self, path: str) -> str: """Fetch any file from the repo at the PR's head commit. @@ -291,9 +318,7 @@ async def search_code(self, query: str) -> list[dict[str, Any]]: url = f"{GITHUB_API_BASE}/search/code" # Debug logging for bundled mode troubleshooting - print(f"[DEBUG-SEARCH] Query: '{search_query}'") - print(f"[DEBUG-SEARCH] URL: {url}") - print(f"[DEBUG-SEARCH] GITHUB_TOKEN present: {bool(GITHUB_TOKEN)}") + logger.debug("search_code query=%r url=%s token_present=%s", search_query, url, bool(GITHUB_TOKEN)) try: async with _semaphore: @@ -306,16 +331,16 @@ async def search_code(self, query: str) -> list[dict[str, Any]]: params={"q": search_query, "per_page": 10}, timeout=30.0, ) - print(f"[DEBUG-SEARCH] Response status: {resp.status_code}") + logger.debug("search_code response_status=%s", resp.status_code) except Exception as e: - print(f"[DEBUG-SEARCH] Exception: {e}") + logger.debug("search_code exception=%s", e) return [] # Soft fail if _is_rate_limited(resp): - print(f"[DEBUG-SEARCH] Rate limited!") + logger.debug("search_code rate_limited") return [] if resp.status_code != 200: - print(f"[DEBUG-SEARCH] Non-200 response: {resp.text[:500]}") + logger.debug("search_code non_200=%s body=%r", resp.status_code, resp.text[:500]) return [] # Soft fail data = resp.json() @@ -332,7 +357,750 @@ async def search_code(self, query: str) -> list[dict[str, Any]]: results.append(entry) return results - + + async def get_symbol_definition(self, symbol: str, context_file: str = "") -> str: + """Search for a symbol definition (function or class). + + Uses GitHub code search to find 'def {symbol}' or 'class {symbol}' patterns. + When context_file is provided, prioritizes results from that directory. + Returns file path + content snippet or error stub. + """ + if not symbol or not symbol.strip(): + return "[ERROR: empty symbol]" + + symbol = symbol.strip() + # Extract simple name from dotted paths (e.g., "dspy.adapters.DataFrame" -> "DataFrame") + if "." in symbol: + symbol = symbol.rsplit(".", 1)[-1] + client = await self._get_client() + + # Build search queries with optional path qualifier. + # Avoid regex-like OR expressions because GitHub code search expects + # query syntax, not raw regex. + def_query = f"def {symbol} repo:{self.owner}/{self.repo}" + class_query = f"class {symbol} repo:{self.owner}/{self.repo}" + + # If context_file is provided, extract directory and try scoped search first + context_dir = "" + if context_file and context_file.strip(): + context_dir = os.path.dirname(context_file).strip() + + url = f"{GITHUB_API_BASE}/search/code" + queries: list[str] = [] + if context_dir: + queries.extend([f"{def_query} path:{context_dir}", f"{class_query} path:{context_dir}"]) + queries.extend([def_query, class_query]) + + items: list[dict[str, Any]] = [] + last_status: int | None = None + for search_query in queries: + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 5}, + timeout=30.0, + ) + except Exception: + return "[ERROR: search failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + last_status = resp.status_code + continue + + data = resp.json() + items = data.get("items", []) + if items: + break + + if not items: + if last_status is not None: + return f"[ERROR: {last_status}]" + return f"[ERROR: symbol '{symbol}' not found]" + + # Return first match with path and fragment + item = items[0] + path = item.get("path", "") + fragment = "" + text_matches = item.get("text_matches", []) + if text_matches: + fragment = text_matches[0].get("fragment", "")[:500] + + result = f"Found in: {path}\n" + if fragment: + result += f"Definition:\n{fragment}" + return result + + async def find_usages(self, symbol: str, scope_path: str = ".") -> str: + """Search for usages of a symbol in the repository. + + When scope_path is provided and not ".", narrows search to that path. + Returns list of files + fragments where symbol is referenced. + """ + if not symbol or not symbol.strip(): + return "[ERROR: empty symbol]" + + symbol = symbol.strip() + # Extract simple name from dotted paths (e.g., "dspy.predict.rlm.RLM" -> "RLM") + if "." in symbol: + symbol = symbol.rsplit(".", 1)[-1] + client = await self._get_client() + + # Build search query with optional path qualifier + search_query = f"{symbol} repo:{self.owner}/{self.repo}" + + # Add path qualifier if scope_path is provided and not "." + if scope_path and scope_path.strip() and scope_path.strip() != ".": + search_query = f"{search_query} path:{scope_path.strip()}" + + url = f"{GITHUB_API_BASE}/search/code" + + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 10}, + timeout=30.0, + ) + except Exception: + return "[ERROR: search failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + data = resp.json() + items = data.get("items", []) + + if not items: + return f"[ERROR: no usages of '{symbol}' found]" + + # Format results + result = f"Found {len(items)} usages of '{symbol}':\n" + for item in items[:10]: + path = item.get("path", "") + result += f" - {path}\n" + + return result + + async def get_type_hierarchy(self, class_name: str) -> str: + """Get the type hierarchy (parent classes) for a class. + + Searches for the class definition, parses parent classes using regex, + and recursively resolves parent classes. + """ + if not class_name or not class_name.strip(): + return "[ERROR: empty class name]" + + class_name = class_name.strip() + # Extract simple name from dotted paths (e.g., "dspy.adapters.DataFrame" -> "DataFrame") + if "." in class_name: + class_name = class_name.rsplit(".", 1)[-1] + client = await self._get_client() + + # Search for class definition + search_query = f"class {class_name} repo:{self.owner}/{self.repo}" + url = f"{GITHUB_API_BASE}/search/code" + + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 5}, + timeout=30.0, + ) + except Exception: + return "[ERROR: search failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + data = resp.json() + items = data.get("items", []) + + if not items: + return f"[ERROR: class '{class_name}' not found]" + + # Fetch the file containing the class + path = items[0].get("path", "") + file_content = await self.fetch_file(path) + + if file_content.startswith("[ERROR:") or file_content.startswith("[SKIPPED:"): + return f"[ERROR: could not fetch {path}]" + + # Parse parent classes using regex: class ClassName(Parent1, Parent2): + # Use re.DOTALL to handle multi-line class definitions + pattern = rf"class\s+{re.escape(class_name)}\s*\(([^)]+)\)" + match = re.search(pattern, file_content, re.DOTALL) + + if not match: + return f"Class '{class_name}' has no parent classes (or is not found)" + + parents_str = match.group(1) + # Strip whitespace and newlines from each parent, filter out empty strings + parents = [p.strip() for p in parents_str.split(",") if p.strip()] + + result = f"Type hierarchy for '{class_name}':\n" + result += f" {class_name} extends: {', '.join(parents)}\n" + + # Resolve one level of parent classes + result += " Parent details:\n" + for parent in parents: + parent_hierarchy = await self._resolve_parent_class(parent) + if parent_hierarchy: + result += f" {parent_hierarchy}\n" + else: + result += f" {parent} (no parents found)\n" + + return result + + async def _resolve_parent_class(self, parent_name: str) -> str | None: + """Resolve one level of parent class hierarchy. + + Searches for the parent class definition and extracts its parents. + Returns a string like "ParentClass extends: GrandParent" or None if not found. + """ + if not parent_name or not parent_name.strip(): + return None + + parent_name = parent_name.strip() + client = await self._get_client() + + # Search for parent class definition + search_query = f"class {parent_name} repo:{self.owner}/{self.repo}" + url = f"{GITHUB_API_BASE}/search/code" + + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 5}, + timeout=30.0, + ) + except Exception: + return None + + if _is_rate_limited(resp) or resp.status_code != 200: + return None + + data = resp.json() + items = data.get("items", []) + + if not items: + return None + + # Fetch the file containing the parent class + path = items[0].get("path", "") + file_content = await self.fetch_file(path) + + if file_content.startswith("[ERROR:") or file_content.startswith("[SKIPPED:"): + return None + + # Parse parent classes of the parent class + pattern = rf"class\s+{re.escape(parent_name)}\s*\(([^)]+)\)" + match = re.search(pattern, file_content, re.DOTALL) + + if not match: + return None + + parents_str = match.group(1) + grandparents = [p.strip() for p in parents_str.split(",") if p.strip()] + + return f"{parent_name} extends: {', '.join(grandparents)}" + + async def get_call_graph(self, func_name: str, depth: int = 1) -> str: + """Get the call graph for a function. + + Searches for calls to func_name, and if depth > 0, searches for what + func_name calls by fetching its definition. + """ + if not func_name or not func_name.strip(): + return "[ERROR: empty function name]" + + func_name = func_name.strip() + # Extract simple name from dotted paths (e.g., "module.sub.my_func" -> "my_func") + if "." in func_name: + func_name = func_name.rsplit(".", 1)[-1] + client = await self._get_client() + + # Search for calls to the function + search_query = f"\"{func_name}(\" repo:{self.owner}/{self.repo}" + url = f"{GITHUB_API_BASE}/search/code" + + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 10}, + timeout=30.0, + ) + except Exception: + return "[ERROR: search failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + data = resp.json() + items = data.get("items", []) + + result = f"Call graph for '{func_name}':\n" + + if not items: + result += f" No calls found\n" + else: + # List files that call this function + result += f" Called in {len(items)} locations:\n" + for item in items[:10]: + path = item.get("path", "") + result += f" - {path}\n" + + # If depth >= 1, find outgoing calls from this function + if depth >= 1: + outgoing = await self._get_outgoing_calls(func_name) + if outgoing: + result += f"\n Calls (outgoing):\n" + for call in outgoing: + result += f" - {call}\n" + else: + result += f"\n Calls (outgoing): none found\n" + + return result + + async def _get_outgoing_calls(self, func_name: str) -> list[str]: + """Extract outgoing calls from a function definition. + + Returns list of function names called by func_name. + """ + # Python keywords to filter out + keywords = { + "if", "for", "while", "return", "print", "range", "len", "str", + "int", "list", "dict", "set", "tuple", "type", "isinstance", + "hasattr", "getattr", "setattr", "super", "enumerate", "zip", + "map", "filter", "sorted", "reversed", "any", "all", "min", + "max", "sum", "abs", "round", "open", "format", "repr", "hash", + "id", "input", "next", "iter" + } + + # Extract simple name from dotted paths + if "." in func_name: + func_name = func_name.rsplit(".", 1)[-1] + + # Search for function definition + search_query = f"def {func_name} repo:{self.owner}/{self.repo}" + url = f"{GITHUB_API_BASE}/search/code" + + client = await self._get_client() + + try: + async with _semaphore: + resp = await client.get( + url, + headers={ + **_get_headers(), + "Accept": "application/vnd.github.text-match+json", + }, + params={"q": search_query, "per_page": 5}, + timeout=30.0, + ) + except Exception: + return [] + + if _is_rate_limited(resp) or resp.status_code != 200: + return [] + + data = resp.json() + items = data.get("items", []) + + if not items: + return [] + + # Fetch the file containing the function + path = items[0].get("path", "") + file_content = await self.fetch_file(path) + + if file_content.startswith("[ERROR:") or file_content.startswith("[SKIPPED:"): + return [] + + # Extract function body + func_body = self._extract_function_body(file_content, func_name) + if not func_body: + return [] + + # Find all function calls using regex: word followed by ( + pattern = r"\b(\w+)\s*\(" + matches = re.findall(pattern, func_body) + + # Filter out keywords and duplicates + calls = [] + seen = set() + for match in matches: + if match not in keywords and match not in seen: + calls.append(match) + seen.add(match) + + return calls[:10] # Limit to 10 results + + def _extract_function_body(self, content: str, func_name: str) -> str: + """Extract the body of a function from file content. + + Returns the function body as a string, or empty string if not found. + """ + lines = content.splitlines() + func_start = None + + # Find the function definition line + for i, line in enumerate(lines): + if re.match(rf"def\s+{re.escape(func_name)}\s*\(", line): + func_start = i + break + + if func_start is None: + return "" + + # Get the indentation level of the function definition + def_line = lines[func_start] + def_indent = len(def_line) - len(def_line.lstrip()) + + # Extract lines until we hit a line with same or less indentation (next function/class) + body_lines = [def_line] + for i in range(func_start + 1, len(lines)): + line = lines[i] + # Skip empty lines + if not line.strip(): + body_lines.append(line) + continue + # Check indentation + line_indent = len(line) - len(line.lstrip()) + if line_indent <= def_indent and line.strip(): + # Hit next function/class at same level + break + body_lines.append(line) + + return "\n".join(body_lines) + + async def get_pr_comments(self, pr_number: int | None = None) -> str: + """Get all comments and reviews from a PR. + + Uses /repos/{owner}/{repo}/pulls/{pr_number}/reviews and /comments endpoints. + """ + pr_num = pr_number or self.pr_number + if not pr_num: + return "[ERROR: no PR number provided]" + + client = await self._get_client() + + # Fetch reviews + reviews_url = f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/pulls/{pr_num}/reviews" + comments_url = f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/pulls/{pr_num}/comments" + + result = f"PR #{pr_num} Comments and Reviews:\n" + + try: + async with _semaphore: + reviews_resp = await client.get(reviews_url, headers=_get_headers(), timeout=30.0) + comments_resp = await client.get(comments_url, headers=_get_headers(), timeout=30.0) + except Exception: + return "[ERROR: failed to fetch PR comments]" + + if reviews_resp.status_code == 200: + reviews = reviews_resp.json() + result += f"\nReviews ({len(reviews)}):\n" + for review in reviews[:10]: + author = review.get("user", {}).get("login", "unknown") + state = review.get("state", "UNKNOWN") + body = review.get("body", "")[:200] + result += f" - {author} ({state}): {body}\n" + + if comments_resp.status_code == 200: + comments = comments_resp.json() + result += f"\nComments ({len(comments)}):\n" + for comment in comments[:10]: + author = comment.get("user", {}).get("login", "unknown") + body = comment.get("body", "")[:200] + result += f" - {author}: {body}\n" + + return result + + def _parse_line_range(self, line_range: str) -> tuple[int, int] | None: + """Parse line range from "10-20" or "10,20" format. + + Returns (start, end) tuple or None if parsing fails. + """ + if not line_range or not line_range.strip(): + return None + + line_range = line_range.strip() + + # Try "10-20" format + if "-" in line_range: + parts = line_range.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + if start > 0 and end > 0 and start <= end: + return (start, end) + except ValueError: + pass + + # Try "10,20" format + if "," in line_range: + parts = line_range.split(",") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + if start > 0 and end > 0 and start <= end: + return (start, end) + except ValueError: + pass + + return None + + def _patch_touches_lines(self, patch: str, start_line: int, end_line: int) -> bool: + """Check if a unified diff patch touches the given line range. + + Parses @@ -start,count +start,count @@ headers to determine + if the patch modifies lines in the requested range. + """ + if not patch: + return False + + # Find all hunk headers: @@ -start,count +start,count @@ + # The second number is the line range in the new file + hunk_pattern = r"@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@" + + for match in re.finditer(hunk_pattern, patch): + hunk_start = int(match.group(1)) + hunk_count = int(match.group(2)) if match.group(2) else 1 + hunk_end = hunk_start + hunk_count - 1 + + # Check if this hunk overlaps with the requested range + if hunk_start <= end_line and hunk_end >= start_line: + return True + + return False + + async def get_blame(self, path: str, line_range: str = "") -> str: + """Get blame information for a file or line range. + + Parses line_range like "10-20" or "10,20" and returns commit info + for commits that touched those lines. If line_range is empty, + returns recent commits for the entire file. + """ + clean_path = sanitize_path(path) + if clean_path is None: + return "[ERROR: invalid path]" + + client = await self._get_client() + + # GitHub doesn't have a direct blame API, so we fetch commits for the file + url = f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/commits" + + try: + async with _semaphore: + resp = await client.get( + url, + headers=_get_headers(), + params={"path": clean_path, "per_page": 10}, + timeout=30.0, + ) + except Exception: + return "[ERROR: blame fetch failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + commits = resp.json() + + # If no line_range specified, return all recent commits + if not line_range or not line_range.strip(): + result = f"Blame for {clean_path} (recent commits):\n" + for commit in commits[:10]: + sha = commit.get("sha", "")[:7] + author = commit.get("commit", {}).get("author", {}).get("name", "unknown") + message = commit.get("commit", {}).get("message", "")[:100] + result += f" {sha} - {author}: {message}\n" + return result + + # Parse line_range + parsed_range = self._parse_line_range(line_range) + if parsed_range is None: + # Parsing failed - fall back to all commits with a note + result = f"Blame for {clean_path} (invalid line range '{line_range}', showing recent commits):\n" + for commit in commits[:10]: + sha = commit.get("sha", "")[:7] + author = commit.get("commit", {}).get("author", {}).get("name", "unknown") + message = commit.get("commit", {}).get("message", "")[:100] + result += f" {sha} - {author}: {message}\n" + return result + + start_line, end_line = parsed_range + + # Filter commits to those that touched the requested line range + matching_commits = [] + for commit in commits[:10]: + sha = commit.get("sha", "") + + # Fetch commit details to get the patch + commit_url = f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/commits/{sha}" + try: + async with _semaphore: + commit_resp = await client.get( + commit_url, + headers=_get_headers(), + timeout=30.0, + ) + except Exception: + # If we can't fetch commit details, skip it + continue + + if _is_rate_limited(commit_resp): + # Rate limited - return what we have so far + break + if commit_resp.status_code != 200: + # Skip this commit + continue + + commit_data = commit_resp.json() + files = commit_data.get("files", []) + + # Check if any file in this commit touches our line range + for file_info in files: + if file_info.get("filename") == clean_path: + patch = file_info.get("patch", "") + if self._patch_touches_lines(patch, start_line, end_line): + matching_commits.append(commit) + break + + # Format result + result = f"Blame for {clean_path} (lines {line_range}):\n" + + if not matching_commits: + result += f" No commits found that touched lines {line_range}\n" + else: + for commit in matching_commits: + sha = commit.get("sha", "")[:7] + author = commit.get("commit", {}).get("author", {}).get("name", "unknown") + message = commit.get("commit", {}).get("message", "")[:100] + result += f" {sha} - {author}: {message}\n" + + return result + + async def get_commit_history(self, path: str, limit: int = 5) -> str: + """Get commit history for a file. + + Uses /repos/{owner}/{repo}/commits?path={path}&per_page={limit} endpoint. + """ + clean_path = sanitize_path(path) + if clean_path is None: + return "[ERROR: invalid path]" + + client = await self._get_client() + url = f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/commits" + + try: + async with _semaphore: + resp = await client.get( + url, + headers=_get_headers(), + params={"path": clean_path, "per_page": min(limit, 100)}, + timeout=30.0, + ) + except Exception: + return "[ERROR: commit history fetch failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + commits = resp.json() + + result = f"Commit history for {clean_path} (last {len(commits)} commits):\n" + for commit in commits[:limit]: + sha = commit.get("sha", "")[:7] + author = commit.get("commit", {}).get("author", {}).get("name", "unknown") + date = commit.get("commit", {}).get("author", {}).get("date", "")[:10] + message = commit.get("commit", {}).get("message", "").split("\n")[0][:80] + result += f" {sha} ({date}) - {author}: {message}\n" + + return result + + async def get_related_issues(self, query_text: str) -> str: + """Search for related issues in the repository. + + Uses /search/issues?q={query_text}+repo:{owner}/{repo} endpoint. + """ + if not query_text or not query_text.strip(): + return "[ERROR: empty query]" + + query_text = query_text.strip() + client = await self._get_client() + + search_query = f"{query_text} repo:{self.owner}/{self.repo}" + url = f"{GITHUB_API_BASE}/search/issues" + + try: + async with _semaphore: + resp = await client.get( + url, + headers=_get_headers(), + params={"q": search_query, "per_page": 10}, + timeout=30.0, + ) + except Exception: + return "[ERROR: issue search failed]" + + if _is_rate_limited(resp): + return "[ERROR: rate limited]" + if resp.status_code != 200: + return f"[ERROR: {resp.status_code}]" + + data = resp.json() + items = data.get("items", []) + + if not items: + return f"[ERROR: no issues found for '{query_text}']" + + result = f"Related issues for '{query_text}' ({len(items)} found):\n" + for item in items[:10]: + number = item.get("number", "") + title = item.get("title", "")[:80] + state = item.get("state", "") + result += f" #{number} [{state}] {title}\n" + + return result + def format_source(self, path: str, content: str | None = None, needle: str | None = None) -> str: """Format a source citation as repo@sha:path#Lx-Ly.""" line_range = "" @@ -345,14 +1113,38 @@ def format_source(self, path: str, content: str | None = None, needle: str | Non TOOL_DESCRIPTIONS = """ AVAILABLE TOOLS (use via Python in REPL): + +BASIC TOOLS: - fetch_file(path: str) -> str: Fetch any file from the repo. Returns content or error stub. - list_directory(path: str = "") -> list[dict]: List {path, type, size} entries. - search_code(query: str) -> list[dict]: Search for patterns. Returns {path, fragment}. +CODE SEARCH & ANALYSIS: +- get_symbol_definition(symbol: str, context_file: str = "") -> str: Find function/class definition. Uses text search; context_file narrows to that directory. +- find_usages(symbol: str, scope_path: str = ".") -> str: Find all usages of a symbol. scope_path narrows search to a directory. +- get_type_hierarchy(class_name: str) -> str: Get parent classes for a class. Resolves one level of parents. +- get_call_graph(func_name: str, depth: int = 1) -> str: Get callers and outgoing calls for a function. depth controls outgoing edge resolution. + +GITHUB CONTEXT: +- get_pr_comments(pr_number: int | None = None) -> str: Get PR reviews and comments. +- get_blame(path: str, line_range: str = "") -> str: Get blame info for a file. line_range filters to commits touching those lines. +- get_commit_history(path: str, limit: int = 5) -> str: Get commit history for a file. +- get_related_issues(query_text: str) -> str: Search for related issues. + +NOTE: Code search tools use text-pattern matching, not AST analysis. Results may include matches in comments or strings. + TOOL USAGE RULES: 1. Fetch the minimum: prefer 1–3 files; don't traverse the repo. 2. If analysis depends on unchanged code, use fetch_file. 3. Use search_code to find paths; then fetch_file to read. 4. Files > 200KB return a stub—avoid large/generated files. 5. Use list_directory to understand structure first. +6. Use get_symbol_definition to find where a symbol is defined. +7. Use find_usages to understand impact of changes. +8. Use get_type_hierarchy to understand class relationships. +9. Use get_call_graph to trace function dependencies. +10. Use get_pr_comments to understand review feedback. +11. Use get_blame to find who changed what and when. +12. Use get_commit_history to understand evolution of a file. +13. Use get_related_issues to find context about bugs/features. """ diff --git a/npx/python/cli/virtual_runner.py b/npx/python/cli/virtual_runner.py index 89d51e4..6dee2d2 100644 --- a/npx/python/cli/virtual_runner.py +++ b/npx/python/cli/virtual_runner.py @@ -93,11 +93,12 @@ def _sync_call(self, coro): def _create_tool_functions(self): """Create sync tool wrapper functions for DSPy RLM. - Returns a dict of three sync tool functions as closures that capture + Returns a dict of {name: func} sync tool functions as closures that capture self by reference (so self._repo_tools can change between review calls). + DSPy RLM expects tools as dict[str, Callable[..., str]]. Returns: - Dict mapping tool name to function: {fetch_file, list_dir, search_code} + Dict mapping tool names to callable tool functions. """ runner = self @@ -179,7 +180,135 @@ def search_code(query: str) -> str: lines.append(f"{path}") return "\n".join(lines) - return {"fetch_file": fetch_file, "list_dir": list_dir, "search_code": search_code} + def get_symbol_definition(symbol: str, context_file: str = "") -> str: + """Find the definition of a symbol (function or class). + + Searches for 'def {symbol}' or 'class {symbol}' patterns in the repository. + Returns the file path and code snippet where the symbol is defined. + + Args: + symbol: Symbol name to search for (e.g., 'MyClass' or 'my_function') + context_file: Optional file path — narrows search to that file's directory + + Returns: + File path and code snippet, or error message + """ + return runner._sync_call(runner._repo_tools.get_symbol_definition(symbol, context_file)) + + def find_usages(symbol: str, scope_path: str = ".") -> str: + """Find all usages of a symbol in the repository. + + Searches for references to the given symbol across all files. + Returns a list of files where the symbol is used. + + Args: + symbol: Symbol name to search for + scope_path: Narrows search to files under this path + + Returns: + List of files containing usages, or error message + """ + return runner._sync_call(runner._repo_tools.find_usages(symbol, scope_path)) + + def get_type_hierarchy(class_name: str) -> str: + """Get the type hierarchy (parent classes) for a class. + + Searches for the class definition and extracts parent class information. + Returns the inheritance chain for the given class. + + Args: + class_name: Name of the class to analyze + + Returns: + Type hierarchy information, or error message + """ + return runner._sync_call(runner._repo_tools.get_type_hierarchy(class_name)) + + def get_call_graph(func_name: str, depth: int = 1) -> str: + """Get the call graph for a function. + + Searches for all locations where the function is called. + Returns a list of files and locations that call the function. + + Args: + func_name: Name of the function to analyze + depth: Controls outgoing edge resolution (what the function calls) + + Returns: + Call graph information, or error message + """ + return runner._sync_call(runner._repo_tools.get_call_graph(func_name, depth)) + + def get_pr_comments(pr_number: int = 0) -> str: + """Get all comments and reviews from a PR. + + Fetches reviews and comments from the PR associated with this review. + Returns formatted list of reviews and comments. + + Args: + pr_number: PR number (optional, uses PR from current review if not provided) + + Returns: + Formatted list of PR comments and reviews, or error message + """ + return runner._sync_call(runner._repo_tools.get_pr_comments(pr_number if pr_number else None)) + + def get_blame(path: str, line_range: str = "") -> str: + """Get blame information for a file or line range. + + Returns commit information for the specified file or line range, + showing who changed what and when. line_range filters commits to those touching the specified lines. + + Args: + path: File path to get blame for + line_range: Optional line range (e.g., '10-20') — filters to commits touching those lines + + Returns: + Blame information with commit details, or error message + """ + return runner._sync_call(runner._repo_tools.get_blame(path, line_range)) + + def get_commit_history(path: str, limit: int = 5) -> str: + """Get commit history for a file. + + Returns the recent commits that modified the specified file. + + Args: + path: File path to get history for + limit: Maximum number of commits to return (default: 5) + + Returns: + Commit history with dates and messages, or error message + """ + return runner._sync_call(runner._repo_tools.get_commit_history(path, limit)) + + def get_related_issues(query_text: str) -> str: + """Search for related issues in the repository. + + Searches for issues matching the given query text. + Returns a list of related issues. + + Args: + query_text: Search query for finding related issues + + Returns: + List of related issues, or error message + """ + return runner._sync_call(runner._repo_tools.get_related_issues(query_text)) + + return { + "fetch_file": fetch_file, + "list_dir": list_dir, + "search_code": search_code, + "get_symbol_definition": get_symbol_definition, + "find_usages": find_usages, + "get_type_hierarchy": get_type_hierarchy, + "get_call_graph": get_call_graph, + "get_pr_comments": get_pr_comments, + "get_blame": get_blame, + "get_commit_history": get_commit_history, + "get_related_issues": get_related_issues, + } def _ensure_configured(self): """Configure DSPy and RLM on first use.""" @@ -244,8 +373,9 @@ async def review(self, url: str, question: str) -> tuple[str, list[str], dict]: # Get head SHA for PR (for consistent file reads) head_sha = data.get("head_sha", "HEAD") - # Create repo tools for this review - self._repo_tools = RepoTools(owner, repo, head_sha) + # Create repo tools for this review, passing pr_number if this is a PR + pr_number = number if url_type == "pr" else None + self._repo_tools = RepoTools(owner, repo, head_sha, pr_number=pr_number) # Build context from PR data context = build_review_context(data) diff --git a/npx/python/tests/test_get_type_hierarchy.py b/npx/python/tests/test_get_type_hierarchy.py new file mode 100644 index 0000000..e36032c --- /dev/null +++ b/npx/python/tests/test_get_type_hierarchy.py @@ -0,0 +1,97 @@ +"""Tests for get_type_hierarchy function in repo_tools and local_repo_tools.""" + +import asyncio +import os +import tempfile +import pytest +from pathlib import Path + +from cli.repo_tools import RepoTools +from cli.local_repo_tools import LocalRepoTools + + +class TestLocalRepoToolsTypeHierarchy: + """Tests for LocalRepoTools.get_type_hierarchy with multi-line class definitions.""" + + @pytest.fixture + def temp_repo(self): + """Create a temporary repository with test classes.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a test Python file with multi-line class definition + test_file = Path(tmpdir) / "test_classes.py" + test_file.write_text(""" +class GrandParent: + pass + +class Parent( + GrandParent +): + pass + +class Child( + Parent, + object +): + pass + +class MultilineParent( + GrandParent, + object +): + pass +""") + yield tmpdir + + @pytest.mark.asyncio + async def test_single_line_class(self, temp_repo): + """Test parsing single-line class definition.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("GrandParent") + assert "GrandParent" in result + assert "no parents" in result.lower() + + @pytest.mark.asyncio + async def test_multiline_class_definition(self, temp_repo): + """Test parsing multi-line class definition.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("Parent") + assert "Parent" in result + assert "GrandParent" in result + assert "extends" in result + + @pytest.mark.asyncio + async def test_multiline_with_multiple_parents(self, temp_repo): + """Test parsing multi-line class with multiple parents.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("Child") + assert "Child" in result + assert "Parent" in result + assert "object" in result + + @pytest.mark.asyncio + async def test_parent_resolution(self, temp_repo): + """Test one level of parent resolution.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("Child") + # Should show Child's parents and their parents + assert "Parent details:" in result + assert "Parent extends:" in result or "Parent (no parents found)" in result + + @pytest.mark.asyncio + async def test_empty_class_name(self, temp_repo): + """Test error handling for empty class name.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("") + assert "ERROR" in result + + @pytest.mark.asyncio + async def test_nonexistent_class(self, temp_repo): + """Test error handling for non-existent class.""" + tools = LocalRepoTools(temp_repo) + result = await tools.get_type_hierarchy("NonExistentClass") + assert "ERROR" in result or "not found" in result.lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + diff --git a/npx/python/tests/test_repo_tooling_helpers.py b/npx/python/tests/test_repo_tooling_helpers.py new file mode 100644 index 0000000..eb3d3c6 --- /dev/null +++ b/npx/python/tests/test_repo_tooling_helpers.py @@ -0,0 +1,223 @@ +"""Unit tests for RepoTools and LocalRepoTools helper methods.""" + +import asyncio +import tempfile +from pathlib import Path + +from cli.local_repo_tools import LocalRepoTools +from cli.repo_tools import RepoTools + + +class FakeResponse: + def __init__(self, status_code=200, json_data=None, text="", headers=None): + self.status_code = status_code + self._json_data = json_data if json_data is not None else {} + self.text = text + self.headers = headers if headers is not None else {} + + def json(self): + return self._json_data + + +class FakeAsyncClient: + def __init__(self, responses): + self._responses = list(responses) + self.calls = [] + + async def get(self, url, headers=None, params=None, timeout=None): + self.calls.append({"url": url, "params": params or {}, "headers": headers or {}, "timeout": timeout}) + if not self._responses: + return FakeResponse(200, {"items": []}) + return self._responses.pop(0) + + +def test_repo_get_symbol_definition_uses_non_regex_queries(): + async def _run(): + tools = RepoTools("owner", "repo", "abcdef123456") + fake_client = FakeAsyncClient( + [ + FakeResponse(200, {"items": []}), + FakeResponse( + 200, + { + "items": [ + { + "path": "src/example.py", + "text_matches": [{"fragment": "class MySymbol(BaseClass):"}], + } + ] + }, + ), + ] + ) + + async def fake_get_client(): + return fake_client + + tools._get_client = fake_get_client # type: ignore[method-assign] + + result = await tools.get_symbol_definition("MySymbol", "src/current_file.py") + + assert "Found in: src/example.py" in result + queries = [call["params"].get("q", "") for call in fake_client.calls] + assert any("def MySymbol" in q for q in queries) + assert any("class MySymbol" in q for q in queries) + assert all("|" not in q for q in queries) + + asyncio.run(_run()) + + +def test_repo_find_usages_formats_results(): + async def _run(): + tools = RepoTools("owner", "repo", "abcdef123456") + fake_client = FakeAsyncClient( + [ + FakeResponse( + 200, + { + "items": [ + {"path": "src/a.py"}, + {"path": "src/b.py"}, + ] + }, + ) + ] + ) + + async def fake_get_client(): + return fake_client + + tools._get_client = fake_get_client # type: ignore[method-assign] + + result = await tools.find_usages("my_symbol", "src") + + assert "Found 2 usages of 'my_symbol'" in result + assert "src/a.py" in result + assert "src/b.py" in result + + asyncio.run(_run()) + + +def test_repo_get_call_graph_uses_literal_call_query(): + async def _run(): + tools = RepoTools("owner", "repo", "abcdef123456") + fake_client = FakeAsyncClient([FakeResponse(200, {"items": []})]) + + async def fake_get_client(): + return fake_client + + async def fake_outgoing(_func_name): + return [] + + tools._get_client = fake_get_client # type: ignore[method-assign] + tools._get_outgoing_calls = fake_outgoing # type: ignore[method-assign] + + result = await tools.get_call_graph("my_func", depth=1) + + assert "Call graph for 'my_func'" in result + query = fake_client.calls[0]["params"].get("q", "") + assert "\"my_func(\"" in query + assert "repo:owner/repo" in query + + asyncio.run(_run()) + + +def test_repo_get_blame_filters_line_range(): + async def _run(): + tools = RepoTools("owner", "repo", "abcdef123456") + commits_resp = FakeResponse( + 200, + [ + {"sha": "abc1111", "commit": {"author": {"name": "A"}, "message": "touch lines"}}, + {"sha": "def2222", "commit": {"author": {"name": "B"}, "message": "other lines"}}, + ], + ) + commit_a_resp = FakeResponse( + 200, + {"files": [{"filename": "src/file.py", "patch": "@@ -1,2 +10,5 @@\n+line"}]}, + ) + commit_b_resp = FakeResponse( + 200, + {"files": [{"filename": "src/file.py", "patch": "@@ -1,2 +50,2 @@\n+line"}]}, + ) + fake_client = FakeAsyncClient([commits_resp, commit_a_resp, commit_b_resp]) + + async def fake_get_client(): + return fake_client + + tools._get_client = fake_get_client # type: ignore[method-assign] + + result = await tools.get_blame("src/file.py", "10-12") + + assert "Blame for src/file.py (lines 10-12)" in result + assert "abc1111"[:7] in result + assert "def2222"[:7] not in result + + asyncio.run(_run()) + + +def test_repo_get_commit_history_and_related_issues(): + async def _run(): + tools = RepoTools("owner", "repo", "abcdef123456") + fake_client = FakeAsyncClient( + [ + FakeResponse( + 200, + [ + { + "sha": "111aaaa", + "commit": { + "author": {"name": "Dev", "date": "2026-01-01T10:00:00Z"}, + "message": "Refactor loader\n\nextra", + }, + } + ], + ), + FakeResponse( + 200, + { + "items": [ + {"number": 12, "title": "Fix loader bug", "state": "open"}, + ] + }, + ), + ] + ) + + async def fake_get_client(): + return fake_client + + tools._get_client = fake_get_client # type: ignore[method-assign] + + history = await tools.get_commit_history("src/file.py", limit=1) + issues = await tools.get_related_issues("loader") + + assert "Commit history for src/file.py" in history + assert "111aaaa"[:7] in history + assert "Related issues for 'loader'" in issues + assert "#12 [open] Fix loader bug" in issues + + asyncio.run(_run()) + + +def test_local_get_symbol_definition_matches_class_and_def(): + async def _run(): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "pkg").mkdir() + (root / "pkg" / "sample.py").write_text( + "class MyClass:\n" + " pass\n\n" + "def my_function():\n" + " return 1\n", + encoding="utf-8", + ) + + tools = LocalRepoTools(str(root)) + class_result = await tools.get_symbol_definition("MyClass") + func_result = await tools.get_symbol_definition("my_function", context_file="pkg/other.py") + + assert "local:pkg/sample.py" in class_result + assert "local:pkg/sample.py" in func_result + + asyncio.run(_run()) diff --git a/pyproject.toml b/pyproject.toml index 2c9075f..3086405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,14 +31,13 @@ dev = [ [project.scripts] cr = "cr.cli:main" -asyncreview = "cli.main:main" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] -packages = ["cr", "cli"] +packages = ["cr"] [tool.ruff] line-length = 100