From 0f800a13848f2197d3d1792f8500bfc5932ed23b Mon Sep 17 00:00:00 2001 From: prosdev Date: Sat, 14 Mar 2026 14:55:17 -0700 Subject: [PATCH 1/5] docs: add Phase 5 plan for exporter, tools, and SSRF transport Covers 3 parts: SSRF-hardened httpx transport (PinnedDNSBackend), 5 remaining tools (web_search, wikipedia, file_read, file_write, weather), and Python code exporter. Plan reviewed through 4 passes of plan-reviewer. Also updates research-planner agent with ASCII diagram guidance and plan-review handoff loop, and marks Phase 4 as merged in READMEs. Co-Authored-By: Claude Opus 4.6 (1M context) --- .claude/agents/research-planner.md | 29 +- .claude/gw-plans/README.md | 2 +- .claude/gw-plans/execution/README.md | 4 +- .../execution/phase-5-exporter-tools-ssrf.md | 461 ++++++++++++++++++ 4 files changed, 490 insertions(+), 6 deletions(-) create mode 100644 .claude/gw-plans/execution/phase-5-exporter-tools-ssrf.md diff --git a/.claude/agents/research-planner.md b/.claude/agents/research-planner.md index ffd09cf..96903b3 100644 --- a/.claude/agents/research-planner.md +++ b/.claude/agents/research-planner.md @@ -39,6 +39,13 @@ You are a senior software engineer helping me plan features. - Plans go in `.claude/gw-plans/` following the existing structure. - Each plan should be self-contained: someone reading only the plan file should understand what to build and why. +- **Include ASCII diagrams** to visualize architecture, data flow, or component relationships. + Diagrams make plans faster to review and easier to reason about. Use them for: + - Module/file dependency graphs + - Request/response flows (e.g., how a tool call flows from state → input_map → tool → output_key) + - State transformations (what goes in, what comes out) + - Security boundaries (what's sandboxed, what talks to external systems) + - Anything where a picture is worth 50 lines of prose - Include a "Not in Scope" section to prevent scope creep. - Include a "Decisions & Risks" section documenting assumptions and their mitigations. - Include a "Commit Plan" section: ordered list of commits, each with a conventional commit @@ -70,10 +77,26 @@ You are a senior software engineer helping me plan features. - Use your judgement on the threshold. If a plan exceeds ~400 lines or has 3+ distinct commits touching different modules, split it. -## Revision Workflow (for the orchestrating agent) +## After Writing a Plan + +Once you have written or updated a plan file, **signal that it's ready for review**. +Include the plan file path in your response so the main conversation knows what to review. + +The agentic loop (driven by the main conversation): +1. **research-planner** writes/updates the plan → signals "ready for review" +2. Main conversation launches **plan-reviewer** on the plan file +3. If plan-reviewer returns REVISE → main conversation resumes **research-planner** to address findings +4. Repeat until plan-reviewer returns APPROVE + +You cannot spawn sub-agents yourself. End your response with a clear handoff: +``` +READY FOR REVIEW: .claude/gw-plans/path/to/plan.md +``` + +## Revision Workflow When plan-reviewer findings need to be applied to a large feature (overview + parts): 1. **Fix the overview first** (sequentially) — it sets the architecture decisions that parts reference. 2. **Fix the part files in parallel** — they are independent of each other and can reference the updated overview. This gives consistency and speed. -The research-planner cannot spawn sub-agents itself. The orchestrating agent (main conversation) -should launch parallel research-planner invocations for the part files after the overview is done. +The research-planner cannot spawn sub-agents itself. The main conversation should launch +parallel research-planner invocations for the part files after the overview is done. diff --git a/.claude/gw-plans/README.md b/.claude/gw-plans/README.md index f20f5c4..40b608c 100644 --- a/.claude/gw-plans/README.md +++ b/.claude/gw-plans/README.md @@ -9,7 +9,7 @@ Implementation deviations are logged at the bottom of each plan file. | Track | Description | Status | |-------|-------------|--------| -| [Execution](execution/) | FastAPI + LangGraph backend | Phases 1-3 merged, 4 in progress | +| [Execution](execution/) | FastAPI + LangGraph backend | Phases 1-4 merged, 5 in progress | | [Canvas](canvas/) | React + React Flow frontend | Not started | | Deployment | Cloud Run + Vercel + CI/CD | Not started (after both tracks) | diff --git a/.claude/gw-plans/execution/README.md b/.claude/gw-plans/execution/README.md index e48d44a..e5cca94 100644 --- a/.claude/gw-plans/execution/README.md +++ b/.claude/gw-plans/execution/README.md @@ -10,5 +10,5 @@ FastAPI + LangGraph backend phases. | 1.5 | [Scoped API Key Auth](phase-1.5-execution-auth.md) | Merged | [#2](https://github.com/prosdevlab/graphweave/pull/2) | | 2 | [GraphSchema -> LangGraph Builder](phase-2-graph-schema-langgraph-builder.md) | Merged | [#3](https://github.com/prosdevlab/graphweave/pull/3) | | 3 | [Executor + SSE Streaming](phase-3/overview.md) | Merged | [#6](https://github.com/prosdevlab/graphweave/pull/6) | -| 4 | [API Routes (validate, export, run history, cancel, delete)](phase-4-api-routes.md) | In progress | — | -| 5 | Exporter + remaining tools + SSRF transport | Not started | — | +| 4 | [API Routes (validate, export, run history, cancel, delete)](phase-4-api-routes.md) | Merged | [#7](https://github.com/prosdevlab/graphweave/pull/7), [#8](https://github.com/prosdevlab/graphweave/pull/8) | +| 5 | [Exporter + Remaining Tools + SSRF Transport](phase-5-exporter-tools-ssrf.md) | In progress | — | diff --git a/.claude/gw-plans/execution/phase-5-exporter-tools-ssrf.md b/.claude/gw-plans/execution/phase-5-exporter-tools-ssrf.md new file mode 100644 index 0000000..da6dbf8 --- /dev/null +++ b/.claude/gw-plans/execution/phase-5-exporter-tools-ssrf.md @@ -0,0 +1,461 @@ +# Phase 5: Exporter + Remaining Tools + SSRF Transport + +## Context + +Phases 1-4 built the full execution pipeline: DB, auth, builder, executor, SSE streaming, and API routes. The graph runs end-to-end. Phase 5 fills two remaining gaps: + +1. **Exporter** — the `GET /v1/graphs/{id}/export` route returns 501. Users need standalone Python code they can run outside GraphWeave. +2. **Tools** — 3 of 8 v1 tools are implemented (calculator, datetime, url_fetch). The remaining 5 round out the tool registry. +3. **SSRF hardening** — `url_fetch` has an IP-level SSRF guard but is vulnerable to DNS rebinding. A custom httpx transport pins resolved IPs. + +## What exists today + +| Component | Status | +|-----------|--------| +| `app/exporter.py` | Stub — returns `"# TODO: implement"` | +| `GET /{graph_id}/export` | Returns 501, ownership-checked | +| `ExportResponse` schema | Defined in `schemas/graphs.py` (code + requirements) | +| `app/tools/base.py` | `BaseTool` ABC with `run(inputs) -> dict` | +| `app/tools/registry.py` | 3 tools: calculator, datetime, url_fetch | +| `app/tools/url_fetch.py` | `validate_url()` blocks private IPs, `follow_redirects=False` | +| Tool response envelope | `{success, result/error, recoverable, source}` | + +--- + +## Part 5.1: SSRF-Hardened Transport + +### Problem + +`validate_url()` resolves the hostname to check for private IPs, then httpx resolves it again for the actual request. A DNS rebinding attack returns a public IP for validation, then a private IP (e.g., `169.254.169.254`) for the real request. + +### Solution + +Override DNS resolution at the httpcore level so the pinned IP is used for the actual connection while preserving the original hostname for TLS SNI and the `Host` header. This avoids the naive approach of replacing the URL host with an IP (which breaks HTTPS certificate validation). + +**New file**: `app/tools/ssrf_transport.py` + +```python +import httpcore +from httpcore._backends.sync import SyncBackend + +class PinnedDNSBackend(httpcore.NetworkBackend): + """Network backend that substitutes hostname with a pinned IP in connect_tcp. + + Wraps the default SyncBackend. When connect_tcp is called, replaces the + `host` parameter with the pinned IP while leaving everything else + unchanged. This means: + - TCP connection goes to the pinned IP (SSRF-safe) + - TLS SNI uses the original hostname (httpcore passes it separately) + - Host header uses the original hostname (httpx sets it from the URL) + """ + + def __init__(self, pinned_ip: str): + self._pinned_ip = pinned_ip + self._backend = SyncBackend() + + def connect_tcp(self, host, port, timeout=None, local_address=None, socket_options=None): + # Substitute host with pinned IP — all other params (including SNI) unchanged + return self._backend.connect_tcp( + self._pinned_ip, port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) +``` + +```python +import httpx +import httpcore + +class SSRFSafeTransport(httpx.HTTPTransport): + """httpx transport that pins DNS to a pre-validated IP. + + Subclasses HTTPTransport to inherit its handle_request method, which + correctly converts between httpx.Request/Response and httpcore types. + We only replace the internal connection pool with one using our + PinnedDNSBackend. + """ + + def __init__(self, pinned_ip: str, **kwargs): + super().__init__(**kwargs) + # Preserve ssl_context from the pool created by super().__init__ + # (respects verify=, cert=, trust_env= kwargs). Then replace the + # pool with one using our pinned DNS backend. + existing_pool = self._pool + self._pool = httpcore.ConnectionPool( + ssl_context=existing_pool._ssl_context, + network_backend=PinnedDNSBackend(pinned_ip), + ) +``` + +**Why subclass `HTTPTransport` instead of `BaseTransport`**: `HTTPTransport.handle_request` converts `httpx.Request` → `httpcore.Request`, calls `self._pool.handle_request()`, then converts `httpcore.Response` → `httpx.Response`. If we subclassed `BaseTransport`, we'd need to replicate this conversion logic (including private imports like `map_httpcore_exceptions` and `ResponseStream`). Subclassing `HTTPTransport` and replacing `self._pool` gives us the correct conversion for free. + +Flow: +1. `validate_url()` resolves hostname → returns `(error, resolved_ip)` tuple +2. `SSRFSafeTransport(pinned_ip)` creates a `ConnectionPool` with `PinnedDNSBackend` +3. `PinnedDNSBackend.connect_tcp()` substitutes the host with the pinned IP +4. TLS SNI and `Host` header use the original hostname (httpx/httpcore handle this from the URL, not from `connect_tcp`'s host param) +5. Works correctly for both HTTP and HTTPS + +Changes to `url_fetch.py`: +- `validate_url()` returns `tuple[str | None, str | None]` — `(error, resolved_ip)` +- `UrlFetchTool.run()` creates `httpx.Client(transport=SSRFSafeTransport(resolved_ip))` +- **Existing tests in `tests/unit/test_tools/test_url_fetch.py` must be updated** to handle the new tuple return type + +### Tests + +| Test | What it verifies | +|------|-----------------| +| `test_backend_connect_tcp_receives_pinned_ip` | `PinnedDNSBackend.connect_tcp` called with pinned IP, not original hostname | +| `test_transport_end_to_end_with_httpx_client` | Real `httpx.Client(transport=SSRFSafeTransport(...))` completes a request — catches type conversion bugs | +| `test_transport_sets_host_header` | Original hostname in Host header for virtual hosting | +| `test_transport_https_sni_correct` | TLS SNI uses original hostname, not pinned IP — cert validation passes | +| `test_validate_url_returns_tuple` | Updated return type `(error, resolved_ip)` | +| `test_existing_ssrf_guards_unchanged` | Private IP, loopback, link-local still blocked | +| `test_transport_preserves_ssl_context` | `SSRFSafeTransport(verify=False)` forwards ssl_context to replacement pool | +| `test_url_fetch_existing_tests_updated` | All existing url_fetch tests pass with tuple return | + +--- + +## Part 5.2: Remaining Tools (5 tools) + +All tools follow the existing pattern: extend `BaseTool`, implement `run(inputs) -> dict`, return the response envelope. + +### 5.2.1 `web_search` + +**File**: `app/tools/web_search.py` + +| Input | Type | Description | +|-------|------|-------------| +| `query` | str | Search query | +| `max_results` | int | Max results (default 5, max 10) | + +**Behavior**: +- If `TAVILY_API_KEY` env var is set → use Tavily API (`tavily-python`) +- If not set → fall back to DuckDuckGo (`duckduckgo-search`, no API key needed) +- Returns list of `{title, url, snippet}` as the result string (formatted) +- `max_results` clamped to 10 regardless of input +- 10-second timeout on both providers +- `recoverable: True` on timeout/network errors +- `recoverable: False` on empty query + +**Dependencies**: `tavily-python>=0.5.0,<1.0`, `duckduckgo-search>=6.0.0,<8.0` + +**Note on duckduckgo-search**: The API changed between major versions. Implementation must verify the actual `DDGS` class interface against the installed version. Pin to `>=6.0.0,<8.0` and add a comment documenting the expected method signature (`DDGS().text(query, max_results=N)`). + +### 5.2.2 `wikipedia` + +**File**: `app/tools/wikipedia_tool.py` + +| Input | Type | Description | +|-------|------|-------------| +| `query` | str | Search term | +| `action` | str | `"search"` (list titles) or `"page"` (get content) | +| `title` | str | Page title (required for `action=page`) | + +**Behavior**: +- `action=search` → calls MediaWiki opensearch API directly via httpx: + `https://en.wikipedia.org/w/api.php?action=opensearch&search={query}&limit=5&format=json` + Returns list of matching titles (max 5). No extra library needed — plain httpx GET. +- `action=page` → uses `wikipediaapi` (PyPI: `Wikipedia-API`). + **Must initialize with `user_agent`** (mandatory since v0.6.0): + ```python + wiki = wikipediaapi.Wikipedia(language="en", user_agent="GraphWeave/1.0") + page = wiki.page(title) + ``` + Returns page summary + first 10K chars of content. +- `recoverable: True` on network errors +- `recoverable: False` on page not found or missing title for `action=page` + +**Note**: `wikipediaapi` only supports page retrieval, not search. The `action=search` path uses the MediaWiki opensearch API directly via httpx, avoiding a second Wikipedia library. + +**Dependencies**: `Wikipedia-API>=0.7.0` + +### 5.2.3 `file_read` + +**File**: `app/tools/file_read.py` + +| Input | Type | Description | +|-------|------|-------------| +| `path` | str | File path relative to `/workspace` | + +**Behavior**: +- Sandboxed to `/workspace` directory (configurable via `FILE_SANDBOX_ROOT` env var, default `/workspace`) +- Path traversal prevention: resolve path, verify it starts with sandbox root +- Opens files with `O_NOFOLLOW` flag to prevent symlink-based TOCTOU attacks (see Sandbox section) +- Reads with explicit `encoding="utf-8"` — binary files not supported (v2) +- Max file size: 1MB — returns `recoverable: False` if exceeded +- Returns file content as string (10K char truncation like url_fetch) +- `truncated: True` if content was truncated +- `recoverable: False` on permission error, file not found, path traversal, encoding error + +### 5.2.4 `file_write` + +**File**: `app/tools/file_write.py` + +| Input | Type | Description | +|-------|------|-------------| +| `path` | str | File path relative to `/workspace` | +| `content` | str | Content to write (text only, UTF-8) | +| `mode` | str | `"overwrite"` (default) or `"append"` | + +**Behavior**: +- Same sandbox as `file_read` (`FILE_SANDBOX_ROOT`) +- Path traversal prevention (same as file_read) +- Opens files with `O_NOFOLLOW` flag (see Sandbox section) +- Writes with explicit `encoding="utf-8"` — binary content not supported (v2) +- Creates parent directories if needed +- Max content size: 1MB — `recoverable: False` if exceeded +- Returns `{success: True, result: "Written N bytes to path"}` +- `recoverable: False` on permission error, path traversal, encoding error + +### 5.2.5 `weather` + +**File**: `app/tools/weather.py` + +| Input | Type | Description | +|-------|------|-------------| +| `location` | str | City name or "lat,lon" | +| `action` | str | `"current"` or `"forecast"` (default: `"current"`) | + +**Behavior**: +- Uses Open-Meteo API (free, no key required) +- Step 1: If `location` matches `"lat,lon"` pattern → skip geocoding, use directly +- Step 2: Otherwise geocode location name via Open-Meteo geocoding API → lat/lon +- Step 3: Fetch weather data from Open-Meteo forecast API +- `action=current` → temperature, humidity, wind, conditions +- `action=forecast` → 7-day daily forecast (high/low, conditions) +- 10-second timeout +- Open-Meteo URLs are hardcoded public API endpoints (no SSRF risk — not user-influenced) +- `recoverable: True` on timeout/network errors +- `recoverable: False` on unknown location + +**Dependencies**: None extra — uses httpx (already a dependency) + +### Shared: File sandbox utility + +**File**: `app/tools/sandbox.py` + +Shared between `file_read` and `file_write`: + +```python +def resolve_sandboxed_path(path: str, sandbox_root: str) -> str | None: + """Resolve path within sandbox. Returns absolute path or None if escaped.""" +``` + +**TOCTOU mitigation**: The sandbox uses a two-layer defense: + +1. **`os.path.realpath()` pre-check** — resolves symlinks and verifies the resolved path starts with `sandbox_root`. Rejects obvious traversal attempts. +2. **`O_NOFOLLOW` on open** — file_read and file_write open files with `os.O_NOFOLLOW`, which refuses to follow symlinks at the final path component. This prevents a TOCTOU race where a symlink is swapped between validation and open. + +**Additional context**: `/workspace` is an ephemeral per-run Docker volume. Each graph execution gets its own isolated `/workspace`. There is no persistent attacker presence across runs, which significantly limits the TOCTOU window. The `O_NOFOLLOW` defense is belt-and-suspenders. + +```python +def open_sandboxed( + path: str, sandbox_root: str, flags: int, mode: int = 0o644 +) -> int: + """Open a file within the sandbox with O_NOFOLLOW. Returns fd or raises.""" +``` + +**Note**: `O_NOFOLLOW` only applies to the final path component (the leaf). Parent directory symlink traversal is prevented by the `os.path.realpath()` pre-check, which must always run first. + +**I/O pattern for callers**: +- `file_write`: Call `os.makedirs(parent, exist_ok=True)` before `open_sandboxed` to create parent directories. Wrap the raw fd with `os.fdopen(fd, "w", encoding="utf-8")` to get a proper file object for string writes. +- `file_read`: Wrap the raw fd with `os.fdopen(fd, "r", encoding="utf-8")`. Non-UTF-8 files will raise `UnicodeDecodeError` → return `recoverable: False`. + +### Registry update + +Add all 5 to `REGISTRY` in `registry.py`. + +### Tests per tool + +Each tool gets its own test file in `tests/unit/`: + +| Tool | Tests | Key scenarios | +|------|-------|---------------| +| `web_search` | 6 | Tavily path, DDG fallback, empty query, timeout, max_results clamped to 10, max_results within range | +| `wikipedia` | 6 | Search results, page content, disambiguation error, not found, truncation, user_agent set | +| `file_read` | 7 | Read file, path traversal blocked, file not found, too large, truncation, symlink escape (O_NOFOLLOW), empty file | +| `file_write` | 8 | Write file, append mode, path traversal blocked, creates dirs, too large, symlink escape (O_NOFOLLOW), encoding error on non-UTF-8, symlink in parent directory | +| `weather` | 5 | Current weather, forecast, unknown location, timeout, lat/lon input format (skip geocoding) | +| `registry` | 1 | `test_registry_has_all_eight_tools` — assert `len(REGISTRY) == 8` and each name exists | +| `file roundtrip` | 1 | `test_file_roundtrip_write_then_read` — write via file_write, read back via file_read, assert match | + +`web_search` tests mock the Tavily/DDG clients. `weather` tests mock httpx responses. File tools use `tmp_path` fixture with `FILE_SANDBOX_ROOT` monkeypatched. + +--- + +## Part 5.3: Exporter + +### What it generates + +The exporter takes a `GraphSchema` dict and produces standalone Python code that recreates the graph using LangGraph directly — no GraphWeave dependency needed. + +**Output structure** (uses `TypedDict` — standard LangGraph pattern): + +```python +"""Generated by GraphWeave — standalone LangGraph graph.""" + +import operator +from typing import Annotated, TypedDict +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +# ... other imports based on graph content + + +class GraphState(TypedDict): + messages: Annotated[list, ...] + result: str + + +# Node functions +async def llm_node_1(state: GraphState) -> dict: + ... + +def tool_node_1(state: GraphState) -> dict: + ... + + +# Build graph +graph = StateGraph(GraphState) +graph.add_node("llm_1", llm_node_1) +... +graph.add_edge(START, "llm_1") +... +compiled = graph.compile() + + +# Run +if __name__ == "__main__": + import asyncio + result = asyncio.run(compiled.ainvoke({...defaults...})) + print(result) +``` + +**Key difference from builder**: The builder uses `type()` with `__annotations__` (dynamic, works at runtime). The exporter generates `TypedDict` subclass source code (static, matches LangGraph docs and user expectations). + +### Implementation: `app/exporter.py` + +```python +def export_graph(schema: dict) -> dict: + """Generate standalone Python code from a GraphSchema.""" + # Returns {"code": str, "requirements": str} +``` + +Code generation sections (in order): +1. **Imports** — scan nodes to determine which imports are needed (langchain providers, tools, etc.) +2. **State class** — generate `class GraphState(TypedDict)` with annotations matching the schema +3. **Node functions** — generate function bodies for each non-start/non-end node +4. **Graph construction** — `StateGraph(GraphState)`, `add_node`, `add_edge`, `add_conditional_edges` +5. **Compilation** — `graph.compile()`, with checkpointer comment if human_input nodes present +6. **Main block** — `if __name__ == "__main__"` with defaults-merged invocation + +**Requirements generation**: Scan the schema for: +- Base: `langgraph`, `langchain-core` +- LLM providers used → `langchain-openai`, `langchain-anthropic`, `langchain-google-genai` +- Tools used → `simpleeval` (calculator), `httpx` + `trafilatura` (url_fetch), etc. +- No LLM nodes → no provider deps in requirements + +### Route update + +Change `GET /{graph_id}/export` from 501 stub to return `ExportResponse`: + +```python +@router.get("/{graph_id}/export", response_model=ExportResponse) +async def export_graph_route(...): + result = export_graph(graph.schema_json) + return ExportResponse(code=result["code"], requirements=result["requirements"]) +``` + +### Tests + +| Test | What it verifies | +|------|-----------------| +| `test_export_linear_graph` | start → llm → end produces valid Python | +| `test_export_with_tool_node` | Tool node generates correct function body | +| `test_export_with_condition` | Condition generates routing function + `add_conditional_edges` | +| `test_export_with_human_input` | Includes interrupt import + checkpointer comment | +| `test_export_requirements_openai` | Requirements include `langchain-openai` | +| `test_export_requirements_multi_provider` | Multiple providers listed | +| `test_export_requirements_no_llm` | Tool-only graph has no provider deps | +| `test_export_state_typeddict` | State class is `TypedDict` subclass with correct annotations | +| `test_export_code_compiles` | `compile(exported_code)` doesn't raise SyntaxError | +| `test_export_code_ast_structure` | `ast.parse()` + `ast.walk()` verifies expected function defs, class defs, imports | +| `test_export_complex_graph` | Graph with LLM + tool + condition + human_input — all node types combined | +| `test_export_route_returns_200` | Route returns ExportResponse, not 501 | +| `test_export_route_not_found` | 404 for nonexistent graph | + +--- + +## Commit Checkpoints + +| Checkpoint | What's in it | +|------------|-------------| +| 1 | SSRF transport + updated url_fetch + updated existing url_fetch tests | +| 2 | File sandbox utility + file_read + file_write + tests (including roundtrip) | +| 3 | web_search + wikipedia + tests | +| 4 | weather tool + registry count test + tests | +| 5 | Exporter implementation + tests | +| 6 | Export route update (501 → 200) + integration test | + +--- + +## Dependencies to add + +```toml +# pyproject.toml +"tavily-python>=0.5.0,<1.0", +"duckduckgo-search>=6.0.0,<8.0", +"Wikipedia-API>=0.7.0", +``` + +--- + +## Files Summary + +| Action | File | +|--------|------| +| CREATE | `app/tools/ssrf_transport.py` | +| CREATE | `app/tools/web_search.py` | +| CREATE | `app/tools/wikipedia_tool.py` | +| CREATE | `app/tools/file_read.py` | +| CREATE | `app/tools/file_write.py` | +| CREATE | `app/tools/weather.py` | +| CREATE | `app/tools/sandbox.py` | +| CREATE | `tests/unit/test_ssrf_transport.py` | +| CREATE | `tests/unit/test_web_search.py` | +| CREATE | `tests/unit/test_wikipedia.py` | +| CREATE | `tests/unit/test_file_read.py` | +| CREATE | `tests/unit/test_file_write.py` | +| CREATE | `tests/unit/test_weather.py` | +| CREATE | `tests/unit/test_exporter.py` | +| MODIFY | `app/tools/url_fetch.py` — use SSRFSafeTransport, update validate_url return type | +| MODIFY | `app/tools/registry.py` — register 5 new tools | +| MODIFY | `app/exporter.py` — full implementation | +| MODIFY | `app/routes/graphs.py` — export route 501 → 200 | +| MODIFY | `tests/unit/test_tools/test_url_fetch.py` — update for tuple return type | +| MODIFY | `pyproject.toml` — add 3 dependencies | +| REGEN | `uv.lock` | + +--- + +## Verification + +```bash +cd packages/execution +uv sync +uv run ruff check app/ tests/ +uv run ruff format --check app/ tests/ +uv run pytest tests/unit/ -v +``` + +--- + +## Not in scope + +- Tool parameter configuration in schema (v2 — per-tool config UI in canvas) +- Image/media tools (v2) +- Database tools (v2 — multi-tenant safety unclear) +- Binary file I/O (v2 — file_read/file_write are UTF-8 text only) +- Persistent file storage across runs (files live in `/workspace` per-run, ephemeral in Docker) +- Export to other formats (Jupyter notebook, etc.) From 830b1156822d09a2086a68509792b4babf4589d5 Mon Sep 17 00:00:00 2001 From: prosdev Date: Sat, 14 Mar 2026 15:18:31 -0700 Subject: [PATCH 2/5] feat: add exporter, 5 tools, and SSRF-hardened transport Phase 5 of the execution layer: - SSRF transport: PinnedDNSBackend pins resolved IPs at httpcore level, preventing DNS rebinding while preserving TLS SNI. validate_url() now returns (error, resolved_ip) tuple. - 5 new tools: web_search (Tavily + DDG fallback), wikipedia (MediaWiki opensearch + wikipediaapi), file_read, file_write (sandboxed with O_NOFOLLOW + realpath), weather (Open-Meteo). - Exporter: generates standalone Python from GraphSchema with TypedDict state, node functions, routing, and requirements. - Export route upgraded from 501 stub to 200 with ExportResponse. - Registry expanded from 3 to 8 tools. 302 unit tests, 8 manual tests (30-37) all passing. Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/execution/app/exporter.py | 514 +++++++++++++++++- packages/execution/app/routes/graphs.py | 21 +- packages/execution/app/tools/file_read.py | 77 +++ packages/execution/app/tools/file_write.py | 89 +++ packages/execution/app/tools/registry.py | 10 + packages/execution/app/tools/sandbox.py | 45 ++ .../execution/app/tools/ssrf_transport.py | 61 +++ packages/execution/app/tools/url_fetch.py | 38 +- packages/execution/app/tools/weather.py | 178 ++++++ packages/execution/app/tools/web_search.py | 99 ++++ .../execution/app/tools/wikipedia_tool.py | 115 ++++ packages/execution/pyproject.toml | 3 + .../tests/manual/test_30_url_fetch_real.py | 48 ++ .../tests/manual/test_31_web_search_ddg.py | 47 ++ .../tests/manual/test_32_web_search_tavily.py | 33 ++ .../tests/manual/test_33_wikipedia_real.py | 43 ++ .../tests/manual/test_34_weather_real.py | 41 ++ .../tests/manual/test_35_file_sandbox.py | 89 +++ .../tests/manual/test_36_exporter_exec.py | 204 +++++++ .../tests/manual/test_37_export_route.py | 115 ++++ .../execution/tests/unit/test_exporter.py | 373 +++++++++++++ packages/execution/tests/unit/test_routes.py | 8 +- .../tests/unit/test_routes_export.py | 117 ++++ .../tests/unit/test_tools/test_file_read.py | 89 +++ .../tests/unit/test_tools/test_file_write.py | 119 ++++ .../tests/unit/test_tools/test_registry.py | 18 +- .../unit/test_tools/test_ssrf_transport.py | 88 +++ .../tests/unit/test_tools/test_url_fetch.py | 25 +- .../tests/unit/test_tools/test_weather.py | 109 ++++ .../tests/unit/test_tools/test_web_search.py | 101 ++++ .../tests/unit/test_tools/test_wikipedia.py | 121 +++++ packages/execution/uv.lock | 85 +++ 32 files changed, 3080 insertions(+), 43 deletions(-) create mode 100644 packages/execution/app/tools/file_read.py create mode 100644 packages/execution/app/tools/file_write.py create mode 100644 packages/execution/app/tools/sandbox.py create mode 100644 packages/execution/app/tools/ssrf_transport.py create mode 100644 packages/execution/app/tools/weather.py create mode 100644 packages/execution/app/tools/web_search.py create mode 100644 packages/execution/app/tools/wikipedia_tool.py create mode 100644 packages/execution/tests/manual/test_30_url_fetch_real.py create mode 100644 packages/execution/tests/manual/test_31_web_search_ddg.py create mode 100644 packages/execution/tests/manual/test_32_web_search_tavily.py create mode 100644 packages/execution/tests/manual/test_33_wikipedia_real.py create mode 100644 packages/execution/tests/manual/test_34_weather_real.py create mode 100644 packages/execution/tests/manual/test_35_file_sandbox.py create mode 100644 packages/execution/tests/manual/test_36_exporter_exec.py create mode 100644 packages/execution/tests/manual/test_37_export_route.py create mode 100644 packages/execution/tests/unit/test_exporter.py create mode 100644 packages/execution/tests/unit/test_routes_export.py create mode 100644 packages/execution/tests/unit/test_tools/test_file_read.py create mode 100644 packages/execution/tests/unit/test_tools/test_file_write.py create mode 100644 packages/execution/tests/unit/test_tools/test_ssrf_transport.py create mode 100644 packages/execution/tests/unit/test_tools/test_weather.py create mode 100644 packages/execution/tests/unit/test_tools/test_web_search.py create mode 100644 packages/execution/tests/unit/test_tools/test_wikipedia.py diff --git a/packages/execution/app/exporter.py b/packages/execution/app/exporter.py index a4df6c6..24a3f5f 100644 --- a/packages/execution/app/exporter.py +++ b/packages/execution/app/exporter.py @@ -1,17 +1,513 @@ """Python code generation from GraphSchema.""" +from __future__ import annotations + +# Type → Python type name mapping +_TYPE_NAMES = { + "string": "str", + "number": "float", + "boolean": "bool", + "list": "list", + "object": "dict", +} + +# Type → default value repr +_DEFAULT_REPRS = { + "string": '""', + "number": "0.0", + "boolean": "False", + "list": "[]", + "object": "{}", +} + +# LLM provider → import line and class name +_PROVIDER_IMPORTS = { + "openai": ("from langchain_openai import ChatOpenAI", "ChatOpenAI"), + "anthropic": ("from langchain_anthropic import ChatAnthropic", "ChatAnthropic"), + "gemini": ( + "from langchain_google_genai import ChatGoogleGenerativeAI", + "ChatGoogleGenerativeAI", + ), +} + +# LLM provider → pip requirement +_PROVIDER_REQS = { + "openai": "langchain-openai", + "anthropic": "langchain-anthropic", + "gemini": "langchain-google-genai", +} + +# Tool name → pip requirements +_TOOL_REQS = { + "calculator": ["simpleeval"], + "url_fetch": ["httpx", "trafilatura"], + "web_search": ["tavily-python", "duckduckgo-search"], + "wikipedia": ["Wikipedia-API"], + "weather": ["httpx"], +} + def export_graph(schema: dict) -> dict: """Generate standalone Python code from a GraphSchema. - Args: - schema: A GraphSchema dictionary. - Returns: - Dict with 'code' (Python source) and 'requirements' (requirements.txt content). + Dict with 'code' (Python source) and 'requirements' (pip requirements). """ - # TODO: Implement code generation - return { - "code": "# Generated by GraphWeave\n# TODO: implement\n", - "requirements": "langgraph\nlangchain-openai\n", - } + nodes = schema.get("nodes", []) + edges = schema.get("edges", []) + state_fields = schema.get("state", []) + + nodes_by_id = {n["id"]: n for n in nodes} + start_id = next((n["id"] for n in nodes if n["type"] == "start"), None) + end_ids = {n["id"] for n in nodes if n["type"] == "end"} + + # Collect what we need + has_human_input = any(n["type"] == "human_input" for n in nodes) + has_conditions = any(n["type"] == "condition" for n in nodes) + has_merge_reducer = any(f.get("reducer") == "merge" for f in state_fields) + has_append_messages = any( + f.get("reducer") == "append" and f["key"] == "messages" for f in state_fields + ) + has_append_other = any( + f.get("reducer") == "append" and f["key"] != "messages" for f in state_fields + ) + + # Collect providers and tools + providers = set() + tool_names = set() + for n in nodes: + if n["type"] == "llm": + providers.add(n["config"].get("provider", "openai")) + if n["type"] == "tool": + tool_names.add(n["config"].get("tool_name", "")) + if n["type"] == "condition": + cond = n["config"].get("condition", {}) + if cond.get("type") == "llm_router": + providers.add("openai") # default routing provider + + # Build condition edges + condition_ids = {n["id"] for n in nodes if n["type"] == "condition"} + cond_edges: dict[str, dict[str, str]] = {} + normal_edges: list[dict] = [] + for edge in edges: + if edge["source"] in condition_ids: + branch = edge.get("condition_branch") or edge.get("label", "default") + cond_edges.setdefault(edge["source"], {})[branch] = edge["target"] + else: + normal_edges.append(edge) + + # ── Generate code sections ────────────────────────────────────────── + + sections = [] + + # 1. Docstring + sections.append('"""Generated by GraphWeave — standalone LangGraph graph."""\n') + + # 2. Imports + imports = _build_imports( + providers, + has_human_input, + has_conditions, + has_merge_reducer, + has_append_messages, + has_append_other, + ) + sections.append(imports) + + # 3. Provider imports + for provider in sorted(providers): + if provider in _PROVIDER_IMPORTS: + sections.append(_PROVIDER_IMPORTS[provider][0]) + sections.append("") + + # 4. Merge reducer (if needed) + if has_merge_reducer: + sections.append(_merge_reducer_code()) + + # 5. State class + sections.append(_build_state_class(state_fields)) + + # 6. Node functions + for n in nodes: + if n["type"] in ("start", "end"): + continue + fn = _build_node_function(n, nodes_by_id) + if fn: + sections.append(fn) + + # 7. Condition routers + for cond_id in sorted(cond_edges.keys()): + cond_node = nodes_by_id[cond_id] + router_fn = _build_router_function(cond_node, nodes_by_id, edges) + if router_fn: + sections.append(router_fn) + + # 8. Graph construction + sections.append( + _build_graph_construction( + nodes, + normal_edges, + cond_edges, + nodes_by_id, + start_id, + end_ids, + has_human_input, + ) + ) + + # 9. Main block + sections.append(_build_main_block(state_fields)) + + code = "\n\n".join(s for s in sections if s) + + # ── Requirements ──────────────────────────────────────────────────── + reqs = _build_requirements(providers, tool_names) + + return {"code": code, "requirements": reqs} + + +def _build_imports( + providers: set, + has_human_input: bool, + has_conditions: bool, + has_merge_reducer: bool, + has_append_messages: bool, + has_append_other: bool, +) -> str: + lines = [] + + needs_operator = has_append_other + needs_annotated = has_append_messages or has_append_other or has_merge_reducer + + if needs_operator: + lines.append("import operator") + + typing_imports = ["TypedDict"] + if needs_annotated: + typing_imports.insert(0, "Annotated") + lines.append(f"from typing import {', '.join(typing_imports)}") + + lg_imports = ["END", "START", "StateGraph"] + lines.append(f"from langgraph.graph import {', '.join(lg_imports)}") + + if has_human_input: + lines.append("from langgraph.types import interrupt") + lines.append("from langgraph.checkpoint.memory import InMemorySaver") + + if has_append_messages: + lines.append("from langgraph.graph.message import add_messages") + + if has_conditions: + pass # No extra imports needed for conditions + + return "\n".join(lines) + + +def _merge_reducer_code() -> str: + return '''def _merge_reducer(left: dict, right: dict) -> dict: + """Deep-merge two dicts.""" + result = {**left} + for key, value in right.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _merge_reducer(result[key], value) + else: + result[key] = value + return result''' + + +def _build_state_class(state_fields: list[dict]) -> str: + lines = ["class GraphState(TypedDict):"] + if not state_fields: + lines.append(" pass") + return "\n".join(lines) + + for f in state_fields: + annotation = _field_annotation(f) + lines.append(f" {f['key']}: {annotation}") + + return "\n".join(lines) + + +def _field_annotation(field: dict) -> str: + base = _TYPE_NAMES.get(field["type"], "str") + reducer = field.get("reducer", "replace") + + if reducer == "replace": + return base + if reducer == "append": + if field["key"] == "messages": + return "Annotated[list, add_messages]" + return "Annotated[list, operator.add]" + if reducer == "merge": + return "Annotated[dict, _merge_reducer]" + return base + + +def _build_node_function(node: dict, nodes_by_id: dict) -> str: + ntype = node["type"] + nid = node["id"] + config = node.get("config", {}) + + if ntype == "llm": + return _build_llm_function(nid, config) + if ntype == "tool": + return _build_tool_function(nid, config) + if ntype == "condition": + return _build_condition_passthrough(nid) + if ntype == "human_input": + return _build_human_function(nid, config) + return "" + + +def _build_llm_function(node_id: str, config: dict) -> str: + provider = config.get("provider", "openai") + model = config.get("model", "gpt-4o") + _, cls_name = _PROVIDER_IMPORTS.get(provider, _PROVIDER_IMPORTS["openai"]) + temp = config.get("temperature", 0.7) + max_tokens = config.get("max_tokens", 1024) + output_key = config.get("output_key", "result") + system_prompt = config.get("system_prompt", "") + input_map = config.get("input_map", {}) + + lines = [f"async def {node_id}(state: GraphState) -> dict:"] + + # Build messages + lines.append(" from langchain_core.messages import HumanMessage, SystemMessage") + lines.append( + f' llm = {cls_name}(model="{model}", ' + f"temperature={temp}, max_tokens={max_tokens})" + ) + lines.append(" messages = []") + + if system_prompt: + lines.append( + f' messages.append(SystemMessage(content="{_escape(system_prompt)}"))' + ) + + # Input map + if input_map: + parts = [] + for key, expr in input_map.items(): + parts.append(f'"{key}: " + str(state.get("{expr}", ""))') + user_content = ' + "\\n" + '.join(parts) + lines.append(f" messages.append(HumanMessage(content={user_content}))") + else: + lines.append(" messages.append(HumanMessage(content=str(state)))") + + lines.append(" response = await llm.ainvoke(messages)") + lines.append(f' return {{"{output_key}": response.content}}') + + return "\n".join(lines) + + +def _build_tool_function(node_id: str, config: dict) -> str: + tool_name = config.get("tool_name", "unknown") + output_key = config.get("output_key", "result") + input_map = config.get("input_map", {}) + + lines = [f"def {node_id}(state: GraphState) -> dict:"] + lines.append(f' """Run the {tool_name} tool."""') + + # Resolve inputs + if input_map: + lines.append(" inputs = {") + for key, expr in input_map.items(): + lines.append(f' "{key}": state.get("{expr}", ""),') + lines.append(" }") + else: + lines.append(" inputs = {}") + + lines.append(f" # TODO: Replace with actual {tool_name} tool call") + lines.append(' result = {"success": True, "result": str(inputs)}') + lines.append(f' return {{"{output_key}": result}}') + + return "\n".join(lines) + + +def _build_condition_passthrough(node_id: str) -> str: + return f"def {node_id}(state: GraphState) -> dict:\n return {{}}" + + +def _build_human_function(node_id: str, config: dict) -> str: + prompt = _escape(config.get("prompt", "")) + input_key = config.get("input_key", "input") + + lines = [f"def {node_id}(state: GraphState) -> dict:"] + lines.append( + f' value = interrupt({{"prompt": "{prompt}", ' + f'"input_key": "{input_key}", "node_id": "{node_id}"}})' + ) + lines.append(f' return {{"{input_key}": value}}') + + return "\n".join(lines) + + +def _build_router_function( + cond_node: dict, + nodes_by_id: dict, + edges: list[dict], +) -> str: + node_id = cond_node["id"] + config = cond_node.get("config", {}) + condition = config.get("condition", {}) + ctype = condition.get("type", "field_equals") + default = config.get("default_branch", "") + + lines = [f"def route_{node_id}(state: GraphState) -> str:"] + + if ctype == "field_equals": + field = condition.get("field", "") + value = condition.get("value", "") + branch = condition.get("branch", "") + lines.append(f' if state.get("{field}") == "{value}":') + lines.append(f' return "{branch}"') + lines.append(f' return "{default}"') + + elif ctype == "field_contains": + field = condition.get("field", "") + value = condition.get("value", "") + branch = condition.get("branch", "") + lines.append(f' if "{value}" in str(state.get("{field}", "")):') + lines.append(f' return "{branch}"') + lines.append(f' return "{default}"') + + elif ctype == "field_exists": + field = condition.get("field", "") + branch = condition.get("branch", "") + lines.append(f' if "{field}" in state and state["{field}"] is not None:') + lines.append(f' return "{branch}"') + lines.append(f' return "{default}"') + + elif ctype == "tool_error": + # Find source tool's output_key + tool_key = _find_tool_output_key_for_export(cond_node["id"], nodes_by_id, edges) + on_success = condition.get("on_success", "") + on_error = condition.get("on_error", "") + lines.append(f' tool_output = state.get("{tool_key}", {{}})') + lines.append( + ' if isinstance(tool_output, dict) and tool_output.get("success"):' + ) + lines.append(f' return "{on_success}"') + lines.append(f' return "{on_error}"') + + elif ctype == "iteration_limit": + field = condition.get("field", "counter") + max_val = condition.get("max", 3) + exceeded = condition.get("exceeded", "end") + cont = condition.get("continue", "continue") + lines.append(f' if state.get("{field}", 0) >= {max_val}:') + lines.append(f' return "{exceeded}"') + lines.append(f' return "{cont}"') + + elif ctype == "llm_router": + lines.append(" # TODO: Implement LLM routing logic") + lines.append(f' return "{default}"') + + else: + lines.append(f' return "{default}"') + + return "\n".join(lines) + + +def _find_tool_output_key_for_export( + cond_id: str, nodes_by_id: dict, edges: list[dict] +) -> str: + """Find the output_key of the tool node that feeds into a condition.""" + for edge in edges: + if edge["target"] == cond_id: + source = nodes_by_id.get(edge["source"]) + if source and source["type"] == "tool": + return source["config"].get("output_key", "result") + return "result" + + +def _build_graph_construction( + nodes: list[dict], + normal_edges: list[dict], + cond_edges: dict[str, dict[str, str]], + nodes_by_id: dict, + start_id: str | None, + end_ids: set[str], + has_human_input: bool, +) -> str: + lines = ["# Build graph", "graph = StateGraph(GraphState)"] + + # Add nodes + for n in nodes: + if n["type"] in ("start", "end"): + continue + lines.append(f'graph.add_node("{n["id"]}", {n["id"]})') + + lines.append("") + + # Add normal edges + for edge in normal_edges: + source = "START" if edge["source"] == start_id else f'"{edge["source"]}"' + target = "END" if edge["target"] in end_ids else f'"{edge["target"]}"' + lines.append(f"graph.add_edge({source}, {target})") + + # Add conditional edges + for cond_id, branch_map in sorted(cond_edges.items()): + translated = {} + for branch, target in branch_map.items(): + translated[branch] = "END" if target in end_ids else target + branch_repr = "{\n" + for branch, target in sorted(translated.items()): + target_val = "END" if target == "END" else f'"{target}"' + branch_repr += f' "{branch}": {target_val},\n' + branch_repr += "}" + lines.append( + f'graph.add_conditional_edges("{cond_id}", route_{cond_id}, {branch_repr})' + ) + + lines.append("") + + # Compile + if has_human_input: + lines.append("# Human input nodes require a checkpointer") + lines.append("compiled = graph.compile(checkpointer=InMemorySaver())") + else: + lines.append("compiled = graph.compile()") + + return "\n".join(lines) + + +def _build_main_block(state_fields: list[dict]) -> str: + defaults = {} + for f in state_fields: + if "default" in f and f["default"] is not None: + defaults[f["key"]] = f["default"] + else: + defaults[f["key"]] = _python_default(f["type"]) + + defaults_repr = "{\n" + for key, val in defaults.items(): + defaults_repr += f' "{key}": {repr(val)},\n' + defaults_repr += "}" + + return f"""if __name__ == "__main__": + import asyncio + + result = asyncio.run(compiled.ainvoke({defaults_repr})) + print(result)""" + + +def _python_default(type_name: str) -> object: + return {"string": "", "number": 0.0, "boolean": False, "list": [], "object": {}}[ + type_name + ] + + +def _escape(s: str) -> str: + return s.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n") + + +def _build_requirements(providers: set, tool_names: set) -> str: + reqs = {"langgraph", "langchain-core"} + for p in providers: + if p in _PROVIDER_REQS: + reqs.add(_PROVIDER_REQS[p]) + for t in tool_names: + for r in _TOOL_REQS.get(t, []): + reqs.add(r) + return "\n".join(sorted(reqs)) diff --git a/packages/execution/app/routes/graphs.py b/packages/execution/app/routes/graphs.py index 24b914c..002b3ab 100644 --- a/packages/execution/app/routes/graphs.py +++ b/packages/execution/app/routes/graphs.py @@ -16,6 +16,7 @@ from app.db.connection import get_db from app.schemas.graphs import ( CreateGraphRequest, + ExportResponse, GraphResponse, SchemaValidationError, UpdateGraphRequest, @@ -205,25 +206,23 @@ async def validate_graph( @router.get( "/{graph_id}/export", + response_model=ExportResponse, summary="Export graph as Python code", - responses={ - 404: {"description": "Graph not found"}, - 501: {"description": "Not implemented"}, - }, + responses={404: {"description": "Graph not found"}}, ) -async def export_graph( +async def export_graph_route( graph_id: str, auth: AuthContext = Depends(require_scope("graphs:read")), db=Depends(get_db), -) -> None: - """Export graph as standalone Python code (not yet implemented).""" +) -> ExportResponse: + """Export graph as standalone Python code.""" graph = await crud.get_graph(db, graph_id, owner_id=owner_filter(auth)) if graph is None: raise HTTPException(status_code=404, detail="Graph not found") - raise HTTPException( - status_code=501, - detail="Export not implemented. Coming in a future release.", - ) + from app.exporter import export_graph + + result = export_graph(graph.schema_json) + return ExportResponse(code=result["code"], requirements=result["requirements"]) # ── Run History ──────────────────────────────────────────────────────── diff --git a/packages/execution/app/tools/file_read.py b/packages/execution/app/tools/file_read.py new file mode 100644 index 0000000..2e1dd10 --- /dev/null +++ b/packages/execution/app/tools/file_read.py @@ -0,0 +1,77 @@ +"""File read tool — sandboxed text file reading.""" + +from __future__ import annotations + +import os + +from app.tools.base import BaseTool +from app.tools.sandbox import open_sandboxed + +_MAX_FILE_SIZE = 1_000_000 # 1 MB +_MAX_TEXT_LENGTH = 10_000 +_SANDBOX_ROOT = os.environ.get("FILE_SANDBOX_ROOT", "/workspace") + + +class FileReadTool(BaseTool): + name = "file_read" + description = "Read text content from a sandboxed file" + + def run(self, inputs: dict) -> dict: + path = inputs.get("path", "") + if not path: + return { + "success": False, + "error": "No path provided", + "recoverable": False, + } + + try: + fd = open_sandboxed(path, _SANDBOX_ROOT, os.O_RDONLY) + except PermissionError as exc: + return { + "success": False, + "error": str(exc), + "recoverable": False, + } + except OSError as exc: + return { + "success": False, + "error": f"Cannot open file: {exc}", + "recoverable": False, + } + + try: + size = os.fstat(fd).st_size + if size > _MAX_FILE_SIZE: + os.close(fd) + return { + "success": False, + "error": f"File too large: {size} bytes (max {_MAX_FILE_SIZE})", + "recoverable": False, + } + + with os.fdopen(fd, "r", encoding="utf-8") as f: + text = f.read() + except UnicodeDecodeError as exc: + return { + "success": False, + "error": f"Not a UTF-8 text file: {exc}", + "recoverable": False, + } + except OSError as exc: + return { + "success": False, + "error": f"Read error: {exc}", + "recoverable": False, + } + + truncated = len(text) > _MAX_TEXT_LENGTH + if truncated: + text = text[:_MAX_TEXT_LENGTH] + + return { + "success": True, + "result": text, + "source": path, + "truncated": truncated, + } diff --git a/packages/execution/app/tools/file_write.py b/packages/execution/app/tools/file_write.py new file mode 100644 index 0000000..70d87f7 --- /dev/null +++ b/packages/execution/app/tools/file_write.py @@ -0,0 +1,89 @@ +"""File write tool — sandboxed text file writing.""" + +from __future__ import annotations + +import os + +from app.tools.base import BaseTool +from app.tools.sandbox import resolve_sandboxed_path + +_MAX_CONTENT_SIZE = 1_000_000 # 1 MB +_SANDBOX_ROOT = os.environ.get("FILE_SANDBOX_ROOT", "/workspace") + + +class FileWriteTool(BaseTool): + name = "file_write" + description = "Write text content to a sandboxed file" + + def run(self, inputs: dict) -> dict: + path = inputs.get("path", "") + content = inputs.get("content", "") + mode = inputs.get("mode", "overwrite") + + if not path: + return { + "success": False, + "error": "No path provided", + "recoverable": False, + } + + if len(content) > _MAX_CONTENT_SIZE: + return { + "success": False, + "error": ( + f"Content too large: {len(content)} bytes (max {_MAX_CONTENT_SIZE})" + ), + "recoverable": False, + } + + resolved = resolve_sandboxed_path(path, _SANDBOX_ROOT) + if resolved is None: + return { + "success": False, + "error": f"Path escapes sandbox: {path}", + "recoverable": False, + } + + # Create parent directories + parent = os.path.dirname(resolved) + os.makedirs(parent, exist_ok=True) + + # Open with O_NOFOLLOW to reject symlinks at the leaf + flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW + if mode == "append": + flags |= os.O_APPEND + else: + flags |= os.O_TRUNC + + try: + fd = os.open(resolved, flags, 0o644) + except OSError as exc: + return { + "success": False, + "error": f"Cannot open file: {exc}", + "recoverable": False, + } + + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(content) + except UnicodeEncodeError as exc: + return { + "success": False, + "error": f"Encoding error: {exc}", + "recoverable": False, + } + except OSError as exc: + return { + "success": False, + "error": f"Write error: {exc}", + "recoverable": False, + } + + byte_count = len(content.encode("utf-8")) + return { + "success": True, + "result": f"Written {byte_count} bytes to {path}", + "source": path, + "truncated": False, + } diff --git a/packages/execution/app/tools/registry.py b/packages/execution/app/tools/registry.py index 243274b..698b5b9 100644 --- a/packages/execution/app/tools/registry.py +++ b/packages/execution/app/tools/registry.py @@ -5,7 +5,12 @@ from app.tools.base import BaseTool, ToolNotFoundError from app.tools.calculator import CalculatorTool from app.tools.datetime_tool import DatetimeTool +from app.tools.file_read import FileReadTool +from app.tools.file_write import FileWriteTool from app.tools.url_fetch import UrlFetchTool +from app.tools.weather import WeatherTool +from app.tools.web_search import WebSearchTool +from app.tools.wikipedia_tool import WikipediaTool __all__ = ["BaseTool", "ToolNotFoundError", "REGISTRY", "get_tool"] @@ -13,6 +18,11 @@ "calculator": CalculatorTool(), "datetime": DatetimeTool(), "url_fetch": UrlFetchTool(), + "web_search": WebSearchTool(), + "wikipedia": WikipediaTool(), + "file_read": FileReadTool(), + "file_write": FileWriteTool(), + "weather": WeatherTool(), } diff --git a/packages/execution/app/tools/sandbox.py b/packages/execution/app/tools/sandbox.py new file mode 100644 index 0000000..2bf76c2 --- /dev/null +++ b/packages/execution/app/tools/sandbox.py @@ -0,0 +1,45 @@ +"""File sandbox — path traversal prevention for file_read/file_write.""" + +from __future__ import annotations + +import os + + +def resolve_sandboxed_path(path: str, sandbox_root: str) -> str | None: + """Resolve *path* within *sandbox_root*. + + Returns the absolute resolved path, or ``None`` if the path escapes + the sandbox (via ``../``, symlinks, etc.). + + ``os.path.realpath()`` resolves symlinks in parent directories. + ``O_NOFOLLOW`` on the subsequent open guards the leaf component. + """ + abs_root = os.path.realpath(sandbox_root) + candidate = os.path.realpath(os.path.join(abs_root, path)) + if not candidate.startswith(abs_root + os.sep) and candidate != abs_root: + return None + return candidate + + +def open_sandboxed( + path: str, + sandbox_root: str, + flags: int, + mode: int = 0o644, +) -> int: + """Open a file within the sandbox with ``O_NOFOLLOW``. + + Validates the path first via :func:`resolve_sandboxed_path`, then + opens with ``O_NOFOLLOW`` to reject symlinks at the final component. + + Returns: + Raw file descriptor (caller must close or wrap with ``os.fdopen``). + + Raises: + PermissionError: If the path escapes the sandbox. + OSError: If the file cannot be opened (e.g. symlink at leaf). + """ + resolved = resolve_sandboxed_path(path, sandbox_root) + if resolved is None: + raise PermissionError(f"Path escapes sandbox: {path}") + return os.open(resolved, flags | os.O_NOFOLLOW, mode) diff --git a/packages/execution/app/tools/ssrf_transport.py b/packages/execution/app/tools/ssrf_transport.py new file mode 100644 index 0000000..87f8ea9 --- /dev/null +++ b/packages/execution/app/tools/ssrf_transport.py @@ -0,0 +1,61 @@ +"""SSRF-safe httpx transport — pins DNS resolution to prevent rebinding.""" + +from __future__ import annotations + +import httpcore +import httpx +from httpcore._backends.sync import SyncBackend + + +class PinnedDNSBackend(httpcore.NetworkBackend): + """Network backend that substitutes hostname with a pinned IP in connect_tcp. + + Wraps the default SyncBackend. When ``connect_tcp`` is called, replaces + the ``host`` parameter with the pinned IP while leaving everything else + unchanged. This means: + + - TCP connection goes to the pinned IP (SSRF-safe) + - TLS SNI uses the original hostname (httpcore passes it separately) + - Host header uses the original hostname (httpx sets it from the URL) + """ + + def __init__(self, pinned_ip: str) -> None: + self._pinned_ip = pinned_ip + self._backend = SyncBackend() + + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: object = None, + ) -> httpcore.NetworkStream: + return self._backend.connect_tcp( + self._pinned_ip, + port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + +class SSRFSafeTransport(httpx.HTTPTransport): + """httpx transport that pins DNS to a pre-validated IP. + + Subclasses ``HTTPTransport`` to inherit its ``handle_request`` method, + which correctly converts between ``httpx.Request``/``Response`` and + ``httpcore`` types. We only replace the internal connection pool with + one using ``PinnedDNSBackend``. + """ + + def __init__(self, pinned_ip: str, **kwargs: object) -> None: + super().__init__(**kwargs) + # Preserve ssl_context from the pool created by super().__init__ + # (respects verify=, cert=, trust_env= kwargs). Then replace + # the pool with one using our pinned DNS backend. + existing_pool = self._pool + self._pool = httpcore.ConnectionPool( + ssl_context=existing_pool._ssl_context, + network_backend=PinnedDNSBackend(pinned_ip), + ) diff --git a/packages/execution/app/tools/url_fetch.py b/packages/execution/app/tools/url_fetch.py index 941b9f1..8269100 100644 --- a/packages/execution/app/tools/url_fetch.py +++ b/packages/execution/app/tools/url_fetch.py @@ -10,35 +10,50 @@ import trafilatura from app.tools.base import BaseTool +from app.tools.ssrf_transport import SSRFSafeTransport _MAX_TEXT_LENGTH = 10_000 -def validate_url(url: str) -> str | None: - """Return an error string if the URL is unsafe, ``None`` if OK.""" +def validate_url(url: str) -> tuple[str | None, str | None]: + """Validate a URL for SSRF safety. + + Returns: + (error, resolved_ip) — error is set if URL is unsafe, + resolved_ip is the first safe IP address for DNS-pinning. + """ try: parsed = urlparse(url) except Exception: - return f"Invalid URL: {url}" + return (f"Invalid URL: {url}", None) if parsed.scheme not in ("http", "https"): - return f"Only http/https URLs are allowed, got: {parsed.scheme or 'none'}" + return ( + f"Only http/https URLs are allowed, got: {parsed.scheme or 'none'}", + None, + ) hostname = parsed.hostname if not hostname: - return "URL has no hostname" + return ("URL has no hostname", None) try: addr_infos = socket.getaddrinfo(hostname, None) except socket.gaierror: - return f"Cannot resolve hostname: {hostname}" + return (f"Cannot resolve hostname: {hostname}", None) + resolved_ip = None for info in addr_infos: ip = ipaddress.ip_address(info[4][0]) if ip.is_private or ip.is_reserved or ip.is_loopback or ip.is_link_local: - return f"Blocked: {hostname} resolves to private/reserved IP {ip}" + return ( + f"Blocked: {hostname} resolves to private/reserved IP {ip}", + None, + ) + if resolved_ip is None: + resolved_ip = str(ip) - return None + return (None, resolved_ip) class UrlFetchTool(BaseTool): @@ -48,12 +63,15 @@ class UrlFetchTool(BaseTool): def run(self, inputs: dict) -> dict: url = inputs.get("url", "") - error = validate_url(url) + error, resolved_ip = validate_url(url) if error: return {"success": False, "error": error, "recoverable": False} try: - with httpx.Client(timeout=10, follow_redirects=False) as client: + transport = SSRFSafeTransport(resolved_ip) + with httpx.Client( + transport=transport, timeout=10, follow_redirects=False + ) as client: response = client.get(url) response.raise_for_status() except httpx.TimeoutException: diff --git a/packages/execution/app/tools/weather.py b/packages/execution/app/tools/weather.py new file mode 100644 index 0000000..38a8766 --- /dev/null +++ b/packages/execution/app/tools/weather.py @@ -0,0 +1,178 @@ +"""Weather tool — current conditions and forecast via Open-Meteo API.""" + +from __future__ import annotations + +import re + +import httpx + +from app.tools.base import BaseTool + +_GEOCODE_URL = "https://geocoding-api.open-meteo.com/v1/search" +_FORECAST_URL = "https://api.open-meteo.com/v1/forecast" +_TIMEOUT = 10 +_LATLON_RE = re.compile(r"^(-?\d+\.?\d*)\s*,\s*(-?\d+\.?\d*)$") + + +class WeatherTool(BaseTool): + name = "weather" + description = "Get current weather or 7-day forecast for a location" + + def run(self, inputs: dict) -> dict: + location = inputs.get("location", "").strip() + action = inputs.get("action", "current") + + if not location: + return { + "success": False, + "error": "No location provided", + "recoverable": False, + } + + lat, lon = self._parse_latlon(location) + if lat is None: + result = self._geocode(location) + if not result["success"]: + return result + lat, lon = result["lat"], result["lon"] + + if action == "current": + return self._current(lat, lon, location) + if action == "forecast": + return self._forecast(lat, lon, location) + + return { + "success": False, + "error": f"Unknown action: {action}", + "recoverable": False, + } + + def _parse_latlon(self, location: str) -> tuple[float | None, float | None]: + match = _LATLON_RE.match(location) + if not match: + return None, None + lat, lon = float(match.group(1)), float(match.group(2)) + if not (-90 <= lat <= 90 and -180 <= lon <= 180): + return None, None + return lat, lon + + def _geocode(self, location: str) -> dict: + try: + resp = httpx.get( + _GEOCODE_URL, + params={"name": location, "count": 1, "format": "json"}, + timeout=_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if not results: + return { + "success": False, + "error": f"Location not found: {location}", + "recoverable": False, + } + return { + "success": True, + "lat": results[0]["latitude"], + "lon": results[0]["longitude"], + } + except httpx.TimeoutException: + return { + "success": False, + "error": "Geocoding timed out", + "recoverable": True, + } + except httpx.HTTPError as exc: + return { + "success": False, + "error": f"Geocoding error: {exc}", + "recoverable": True, + } + + def _current(self, lat: float, lon: float, location: str) -> dict: + try: + resp = httpx.get( + _FORECAST_URL, + params={ + "latitude": lat, + "longitude": lon, + "current": "temperature_2m,relative_humidity_2m," + "wind_speed_10m,weather_code", + }, + timeout=_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + current = data.get("current", {}) + text = ( + f"Location: {location}\n" + f"Temperature: {current.get('temperature_2m', 'N/A')}°C\n" + f"Humidity: {current.get('relative_humidity_2m', 'N/A')}%\n" + f"Wind: {current.get('wind_speed_10m', 'N/A')} km/h\n" + f"Weather code: {current.get('weather_code', 'N/A')}" + ) + return { + "success": True, + "result": text, + "source": "open-meteo", + "truncated": False, + } + except httpx.TimeoutException: + return { + "success": False, + "error": "Weather request timed out", + "recoverable": True, + } + except httpx.HTTPError as exc: + return { + "success": False, + "error": f"Weather error: {exc}", + "recoverable": True, + } + + def _forecast(self, lat: float, lon: float, location: str) -> dict: + try: + resp = httpx.get( + _FORECAST_URL, + params={ + "latitude": lat, + "longitude": lon, + "daily": "temperature_2m_max,temperature_2m_min,weather_code", + "forecast_days": 7, + }, + timeout=_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + daily = data.get("daily", {}) + dates = daily.get("time", []) + highs = daily.get("temperature_2m_max", []) + lows = daily.get("temperature_2m_min", []) + codes = daily.get("weather_code", []) + + lines = [f"7-day forecast for {location}:"] + for i, date in enumerate(dates): + high = highs[i] if i < len(highs) else "N/A" + low = lows[i] if i < len(lows) else "N/A" + code = codes[i] if i < len(codes) else "N/A" + lines.append(f" {date}: {low}°C – {high}°C (code {code})") + + return { + "success": True, + "result": "\n".join(lines), + "source": "open-meteo", + "truncated": False, + } + except httpx.TimeoutException: + return { + "success": False, + "error": "Forecast request timed out", + "recoverable": True, + } + except httpx.HTTPError as exc: + return { + "success": False, + "error": f"Forecast error: {exc}", + "recoverable": True, + } diff --git a/packages/execution/app/tools/web_search.py b/packages/execution/app/tools/web_search.py new file mode 100644 index 0000000..fdbbd12 --- /dev/null +++ b/packages/execution/app/tools/web_search.py @@ -0,0 +1,99 @@ +"""Web search tool — Tavily with DuckDuckGo fallback.""" + +from __future__ import annotations + +import os + +from app.tools.base import BaseTool + +_MAX_RESULTS = 10 + + +class WebSearchTool(BaseTool): + name = "web_search" + description = "Search the web and return results with titles, URLs, and snippets" + + def run(self, inputs: dict) -> dict: + query = inputs.get("query", "").strip() + if not query: + return { + "success": False, + "error": "Empty search query", + "recoverable": False, + } + + max_results = min(int(inputs.get("max_results", 5)), _MAX_RESULTS) + + tavily_key = os.environ.get("TAVILY_API_KEY") + if tavily_key: + return self._search_tavily(query, max_results, tavily_key) + return self._search_ddg(query, max_results) + + def _search_tavily(self, query: str, max_results: int, api_key: str) -> dict: + try: + from tavily import TavilyClient + + client = TavilyClient(api_key=api_key) + response = client.search(query, max_results=max_results) + results = [ + { + "title": r.get("title", ""), + "url": r.get("url", ""), + "snippet": r.get("content", ""), + } + for r in response.get("results", []) + ] + return { + "success": True, + "result": _format_results(results), + "source": "tavily", + "truncated": False, + } + except Exception as exc: + return { + "success": False, + "error": f"Tavily search error: {exc}", + "recoverable": True, + } + + def _search_ddg(self, query: str, max_results: int) -> dict: + try: + # duckduckgo-search >=6.0,<8.0 — DDGS().text(keywords, max_results=N) + from duckduckgo_search import DDGS + + with DDGS() as ddgs: + raw = ddgs.text(query, max_results=max_results) + results = [ + { + "title": r.get("title", ""), + "url": r.get("href", ""), + "snippet": r.get("body", ""), + } + for r in raw + ] + return { + "success": True, + "result": _format_results(results), + "source": "duckduckgo", + "truncated": False, + } + except Exception as exc: + return { + "success": False, + "error": f"DuckDuckGo search error: {exc}", + "recoverable": True, + } + + +def _format_results(results: list[dict]) -> str: + """Format search results as readable text.""" + if not results: + return "No results found." + lines = [] + for i, r in enumerate(results, 1): + lines.append(f"{i}. {r['title']}") + lines.append(f" {r['url']}") + if r["snippet"]: + lines.append(f" {r['snippet']}") + lines.append("") + return "\n".join(lines).strip() diff --git a/packages/execution/app/tools/wikipedia_tool.py b/packages/execution/app/tools/wikipedia_tool.py new file mode 100644 index 0000000..1ca51ff --- /dev/null +++ b/packages/execution/app/tools/wikipedia_tool.py @@ -0,0 +1,115 @@ +"""Wikipedia tool — search titles and retrieve page content.""" + +from __future__ import annotations + +import httpx +import wikipediaapi + +from app.tools.base import BaseTool + +_MAX_TEXT_LENGTH = 10_000 +_SEARCH_LIMIT = 5 +_USER_AGENT = "GraphWeave/1.0 (https://github.com/prosdevlab/graphweave)" + + +class WikipediaTool(BaseTool): + name = "wikipedia" + description = "Search Wikipedia titles or retrieve page content" + + def run(self, inputs: dict) -> dict: + action = inputs.get("action", "search") + + if action == "search": + return self._search(inputs.get("query", "")) + if action == "page": + return self._get_page(inputs.get("title", "")) + + return { + "success": False, + "error": f"Unknown action: {action}", + "recoverable": False, + } + + def _search(self, query: str) -> dict: + if not query.strip(): + return { + "success": False, + "error": "Empty search query", + "recoverable": False, + } + + try: + resp = httpx.get( + "https://en.wikipedia.org/w/api.php", + params={ + "action": "opensearch", + "search": query, + "limit": _SEARCH_LIMIT, + "format": "json", + }, + headers={"User-Agent": _USER_AGENT}, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + # opensearch returns [query, [titles], [descriptions], [urls]] + titles = data[1] if len(data) > 1 else [] + return { + "success": True, + "result": "\n".join(titles) if titles else "No results found.", + "source": "wikipedia", + "truncated": False, + } + except httpx.TimeoutException: + return { + "success": False, + "error": "Wikipedia search timed out", + "recoverable": True, + } + except httpx.HTTPError as exc: + return { + "success": False, + "error": f"Wikipedia search error: {exc}", + "recoverable": True, + } + + def _get_page(self, title: str) -> dict: + if not title.strip(): + return { + "success": False, + "error": "No title provided", + "recoverable": False, + } + + wiki = wikipediaapi.Wikipedia(user_agent=_USER_AGENT, language="en") + try: + page = wiki.page(title) + except Exception as exc: + return { + "success": False, + "error": f"Wikipedia API error: {exc}", + "recoverable": True, + } + + if not page.exists(): + return { + "success": False, + "error": f"Page not found: {title}", + "recoverable": False, + } + + text = page.summary + full_text = page.text + if full_text and len(full_text) > len(text): + text = full_text + + truncated = len(text) > _MAX_TEXT_LENGTH + if truncated: + text = text[:_MAX_TEXT_LENGTH] + + return { + "success": True, + "result": text, + "source": f"wikipedia:{title}", + "truncated": truncated, + } diff --git a/packages/execution/pyproject.toml b/packages/execution/pyproject.toml index 6e61405..4f8fd12 100644 --- a/packages/execution/pyproject.toml +++ b/packages/execution/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "trafilatura>=2.0.0", "pydantic>=2.11.0", "aiosqlite>=0.22.1", + "tavily-python>=0.5.0,<1.0", + "duckduckgo-search>=6.0.0,<8.0", + "Wikipedia-API>=0.7.0", ] [dependency-groups] diff --git a/packages/execution/tests/manual/test_30_url_fetch_real.py b/packages/execution/tests/manual/test_30_url_fetch_real.py new file mode 100644 index 0000000..8e2201b --- /dev/null +++ b/packages/execution/tests/manual/test_30_url_fetch_real.py @@ -0,0 +1,48 @@ +"""Manual test 30: url_fetch with real HTTPS — verifies SSRFSafeTransport end-to-end. + +Tests that DNS pinning works with real TLS/SNI against a public URL. + +Usage: cd packages/execution && uv run python tests/manual/test_30_url_fetch_real.py +""" + +from app.tools.url_fetch import UrlFetchTool, validate_url + + +def main(): + print("── Test 30: url_fetch with real HTTPS ──") + + # 1. validate_url returns resolved IP for public URL + error, ip = validate_url("https://httpbin.org/get") + assert error is None, f"validate_url failed: {error}" + assert ip is not None, "No resolved IP returned" + print(f" ✓ Resolved httpbin.org → {ip}") + + # 2. Full fetch through SSRFSafeTransport + tool = UrlFetchTool() + result = tool.run({"url": "https://httpbin.org/html"}) + assert result["success"] is True, f"Fetch failed: {result.get('error')}" + assert len(result["result"]) > 0, "Empty result" + print(f" ✓ Fetched {len(result['result'])} chars via pinned transport") + + # 3. SSRF: localhost blocked + error, ip = validate_url("http://127.0.0.1/secret") + assert error is not None, "localhost should be blocked" + assert ip is None + print(f" ✓ Localhost blocked: {error}") + + # 4. SSRF: metadata endpoint blocked + error, ip = validate_url("http://169.254.169.254/latest/meta-data/") + assert error is not None, "metadata IP should be blocked" + print(f" ✓ Metadata IP blocked: {error}") + + # 5. Bad scheme rejected + result = tool.run({"url": "ftp://example.com"}) + assert result["success"] is False + assert result["recoverable"] is False + print(" ✓ FTP scheme rejected") + + print("\n✅ All url_fetch real tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_31_web_search_ddg.py b/packages/execution/tests/manual/test_31_web_search_ddg.py new file mode 100644 index 0000000..a45be6b --- /dev/null +++ b/packages/execution/tests/manual/test_31_web_search_ddg.py @@ -0,0 +1,47 @@ +"""Manual test 31: web_search DuckDuckGo fallback — no API key needed. + +Usage: cd packages/execution && uv run python tests/manual/test_31_web_search_ddg.py +""" + +import os + +from app.tools.web_search import WebSearchTool + + +def main(): + print("── Test 31: web_search DDG fallback ──") + + # Ensure no Tavily key + os.environ.pop("TAVILY_API_KEY", None) + + tool = WebSearchTool() + + # 1. Basic search + result = tool.run({"query": "Python programming language"}) + assert result["success"] is True, f"Search failed: {result.get('error')}" + assert result["source"] == "duckduckgo" + assert len(result["result"]) > 0 + print(f" ✓ DDG returned results ({len(result['result'])} chars)") + # Show first 200 chars + print(f" Preview: {result['result'][:200]}...") + + # 2. max_results respected + result = tool.run({"query": "LangGraph framework", "max_results": 2}) + assert result["success"] is True + numbered = [ + line for line in result["result"].split("\n") if line[:2] in ("1.", "2.", "3.") + ] + assert len(numbered) <= 3, f"Expected ≤2 results: {numbered}" + print(" ✓ max_results=2 respected") + + # 3. Empty query rejected + result = tool.run({"query": ""}) + assert result["success"] is False + assert result["recoverable"] is False + print(" ✓ Empty query rejected") + + print("\n✅ All DDG search tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_32_web_search_tavily.py b/packages/execution/tests/manual/test_32_web_search_tavily.py new file mode 100644 index 0000000..2df44f0 --- /dev/null +++ b/packages/execution/tests/manual/test_32_web_search_tavily.py @@ -0,0 +1,33 @@ +"""Manual test 32: web_search Tavily path — requires TAVILY_API_KEY. + +Usage: TAVILY_API_KEY=tvly-... uv run python tests/manual/test_32_web_search_tavily.py +""" + +import os +import sys + +from app.tools.web_search import WebSearchTool + + +def main(): + print("── Test 32: web_search Tavily ──") + + if not os.environ.get("TAVILY_API_KEY"): + print(" ⚠ TAVILY_API_KEY not set — skipping") + print(" Set TAVILY_API_KEY and re-run.") + sys.exit(0) + + tool = WebSearchTool() + + result = tool.run({"query": "GraphWeave LangGraph", "max_results": 3}) + assert result["success"] is True, f"Tavily search failed: {result.get('error')}" + assert result["source"] == "tavily" + assert len(result["result"]) > 0 + print(f" ✓ Tavily returned results ({len(result['result'])} chars)") + print(f" Preview: {result['result'][:300]}...") + + print("\n✅ Tavily search test passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_33_wikipedia_real.py b/packages/execution/tests/manual/test_33_wikipedia_real.py new file mode 100644 index 0000000..d037b1d --- /dev/null +++ b/packages/execution/tests/manual/test_33_wikipedia_real.py @@ -0,0 +1,43 @@ +"""Manual test 33: Wikipedia search + page — real API calls. + +Usage: cd packages/execution && uv run python tests/manual/test_33_wikipedia_real.py +""" + +from app.tools.wikipedia_tool import WikipediaTool + + +def main(): + print("── Test 33: Wikipedia real API ──") + + tool = WikipediaTool() + + # 1. Search + result = tool.run({"action": "search", "query": "Python programming"}) + assert result["success"] is True, f"Search failed: {result.get('error')}" + assert "Python" in result["result"] + print(f" ✓ Search returned titles:\n {result['result']}") + + # 2. Page content + result = tool.run({"action": "page", "title": "Python (programming language)"}) + assert result["success"] is True, f"Page failed: {result.get('error')}" + assert len(result["result"]) > 100 + assert "programming" in result["result"].lower() + print(f" ✓ Page content: {len(result['result'])} chars") + print(f" First 200: {result['result'][:200]}...") + + # 3. Page not found + result = tool.run({"action": "page", "title": "Xyzzy12345Nonexistent"}) + assert result["success"] is False + assert result["recoverable"] is False + print(f" ✓ Not found handled: {result['error']}") + + # 4. Empty search + result = tool.run({"action": "search", "query": ""}) + assert result["success"] is False + print(" ✓ Empty search rejected") + + print("\n✅ All Wikipedia tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_34_weather_real.py b/packages/execution/tests/manual/test_34_weather_real.py new file mode 100644 index 0000000..c87bae5 --- /dev/null +++ b/packages/execution/tests/manual/test_34_weather_real.py @@ -0,0 +1,41 @@ +"""Manual test 34: Weather tool — real Open-Meteo API. + +Usage: cd packages/execution && uv run python tests/manual/test_34_weather_real.py +""" + +from app.tools.weather import WeatherTool + + +def main(): + print("── Test 34: Weather real API ──") + + tool = WeatherTool() + + # 1. Current weather by city name + result = tool.run({"location": "London", "action": "current"}) + assert result["success"] is True, f"Current failed: {result.get('error')}" + assert "Temperature" in result["result"] + print(f" ✓ Current weather:\n {result['result']}") + + # 2. Forecast by city name + result = tool.run({"location": "Tokyo", "action": "forecast"}) + assert result["success"] is True, f"Forecast failed: {result.get('error')}" + assert "forecast" in result["result"].lower() + print(f" ✓ Forecast:\n {result['result'][:300]}") + + # 3. Direct lat/lon (Sydney) + result = tool.run({"location": "-33.87, 151.21", "action": "current"}) + assert result["success"] is True, f"Lat/lon failed: {result.get('error')}" + print(f" ✓ Lat/lon weather:\n {result['result']}") + + # 4. Unknown location + result = tool.run({"location": "Xyzzyville99"}) + assert result["success"] is False + assert result["recoverable"] is False + print(f" ✓ Unknown location: {result['error']}") + + print("\n✅ All weather tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_35_file_sandbox.py b/packages/execution/tests/manual/test_35_file_sandbox.py new file mode 100644 index 0000000..256e89e --- /dev/null +++ b/packages/execution/tests/manual/test_35_file_sandbox.py @@ -0,0 +1,89 @@ +"""Manual test 35: File read/write sandbox — real filesystem operations. + +Usage: cd packages/execution && uv run python tests/manual/test_35_file_sandbox.py +""" + +import os +import tempfile + +from app.tools.file_read import FileReadTool +from app.tools.file_write import FileWriteTool + + +def main(): + print("── Test 35: File sandbox real filesystem ──") + + with tempfile.TemporaryDirectory() as sandbox: + # Patch sandbox root + import app.tools.file_read as fr_mod + import app.tools.file_write as fw_mod + + fr_mod._SANDBOX_ROOT = sandbox + fw_mod._SANDBOX_ROOT = sandbox + + writer = FileWriteTool() + reader = FileReadTool() + + # 1. Write and read back + result = writer.run({"path": "hello.txt", "content": "Hello, GraphWeave!"}) + assert result["success"] is True, f"Write failed: {result}" + print(f" ✓ Write: {result['result']}") + + result = reader.run({"path": "hello.txt"}) + assert result["success"] is True + assert result["result"] == "Hello, GraphWeave!" + print(f" ✓ Read back: {result['result']}") + + # 2. Nested directory creation + result = writer.run( + { + "path": "deep/nested/dir/file.txt", + "content": "deep content", + } + ) + assert result["success"] is True + assert os.path.exists(os.path.join(sandbox, "deep/nested/dir/file.txt")) + print(" ✓ Nested dirs created") + + # 3. Append mode + writer.run({"path": "log.txt", "content": "line1\n"}) + writer.run({"path": "log.txt", "content": "line2\n", "mode": "append"}) + result = reader.run({"path": "log.txt"}) + assert result["result"] == "line1\nline2\n" + print(" ✓ Append mode works") + + # 4. Path traversal blocked + result = reader.run({"path": "../../etc/passwd"}) + assert result["success"] is False + print(f" ✓ Traversal blocked: {result['error']}") + + result = writer.run({"path": "../../../tmp/evil", "content": "pwned"}) + assert result["success"] is False + print(f" ✓ Write traversal blocked: {result['error']}") + + # 5. Symlink escape + outside = os.path.join(sandbox, "..", "outside_file.txt") + with open(outside, "w") as f: + f.write("outside") + link_path = os.path.join(sandbox, "escape_link.txt") + os.symlink(outside, link_path) + + result = reader.run({"path": "escape_link.txt"}) + assert result["success"] is False + print(" ✓ Symlink escape blocked (read)") + + result = writer.run({"path": "escape_link.txt", "content": "overwrite"}) + assert result["success"] is False + # Verify original file untouched + with open(outside) as f: + assert f.read() == "outside" + print(" ✓ Symlink escape blocked (write)") + + # Cleanup outside file + os.unlink(outside) + + print("\n✅ All file sandbox tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_36_exporter_exec.py b/packages/execution/tests/manual/test_36_exporter_exec.py new file mode 100644 index 0000000..2e07bdc --- /dev/null +++ b/packages/execution/tests/manual/test_36_exporter_exec.py @@ -0,0 +1,204 @@ +"""Manual test 36: Export and exec — verify generated code is executable. + +Usage: cd packages/execution && uv run python tests/manual/test_36_exporter_exec.py +""" + +import ast + +from app.exporter import export_graph + + +def _linear_schema(): + return { + "id": "linear", + "name": "Linear", + "version": 1, + "state": [ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + ], + "nodes": [ + { + "id": "s", + "type": "start", + "label": "S", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "tool_1", + "type": "tool", + "label": "Calc", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + { + "id": "e", + "type": "end", + "label": "E", + "position": {"x": 0, "y": 200}, + "config": {}, + }, + ], + "edges": [ + {"id": "e1", "source": "s", "target": "tool_1"}, + {"id": "e2", "source": "tool_1", "target": "e"}, + ], + "metadata": {}, + } + + +def _complex_schema(): + return { + "id": "complex", + "name": "Complex", + "version": 1, + "state": [ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + {"key": "data", "type": "object", "reducer": "merge"}, + {"key": "items", "type": "list", "reducer": "append"}, + {"key": "answer", "type": "string", "reducer": "replace"}, + ], + "nodes": [ + { + "id": "s", + "type": "start", + "label": "S", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "llm_1", + "type": "llm", + "label": "LLM", + "position": {"x": 0, "y": 100}, + "config": { + "provider": "openai", + "model": "gpt-4o", + "input_map": {"q": "result"}, + "output_key": "result", + "system_prompt": "You are helpful.", + "temperature": 0.7, + "max_tokens": 512, + }, + }, + { + "id": "tool_1", + "type": "tool", + "label": "Calc", + "position": {"x": 0, "y": 200}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + { + "id": "ask", + "type": "human_input", + "label": "Ask", + "position": {"x": 0, "y": 300}, + "config": { + "prompt": "Confirm?", + "input_key": "answer", + }, + }, + { + "id": "cond_1", + "type": "condition", + "label": "Check", + "position": {"x": 0, "y": 400}, + "config": { + "condition": { + "type": "field_equals", + "field": "answer", + "value": "yes", + "branch": "done", + }, + "default_branch": "retry", + }, + }, + { + "id": "e", + "type": "end", + "label": "E", + "position": {"x": 0, "y": 500}, + "config": {}, + }, + ], + "edges": [ + {"id": "e1", "source": "s", "target": "llm_1"}, + {"id": "e2", "source": "llm_1", "target": "tool_1"}, + {"id": "e3", "source": "tool_1", "target": "ask"}, + {"id": "e4", "source": "ask", "target": "cond_1"}, + {"id": "e5", "source": "cond_1", "target": "e", "condition_branch": "done"}, + { + "id": "e6", + "source": "cond_1", + "target": "llm_1", + "condition_branch": "retry", + }, + ], + "metadata": {}, + } + + +def main(): + print("── Test 36: Exporter exec ──") + + # 1. Linear graph — compile + exec + result = export_graph(_linear_schema()) + compile(result["code"], "", "exec") + ns = {} + exec(compile(result["code"], "", "exec"), ns) + assert "compiled" in ns, "No 'compiled' graph in namespace" + assert "GraphState" in ns, "No 'GraphState' class in namespace" + print(" ✓ Linear graph: compiles and exec's, 'compiled' graph exists") + + # 2. Complex graph — compile + AST check + result = export_graph(_complex_schema()) + compile(result["code"], "", "exec") + + tree = ast.parse(result["code"]) + names = { + node.name + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + } + expected = {"GraphState", "llm_1", "tool_1", "ask", "cond_1", "route_cond_1"} + missing = expected - names + assert not missing, f"Missing definitions: {missing}" + print(f" ✓ Complex graph: all {len(expected)} definitions present") + + # 3. Check requirements + assert "langgraph" in result["requirements"] + assert "langchain-openai" in result["requirements"] + assert "simpleeval" in result["requirements"] + print(f" ✓ Requirements:\n {result['requirements']}") + + # 4. State class has TypedDict + assert "class GraphState(TypedDict):" in result["code"] + assert "Annotated[list, add_messages]" in result["code"] + assert "Annotated[dict, _merge_reducer]" in result["code"] + print(" ✓ State class: TypedDict with correct reducers") + + # 5. Human input → checkpointer + assert "InMemorySaver" in result["code"] + assert "interrupt" in result["code"] + print(" ✓ Human input: InMemorySaver + interrupt present") + + print(f"\n Generated code ({len(result['code'])} chars):") + for line in result["code"].split("\n")[:30]: + print(f" {line}") + print(" ...") + + print("\n✅ All exporter exec tests passed") + + +if __name__ == "__main__": + main() diff --git a/packages/execution/tests/manual/test_37_export_route.py b/packages/execution/tests/manual/test_37_export_route.py new file mode 100644 index 0000000..4f82288 --- /dev/null +++ b/packages/execution/tests/manual/test_37_export_route.py @@ -0,0 +1,115 @@ +"""Manual test 37: Export route end-to-end — hit real API, verify response. + +Usage: cd packages/execution && uv run python tests/manual/test_37_export_route.py +""" + +import asyncio +import os + +import aiosqlite +import httpx + +from app.auth import SCOPES_DEFAULT +from app.db.migrations.runner import run_migrations +from app.main import app +from tests.conftest import create_test_key + +os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-key") + + +def _schema(): + return { + "id": "route_test", + "name": "RouteTest", + "version": 1, + "state": [ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + ], + "nodes": [ + { + "id": "s", + "type": "start", + "label": "S", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "tool_1", + "type": "tool", + "label": "Calc", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + { + "id": "e", + "type": "end", + "label": "E", + "position": {"x": 0, "y": 200}, + "config": {}, + }, + ], + "edges": [ + {"id": "e1", "source": "s", "target": "tool_1"}, + {"id": "e2", "source": "tool_1", "target": "e"}, + ], + "metadata": {}, + } + + +async def main(): + print("── Test 37: Export route end-to-end ──") + + db_path = "/tmp/test_export_route.db" + if os.path.exists(db_path): + os.unlink(db_path) + run_migrations(db_path) + db = await aiosqlite.connect(db_path) + db.row_factory = aiosqlite.Row + app.state.db = db + + _, raw = await create_test_key(db, scopes=SCOPES_DEFAULT, name="user") + headers = {"X-API-Key": raw} + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + # Create graph + resp = await client.post( + "/v1/graphs", + headers=headers, + json={"name": "ExportTest", "schema_json": _schema()}, + ) + assert resp.status_code == 201 + gid = resp.json()["id"] + print(f" ✓ Created graph {gid}") + + # Export + resp = await client.get(f"/v1/graphs/{gid}/export", headers=headers) + assert resp.status_code == 200, f"Expected 200, got {resp.status_code}" + body = resp.json() + assert "code" in body + assert "requirements" in body + assert "class GraphState(TypedDict):" in body["code"] + assert "langgraph" in body["requirements"] + print(f" ✓ Export returned 200 with {len(body['code'])} chars of code") + + # Verify code compiles + compile(body["code"], "", "exec") + print(" ✓ Generated code compiles") + + # 404 for nonexistent + resp = await client.get("/v1/graphs/nonexistent/export", headers=headers) + assert resp.status_code == 404 + print(" ✓ 404 for nonexistent graph") + + await db.close() + os.unlink(db_path) + print("\n✅ All export route tests passed") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/execution/tests/unit/test_exporter.py b/packages/execution/tests/unit/test_exporter.py new file mode 100644 index 0000000..c918018 --- /dev/null +++ b/packages/execution/tests/unit/test_exporter.py @@ -0,0 +1,373 @@ +"""Exporter tests — code generation from GraphSchema.""" + +from __future__ import annotations + +import ast + +from app.exporter import export_graph + + +def _base_schema(**overrides): + """Minimal valid schema: start → end.""" + schema = { + "id": "test", + "name": "Test", + "version": 1, + "state": [ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + ], + "nodes": [ + { + "id": "s", + "type": "start", + "label": "Start", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "e", + "type": "end", + "label": "End", + "position": {"x": 0, "y": 200}, + "config": {}, + }, + ], + "edges": [{"id": "e1", "source": "s", "target": "e"}], + "metadata": {}, + } + schema.update(overrides) + return schema + + +def _add_llm_node(schema, node_id="llm_1", provider="openai", model="gpt-4o"): + schema["nodes"].insert( + -1, + { + "id": node_id, + "type": "llm", + "label": "LLM", + "position": {"x": 0, "y": 100}, + "config": { + "provider": provider, + "model": model, + "temperature": 0.7, + "max_tokens": 1024, + "input_map": {"question": "result"}, + "output_key": "result", + "system_prompt": "You are helpful.", + }, + }, + ) + # Rewire: s → llm → e + schema["edges"] = [ + {"id": "e1", "source": "s", "target": node_id}, + {"id": "e2", "source": node_id, "target": "e"}, + ] + + +def _add_tool_node(schema, node_id="tool_1", tool_name="calculator"): + schema["nodes"].insert( + -1, + { + "id": node_id, + "type": "tool", + "label": "Tool", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": tool_name, + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + ) + schema["edges"] = [ + {"id": "e1", "source": "s", "target": node_id}, + {"id": "e2", "source": node_id, "target": "e"}, + ] + + +def test_export_linear_graph(): + schema = _base_schema() + _add_llm_node(schema) + result = export_graph(schema) + + assert "code" in result + assert "requirements" in result + assert "class GraphState(TypedDict):" in result["code"] + assert "async def llm_1" in result["code"] + assert "compiled = graph.compile()" in result["code"] + + +def test_export_with_tool_node(): + schema = _base_schema() + _add_tool_node(schema) + result = export_graph(schema) + + assert "def tool_1" in result["code"] + assert "calculator" in result["code"] + + +def test_export_with_condition(): + schema = _base_schema() + _add_tool_node(schema, "tool_1") + schema["nodes"].insert( + -1, + { + "id": "cond_1", + "type": "condition", + "label": "Check", + "position": {"x": 0, "y": 150}, + "config": { + "condition": { + "type": "field_equals", + "field": "result", + "value": "42", + "branch": "match", + }, + "default_branch": "no_match", + }, + }, + ) + schema["edges"] = [ + {"id": "e1", "source": "s", "target": "tool_1"}, + {"id": "e2", "source": "tool_1", "target": "cond_1"}, + {"id": "e3", "source": "cond_1", "target": "e", "condition_branch": "match"}, + {"id": "e4", "source": "cond_1", "target": "e", "condition_branch": "no_match"}, + ] + result = export_graph(schema) + + assert "def route_cond_1" in result["code"] + assert "add_conditional_edges" in result["code"] + + +def test_export_with_human_input(): + schema = _base_schema() + schema["nodes"].insert( + -1, + { + "id": "ask", + "type": "human_input", + "label": "Ask", + "position": {"x": 0, "y": 100}, + "config": {"prompt": "Enter value", "input_key": "result"}, + }, + ) + schema["edges"] = [ + {"id": "e1", "source": "s", "target": "ask"}, + {"id": "e2", "source": "ask", "target": "e"}, + ] + result = export_graph(schema) + + assert "interrupt" in result["code"] + assert "InMemorySaver" in result["code"] + assert "checkpointer" in result["code"] + + +def test_export_requirements_openai(): + schema = _base_schema() + _add_llm_node(schema, provider="openai") + result = export_graph(schema) + + assert "langchain-openai" in result["requirements"] + assert "langgraph" in result["requirements"] + assert "langchain-core" in result["requirements"] + + +def test_export_requirements_multi_provider(): + schema = _base_schema() + _add_llm_node(schema, "llm_1", provider="openai") + # Add a second LLM node with anthropic + schema["nodes"].insert( + -1, + { + "id": "llm_2", + "type": "llm", + "label": "LLM2", + "position": {"x": 100, "y": 100}, + "config": { + "provider": "anthropic", + "model": "claude-sonnet-4-6", + "input_map": {}, + "output_key": "result", + }, + }, + ) + result = export_graph(schema) + + assert "langchain-openai" in result["requirements"] + assert "langchain-anthropic" in result["requirements"] + + +def test_export_requirements_no_llm(): + schema = _base_schema() + _add_tool_node(schema, tool_name="calculator") + result = export_graph(schema) + + assert "langchain-openai" not in result["requirements"] + assert "langchain-anthropic" not in result["requirements"] + assert "langgraph" in result["requirements"] + assert "simpleeval" in result["requirements"] + + +def test_export_state_typeddict(): + schema = _base_schema( + state=[ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + {"key": "data", "type": "object", "reducer": "merge"}, + {"key": "items", "type": "list", "reducer": "append"}, + ] + ) + result = export_graph(schema) + + assert "class GraphState(TypedDict):" in result["code"] + assert "Annotated[list, add_messages]" in result["code"] + assert "result: str" in result["code"] + assert "Annotated[dict, _merge_reducer]" in result["code"] + assert "Annotated[list, operator.add]" in result["code"] + + +def test_export_code_compiles(): + schema = _base_schema() + _add_llm_node(schema) + result = export_graph(schema) + + # Should not raise SyntaxError + compile(result["code"], "", "exec") + + +def test_export_code_ast_structure(): + schema = _base_schema() + _add_llm_node(schema) + _add_tool_node(schema, "tool_1") + # Rewire: s → llm_1 → tool_1 → e + schema["edges"] = [ + {"id": "e1", "source": "s", "target": "llm_1"}, + {"id": "e2", "source": "llm_1", "target": "tool_1"}, + {"id": "e3", "source": "tool_1", "target": "e"}, + ] + result = export_graph(schema) + + tree = ast.parse(result["code"]) + names = { + node.name + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + } + + assert "GraphState" in names + assert "llm_1" in names + assert "tool_1" in names + + +def test_export_complex_graph(): + """Graph with LLM + tool + condition + human_input — all node types.""" + schema = _base_schema( + state=[ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + {"key": "answer", "type": "string", "reducer": "replace"}, + ] + ) + schema["nodes"] = [ + { + "id": "s", + "type": "start", + "label": "S", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "llm_1", + "type": "llm", + "label": "LLM", + "position": {"x": 0, "y": 100}, + "config": { + "provider": "openai", + "model": "gpt-4o", + "input_map": {"q": "result"}, + "output_key": "result", + "temperature": 0.5, + "max_tokens": 512, + }, + }, + { + "id": "tool_1", + "type": "tool", + "label": "Calc", + "position": {"x": 0, "y": 200}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + { + "id": "ask", + "type": "human_input", + "label": "Ask", + "position": {"x": 0, "y": 300}, + "config": { + "prompt": "Confirm?", + "input_key": "answer", + }, + }, + { + "id": "cond_1", + "type": "condition", + "label": "Check", + "position": {"x": 0, "y": 400}, + "config": { + "condition": { + "type": "field_equals", + "field": "answer", + "value": "yes", + "branch": "done", + }, + "default_branch": "retry", + }, + }, + { + "id": "e", + "type": "end", + "label": "E", + "position": {"x": 0, "y": 500}, + "config": {}, + }, + ] + schema["edges"] = [ + {"id": "e1", "source": "s", "target": "llm_1"}, + {"id": "e2", "source": "llm_1", "target": "tool_1"}, + {"id": "e3", "source": "tool_1", "target": "ask"}, + {"id": "e4", "source": "ask", "target": "cond_1"}, + {"id": "e5", "source": "cond_1", "target": "e", "condition_branch": "done"}, + { + "id": "e6", + "source": "cond_1", + "target": "llm_1", + "condition_branch": "retry", + }, + ] + + result = export_graph(schema) + + # Should compile + compile(result["code"], "", "exec") + + # All node functions present + tree = ast.parse(result["code"]) + names = { + node.name + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + } + assert "GraphState" in names + assert "llm_1" in names + assert "tool_1" in names + assert "ask" in names + assert "cond_1" in names + assert "route_cond_1" in names + + # Has checkpointer (human_input present) + assert "InMemorySaver" in result["code"] diff --git a/packages/execution/tests/unit/test_routes.py b/packages/execution/tests/unit/test_routes.py index 1f5d953..5d22f21 100644 --- a/packages/execution/tests/unit/test_routes.py +++ b/packages/execution/tests/unit/test_routes.py @@ -464,12 +464,14 @@ async def test_validate_wrong_owner(client, admin_key): assert resp.status_code == 404 -async def test_export_returns_501(client, user_key): +async def test_export_returns_200(client, user_key): _, raw = user_key gid = await _create_graph(client, raw) resp = await client.get(f"/v1/graphs/{gid}/export", headers=_headers(raw)) - assert resp.status_code == 501 - assert "not implemented" in resp.json()["detail"].lower() + assert resp.status_code == 200 + body = resp.json() + assert "code" in body + assert "requirements" in body async def test_export_graph_not_found(client, user_key): diff --git a/packages/execution/tests/unit/test_routes_export.py b/packages/execution/tests/unit/test_routes_export.py new file mode 100644 index 0000000..414907c --- /dev/null +++ b/packages/execution/tests/unit/test_routes_export.py @@ -0,0 +1,117 @@ +"""Integration tests for the export route.""" + +from __future__ import annotations + +import aiosqlite +import httpx +import pytest + +from app.auth import SCOPES_DEFAULT +from app.db.migrations.runner import run_migrations +from app.executor import RunManager +from app.main import app +from tests.conftest import create_test_key + + +def _simple_schema(): + return { + "id": "exp", + "name": "ExportTest", + "version": 1, + "state": [ + {"key": "messages", "type": "list", "reducer": "append"}, + {"key": "result", "type": "string", "reducer": "replace"}, + ], + "nodes": [ + { + "id": "s", + "type": "start", + "label": "Start", + "position": {"x": 0, "y": 0}, + "config": {}, + }, + { + "id": "tool_1", + "type": "tool", + "label": "Calc", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + { + "id": "e", + "type": "end", + "label": "End", + "position": {"x": 0, "y": 200}, + "config": {}, + }, + ], + "edges": [ + {"id": "e1", "source": "s", "target": "tool_1"}, + {"id": "e2", "source": "tool_1", "target": "e"}, + ], + "metadata": { + "created_at": "2026-01-01", + "updated_at": "2026-01-01", + }, + } + + +@pytest.fixture(autouse=True) +def _env(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-dummy-key") + + +@pytest.fixture +async def client(tmp_path): + db_path = str(tmp_path / "test.db") + run_migrations(db_path) + db = await aiosqlite.connect(db_path) + db.row_factory = aiosqlite.Row + app.state.db = db + app.state.run_manager = RunManager() + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + await db.close() + + +@pytest.fixture +async def api_key(client): + db = app.state.db + key, raw = await create_test_key(db, scopes=SCOPES_DEFAULT, name="user") + return key, raw + + +def _headers(raw_key: str) -> dict: + return {"X-API-Key": raw_key} + + +async def _create_graph(client, raw_key): + resp = await client.post( + "/v1/graphs", + headers=_headers(raw_key), + json={"name": "Test", "schema_json": _simple_schema()}, + ) + return resp.json()["id"] + + +async def test_export_route_returns_200(client, api_key): + _, raw = api_key + gid = await _create_graph(client, raw) + resp = await client.get(f"/v1/graphs/{gid}/export", headers=_headers(raw)) + assert resp.status_code == 200 + body = resp.json() + assert "code" in body + assert "requirements" in body + assert "class GraphState(TypedDict):" in body["code"] + assert "langgraph" in body["requirements"] + + +async def test_export_route_not_found(client, api_key): + _, raw = api_key + resp = await client.get("/v1/graphs/nonexistent/export", headers=_headers(raw)) + assert resp.status_code == 404 diff --git a/packages/execution/tests/unit/test_tools/test_file_read.py b/packages/execution/tests/unit/test_tools/test_file_read.py new file mode 100644 index 0000000..073fc39 --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_file_read.py @@ -0,0 +1,89 @@ +"""File read tool tests.""" + +from __future__ import annotations + +import os + +from app.tools.file_read import FileReadTool + + +def _read(inputs: dict, sandbox: str) -> dict: + tool = FileReadTool() + orig = os.environ.get("FILE_SANDBOX_ROOT") + os.environ["FILE_SANDBOX_ROOT"] = sandbox + # Reload module-level default + import app.tools.file_read as mod + + mod._SANDBOX_ROOT = sandbox + try: + return tool.run(inputs) + finally: + if orig is not None: + os.environ["FILE_SANDBOX_ROOT"] = orig + elif "FILE_SANDBOX_ROOT" in os.environ: + del os.environ["FILE_SANDBOX_ROOT"] + + +def test_read_file(tmp_path): + (tmp_path / "hello.txt").write_text("Hello world", encoding="utf-8") + result = _read({"path": "hello.txt"}, str(tmp_path)) + assert result["success"] is True + assert result["result"] == "Hello world" + assert result["truncated"] is False + + +def test_read_empty_file(tmp_path): + (tmp_path / "empty.txt").write_text("", encoding="utf-8") + result = _read({"path": "empty.txt"}, str(tmp_path)) + assert result["success"] is True + assert result["result"] == "" + + +def test_path_traversal_blocked(tmp_path): + result = _read({"path": "../../etc/passwd"}, str(tmp_path)) + assert result["success"] is False + assert "escapes" in result["error"].lower() or "sandbox" in result["error"].lower() + + +def test_file_not_found(tmp_path): + result = _read({"path": "nonexistent.txt"}, str(tmp_path)) + assert result["success"] is False + + +def test_file_too_large(tmp_path): + big = tmp_path / "big.txt" + big.write_bytes(b"x" * 1_100_000) + result = _read({"path": "big.txt"}, str(tmp_path)) + assert result["success"] is False + assert "too large" in result["error"].lower() + + +def test_truncation(tmp_path): + (tmp_path / "long.txt").write_text("x" * 15_000, encoding="utf-8") + result = _read({"path": "long.txt"}, str(tmp_path)) + assert result["success"] is True + assert result["truncated"] is True + assert len(result["result"]) == 10_000 + + +def test_symlink_escape(tmp_path): + """Symlink pointing outside sandbox is caught by realpath check.""" + outside = tmp_path / "outside" + outside.mkdir() + target = outside / "secret.txt" + target.write_text("secret", encoding="utf-8") + + sandbox = tmp_path / "sandbox" + sandbox.mkdir() + link = sandbox / "link.txt" + link.symlink_to(target) + + result = _read({"path": "link.txt"}, str(sandbox)) + assert result["success"] is False + + +def test_binary_file_returns_error(tmp_path): + (tmp_path / "binary.bin").write_bytes(b"\x80\x81\x82\xff") + result = _read({"path": "binary.bin"}, str(tmp_path)) + assert result["success"] is False + assert "utf-8" in result["error"].lower() or "decode" in result["error"].lower() diff --git a/packages/execution/tests/unit/test_tools/test_file_write.py b/packages/execution/tests/unit/test_tools/test_file_write.py new file mode 100644 index 0000000..bfa80fa --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_file_write.py @@ -0,0 +1,119 @@ +"""File write tool tests.""" + +from __future__ import annotations + +from app.tools.file_read import FileReadTool +from app.tools.file_write import FileWriteTool + + +def _write(inputs: dict, sandbox: str) -> dict: + tool = FileWriteTool() + import app.tools.file_write as mod + + mod._SANDBOX_ROOT = sandbox + return tool.run(inputs) + + +def _read(inputs: dict, sandbox: str) -> dict: + tool = FileReadTool() + import app.tools.file_read as mod + + mod._SANDBOX_ROOT = sandbox + return tool.run(inputs) + + +def test_write_file(tmp_path): + result = _write({"path": "test.txt", "content": "Hello"}, str(tmp_path)) + assert result["success"] is True + assert "5 bytes" in result["result"] + assert (tmp_path / "test.txt").read_text(encoding="utf-8") == "Hello" + + +def test_append_mode(tmp_path): + (tmp_path / "log.txt").write_text("line1\n", encoding="utf-8") + result = _write( + {"path": "log.txt", "content": "line2\n", "mode": "append"}, + str(tmp_path), + ) + assert result["success"] is True + assert (tmp_path / "log.txt").read_text(encoding="utf-8") == "line1\nline2\n" + + +def test_path_traversal_blocked(tmp_path): + result = _write( + {"path": "../../etc/evil", "content": "pwned"}, + str(tmp_path), + ) + assert result["success"] is False + assert "escapes" in result["error"].lower() or "sandbox" in result["error"].lower() + + +def test_creates_parent_dirs(tmp_path): + result = _write( + {"path": "a/b/c/deep.txt", "content": "nested"}, + str(tmp_path), + ) + assert result["success"] is True + assert (tmp_path / "a" / "b" / "c" / "deep.txt").exists() + + +def test_content_too_large(tmp_path): + result = _write( + {"path": "big.txt", "content": "x" * 1_100_000}, + str(tmp_path), + ) + assert result["success"] is False + assert "too large" in result["error"].lower() + + +def test_symlink_escape(tmp_path): + """O_NOFOLLOW rejects symlink at the leaf component.""" + outside = tmp_path / "outside" + outside.mkdir() + target = outside / "target.txt" + target.write_text("original", encoding="utf-8") + + link = tmp_path / "sandbox" / "link.txt" + (tmp_path / "sandbox").mkdir() + link.symlink_to(target) + + result = _write( + {"path": "link.txt", "content": "overwritten"}, + str(tmp_path / "sandbox"), + ) + # O_NOFOLLOW should reject the symlink + assert result["success"] is False + # Original file should be untouched + assert target.read_text(encoding="utf-8") == "original" + + +def test_symlink_in_parent_directory(tmp_path): + """Symlink used as parent directory component is caught by realpath check.""" + outside = tmp_path / "outside" + outside.mkdir() + + sandbox = tmp_path / "sandbox" + sandbox.mkdir() + + # Create symlink inside sandbox pointing outside + (sandbox / "escape").symlink_to(outside) + + result = _write( + {"path": "escape/evil.txt", "content": "pwned"}, + str(sandbox), + ) + assert result["success"] is False + # Verify nothing was written outside + assert not (outside / "evil.txt").exists() + + +def test_file_roundtrip(tmp_path): + """Write via file_write, read back via file_read, content matches.""" + sandbox = str(tmp_path) + content = "Hello, GraphWeave!\nLine 2.\n" + write_result = _write({"path": "roundtrip.txt", "content": content}, sandbox) + assert write_result["success"] is True + + read_result = _read({"path": "roundtrip.txt"}, sandbox) + assert read_result["success"] is True + assert read_result["result"] == content diff --git a/packages/execution/tests/unit/test_tools/test_registry.py b/packages/execution/tests/unit/test_tools/test_registry.py index b48de89..d0cc9ae 100644 --- a/packages/execution/tests/unit/test_tools/test_registry.py +++ b/packages/execution/tests/unit/test_tools/test_registry.py @@ -6,17 +6,29 @@ from app.tools.registry import REGISTRY, ToolNotFoundError, get_tool +_ALL_TOOLS = { + "calculator", + "datetime", + "url_fetch", + "web_search", + "wikipedia", + "file_read", + "file_write", + "weather", +} + def test_unknown_tool_raises(): with pytest.raises(ToolNotFoundError, match="Unknown tool: nonexistent"): get_tool("nonexistent") -@pytest.mark.parametrize("name", ["calculator", "datetime", "url_fetch"]) +@pytest.mark.parametrize("name", sorted(_ALL_TOOLS)) def test_registered_tools_are_retrievable(name): tool = get_tool(name) assert tool.name == name -def test_registry_has_expected_tools(): - assert set(REGISTRY.keys()) == {"calculator", "datetime", "url_fetch"} +def test_registry_has_all_eight_tools(): + assert len(REGISTRY) == 8 + assert set(REGISTRY.keys()) == _ALL_TOOLS diff --git a/packages/execution/tests/unit/test_tools/test_ssrf_transport.py b/packages/execution/tests/unit/test_tools/test_ssrf_transport.py new file mode 100644 index 0000000..a0f7687 --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_ssrf_transport.py @@ -0,0 +1,88 @@ +"""SSRF-safe transport tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpcore +import httpx + +from app.tools.ssrf_transport import PinnedDNSBackend, SSRFSafeTransport + + +def test_backend_connect_tcp_receives_pinned_ip(): + """PinnedDNSBackend passes pinned IP to underlying backend, not hostname.""" + backend = PinnedDNSBackend("93.184.216.34") + + with patch.object(backend, "_backend") as mock_backend: + mock_backend.connect_tcp.return_value = MagicMock() + backend.connect_tcp("example.com", 443) + + mock_backend.connect_tcp.assert_called_once() + call_args = mock_backend.connect_tcp.call_args + assert call_args[0][0] == "93.184.216.34" + assert call_args[0][1] == 443 + + +def test_transport_end_to_end_with_httpx_client(): + """SSRFSafeTransport works through a real httpx.Client request cycle.""" + transport = SSRFSafeTransport("127.0.0.1") + + # Mock at the pool level to avoid actual network calls + mock_response = httpcore.Response( + status=200, + headers=[(b"content-type", b"text/plain")], + content=b"OK", + ) + with patch.object(transport._pool, "handle_request", return_value=mock_response): + client = httpx.Client(transport=transport) + resp = client.get("http://example.com/test") + + assert isinstance(resp, httpx.Response) + assert resp.status_code == 200 + assert resp.text == "OK" + client.close() + + +def test_transport_preserves_ssl_context(): + """SSL context from HTTPTransport.__init__ is forwarded to replacement pool.""" + transport = SSRFSafeTransport("93.184.216.34", verify=False) + assert transport._pool._ssl_context is not None + # verify=False produces a context that does not check hostnames + assert transport._pool._ssl_context.check_hostname is False + + +def test_transport_default_ssl_context(): + """Default transport (verify=True) has a verifying ssl_context.""" + transport = SSRFSafeTransport("93.184.216.34") + assert transport._pool._ssl_context is not None + assert transport._pool._ssl_context.check_hostname is True + + +def test_backend_forwards_timeout(): + """PinnedDNSBackend passes through timeout and other kwargs.""" + backend = PinnedDNSBackend("10.0.0.1") + + with patch.object(backend, "_backend") as mock_backend: + mock_backend.connect_tcp.return_value = MagicMock() + backend.connect_tcp("example.com", 80, timeout=5.0, local_address="0.0.0.0") + + call_kwargs = mock_backend.connect_tcp.call_args + assert call_kwargs[1]["timeout"] == 5.0 + assert call_kwargs[1]["local_address"] == "0.0.0.0" + + +def test_transport_uses_pinned_dns_backend(): + """The transport's pool uses PinnedDNSBackend as its network backend.""" + transport = SSRFSafeTransport("1.2.3.4") + pool = transport._pool + assert isinstance(pool._network_backend, PinnedDNSBackend) + assert pool._network_backend._pinned_ip == "1.2.3.4" + + +def test_transport_close(): + """Transport close delegates to pool.""" + transport = SSRFSafeTransport("1.2.3.4") + with patch.object(transport._pool, "close") as mock_close: + transport.close() + mock_close.assert_called_once() diff --git a/packages/execution/tests/unit/test_tools/test_url_fetch.py b/packages/execution/tests/unit/test_tools/test_url_fetch.py index e68fc62..d95bcf7 100644 --- a/packages/execution/tests/unit/test_tools/test_url_fetch.py +++ b/packages/execution/tests/unit/test_tools/test_url_fetch.py @@ -17,25 +17,36 @@ def _fetch(inputs: dict) -> dict: def test_bad_url_no_scheme(): - error = validate_url("not-a-url") + error, ip = validate_url("not-a-url") assert error is not None assert "http" in error.lower() or "scheme" in error.lower() + assert ip is None def test_ssrf_localhost(): with patch("app.tools.url_fetch.socket.getaddrinfo") as mock_gai: mock_gai.return_value = [(None, None, None, None, ("127.0.0.1", 0))] - error = validate_url("http://localhost/secret") + error, ip = validate_url("http://localhost/secret") assert error is not None assert "private" in error.lower() or "blocked" in error.lower() + assert ip is None def test_ssrf_private_ip(): with patch("app.tools.url_fetch.socket.getaddrinfo") as mock_gai: mock_gai.return_value = [(None, None, None, None, ("10.0.0.1", 0))] - error = validate_url("http://example.com/") + error, ip = validate_url("http://example.com/") assert error is not None assert "blocked" in error.lower() + assert ip is None + + +def test_validate_url_returns_resolved_ip(): + with patch("app.tools.url_fetch.socket.getaddrinfo") as mock_gai: + mock_gai.return_value = [(None, None, None, None, ("93.184.216.34", 0))] + error, ip = validate_url("http://example.com/") + assert error is None + assert ip == "93.184.216.34" # ── UrlFetchTool.run ──────────────────────────────────────────────────── @@ -47,7 +58,7 @@ def test_successful_fetch(): mock_response.raise_for_status = MagicMock() with ( - patch("app.tools.url_fetch.validate_url", return_value=None), + patch("app.tools.url_fetch.validate_url", return_value=(None, "93.184.216.34")), patch("app.tools.url_fetch.httpx.Client") as mock_client_cls, patch("app.tools.url_fetch.trafilatura.extract", return_value="Hello world"), ): @@ -69,7 +80,7 @@ def test_empty_extraction(): mock_response.raise_for_status = MagicMock() with ( - patch("app.tools.url_fetch.validate_url", return_value=None), + patch("app.tools.url_fetch.validate_url", return_value=(None, "93.184.216.34")), patch("app.tools.url_fetch.httpx.Client") as mock_client_cls, patch("app.tools.url_fetch.trafilatura.extract", return_value=None), ): @@ -93,7 +104,7 @@ def test_truncation(): mock_response.raise_for_status = MagicMock() with ( - patch("app.tools.url_fetch.validate_url", return_value=None), + patch("app.tools.url_fetch.validate_url", return_value=(None, "93.184.216.34")), patch("app.tools.url_fetch.httpx.Client") as mock_client_cls, patch("app.tools.url_fetch.trafilatura.extract", return_value=long_text), ): @@ -111,7 +122,7 @@ def test_truncation(): def test_timeout(): with ( - patch("app.tools.url_fetch.validate_url", return_value=None), + patch("app.tools.url_fetch.validate_url", return_value=(None, "93.184.216.34")), patch("app.tools.url_fetch.httpx.Client") as mock_client_cls, ): mock_client = MagicMock() diff --git a/packages/execution/tests/unit/test_tools/test_weather.py b/packages/execution/tests/unit/test_tools/test_weather.py new file mode 100644 index 0000000..c46ae35 --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_weather.py @@ -0,0 +1,109 @@ +"""Weather tool tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpx + +from app.tools.weather import WeatherTool + + +def _weather(inputs: dict) -> dict: + return WeatherTool().run(inputs) + + +def test_current_weather(): + geo_resp = MagicMock() + geo_resp.json.return_value = {"results": [{"latitude": 40.71, "longitude": -74.01}]} + geo_resp.raise_for_status = MagicMock() + + weather_resp = MagicMock() + weather_resp.json.return_value = { + "current": { + "temperature_2m": 22.5, + "relative_humidity_2m": 55, + "wind_speed_10m": 12.3, + "weather_code": 1, + } + } + weather_resp.raise_for_status = MagicMock() + + with patch("app.tools.weather.httpx.get", side_effect=[geo_resp, weather_resp]): + result = _weather({"location": "New York", "action": "current"}) + + assert result["success"] is True + assert "22.5" in result["result"] + assert "55" in result["result"] + assert result["source"] == "open-meteo" + + +def test_forecast(): + geo_resp = MagicMock() + geo_resp.json.return_value = {"results": [{"latitude": 48.85, "longitude": 2.35}]} + geo_resp.raise_for_status = MagicMock() + + forecast_resp = MagicMock() + forecast_resp.json.return_value = { + "daily": { + "time": ["2026-03-14", "2026-03-15"], + "temperature_2m_max": [18.0, 20.0], + "temperature_2m_min": [10.0, 12.0], + "weather_code": [1, 3], + } + } + forecast_resp.raise_for_status = MagicMock() + + with patch("app.tools.weather.httpx.get", side_effect=[geo_resp, forecast_resp]): + result = _weather({"location": "Paris", "action": "forecast"}) + + assert result["success"] is True + assert "2026-03-14" in result["result"] + assert "18.0" in result["result"] + + +def test_unknown_location(): + geo_resp = MagicMock() + geo_resp.json.return_value = {"results": []} + geo_resp.raise_for_status = MagicMock() + + with patch("app.tools.weather.httpx.get", return_value=geo_resp): + result = _weather({"location": "Xyzzyville"}) + + assert result["success"] is False + assert "not found" in result["error"].lower() + assert result["recoverable"] is False + + +def test_timeout(): + with patch( + "app.tools.weather.httpx.get", + side_effect=httpx.TimeoutException("timed out"), + ): + result = _weather({"location": "London"}) + + assert result["success"] is False + assert result["recoverable"] is True + + +def test_latlon_input_skips_geocoding(): + weather_resp = MagicMock() + weather_resp.json.return_value = { + "current": { + "temperature_2m": 30.0, + "relative_humidity_2m": 80, + "wind_speed_10m": 5.0, + "weather_code": 0, + } + } + weather_resp.raise_for_status = MagicMock() + + with patch("app.tools.weather.httpx.get", return_value=weather_resp) as mock_get: + result = _weather({"location": "-33.87, 151.21", "action": "current"}) + + assert result["success"] is True + # Only one call (weather), no geocoding call + assert mock_get.call_count == 1 + call_params = mock_get.call_args[1]["params"] + assert float(call_params["latitude"]) == -33.87 + assert float(call_params["longitude"]) == 151.21 diff --git a/packages/execution/tests/unit/test_tools/test_web_search.py b/packages/execution/tests/unit/test_tools/test_web_search.py new file mode 100644 index 0000000..118b88d --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_web_search.py @@ -0,0 +1,101 @@ +"""Web search tool tests.""" + +from __future__ import annotations + +import os +from unittest.mock import MagicMock, patch + +from app.tools.web_search import WebSearchTool + + +def _search(inputs: dict) -> dict: + return WebSearchTool().run(inputs) + + +def test_empty_query(): + result = _search({"query": ""}) + assert result["success"] is False + assert result["recoverable"] is False + + +def test_tavily_path(): + mock_response = { + "results": [ + {"title": "Result 1", "url": "https://example.com", "content": "Snippet 1"} + ] + } + with ( + patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}), + patch("tavily.TavilyClient") as mock_cls, + ): + mock_cls.return_value.search.return_value = mock_response + result = _search({"query": "test query"}) + + assert result["success"] is True + assert result["source"] == "tavily" + assert "Result 1" in result["result"] + assert "https://example.com" in result["result"] + + +def test_ddg_fallback(): + raw = [{"title": "DDG Result", "href": "https://ddg.com", "body": "DDG snippet"}] + # Ensure no Tavily key + env = {k: v for k, v in os.environ.items() if k != "TAVILY_API_KEY"} + with ( + patch.dict("os.environ", env, clear=True), + patch("duckduckgo_search.DDGS") as mock_ddgs, + ): + mock_ctx = MagicMock() + mock_ctx.text.return_value = raw + mock_ddgs.return_value.__enter__ = MagicMock(return_value=mock_ctx) + mock_ddgs.return_value.__exit__ = MagicMock(return_value=False) + + result = _search({"query": "test query"}) + + assert result["success"] is True + assert result["source"] == "duckduckgo" + assert "DDG Result" in result["result"] + + +def test_tavily_timeout(): + with ( + patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}), + patch("tavily.TavilyClient") as mock_cls, + ): + mock_cls.return_value.search.side_effect = Exception("Connection timed out") + result = _search({"query": "test"}) + + assert result["success"] is False + assert result["recoverable"] is True + + +def test_max_results_clamped(): + env = {k: v for k, v in os.environ.items() if k != "TAVILY_API_KEY"} + with ( + patch.dict("os.environ", env, clear=True), + patch("duckduckgo_search.DDGS") as mock_ddgs, + ): + mock_ctx = MagicMock() + mock_ctx.text.return_value = [] + mock_ddgs.return_value.__enter__ = MagicMock(return_value=mock_ctx) + mock_ddgs.return_value.__exit__ = MagicMock(return_value=False) + + _search({"query": "test", "max_results": 100}) + + mock_ctx.text.assert_called_once_with("test", max_results=10) + + +def test_max_results_within_range(): + env = {k: v for k, v in os.environ.items() if k != "TAVILY_API_KEY"} + with ( + patch.dict("os.environ", env, clear=True), + patch("duckduckgo_search.DDGS") as mock_ddgs, + ): + mock_ctx = MagicMock() + mock_ctx.text.return_value = [] + mock_ddgs.return_value.__enter__ = MagicMock(return_value=mock_ctx) + mock_ddgs.return_value.__exit__ = MagicMock(return_value=False) + + _search({"query": "test", "max_results": 3}) + + mock_ctx.text.assert_called_once_with("test", max_results=3) diff --git a/packages/execution/tests/unit/test_tools/test_wikipedia.py b/packages/execution/tests/unit/test_tools/test_wikipedia.py new file mode 100644 index 0000000..9f237dd --- /dev/null +++ b/packages/execution/tests/unit/test_tools/test_wikipedia.py @@ -0,0 +1,121 @@ +"""Wikipedia tool tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpx + +from app.tools.wikipedia_tool import WikipediaTool + + +def _wiki(inputs: dict) -> dict: + return WikipediaTool().run(inputs) + + +def test_search_returns_titles(): + mock_response = MagicMock() + mock_response.json.return_value = [ + "python", + ["Python (programming language)", "Python (genus)"], + ["", ""], + ["https://en.wikipedia.org/wiki/Python_(programming_language)", ""], + ] + mock_response.raise_for_status = MagicMock() + + with patch("app.tools.wikipedia_tool.httpx.get", return_value=mock_response): + result = _wiki({"action": "search", "query": "python"}) + + assert result["success"] is True + assert "Python (programming language)" in result["result"] + assert result["source"] == "wikipedia" + + +def test_search_empty_query(): + result = _wiki({"action": "search", "query": ""}) + assert result["success"] is False + assert result["recoverable"] is False + + +def test_search_timeout(): + with patch( + "app.tools.wikipedia_tool.httpx.get", + side_effect=httpx.TimeoutException("timed out"), + ): + result = _wiki({"action": "search", "query": "test"}) + + assert result["success"] is False + assert result["recoverable"] is True + + +def test_page_content(): + mock_page = MagicMock() + mock_page.exists.return_value = True + mock_page.summary = "Python is a programming language." + mock_page.text = "Python is a programming language. More details..." + + mock_wiki = MagicMock() + mock_wiki.page.return_value = mock_page + + with patch( + "app.tools.wikipedia_tool.wikipediaapi.Wikipedia", return_value=mock_wiki + ): + result = _wiki({"action": "page", "title": "Python"}) + + assert result["success"] is True + assert "programming language" in result["result"] + + +def test_page_not_found(): + mock_page = MagicMock() + mock_page.exists.return_value = False + + mock_wiki = MagicMock() + mock_wiki.page.return_value = mock_page + + with patch( + "app.tools.wikipedia_tool.wikipediaapi.Wikipedia", return_value=mock_wiki + ): + result = _wiki({"action": "page", "title": "Nonexistent12345"}) + + assert result["success"] is False + assert result["recoverable"] is False + + +def test_page_truncation(): + mock_page = MagicMock() + mock_page.exists.return_value = True + mock_page.summary = "Short summary." + mock_page.text = "x" * 15_000 + + mock_wiki = MagicMock() + mock_wiki.page.return_value = mock_page + + with patch( + "app.tools.wikipedia_tool.wikipediaapi.Wikipedia", return_value=mock_wiki + ): + result = _wiki({"action": "page", "title": "Long Article"}) + + assert result["success"] is True + assert result["truncated"] is True + assert len(result["result"]) == 10_000 + + +def test_user_agent_set(): + with patch("app.tools.wikipedia_tool.wikipediaapi.Wikipedia") as mock_wiki_cls: + mock_page = MagicMock() + mock_page.exists.return_value = True + mock_page.summary = "test" + mock_page.text = "test" + mock_wiki_cls.return_value.page.return_value = mock_page + + _wiki({"action": "page", "title": "Test"}) + + call_kwargs = mock_wiki_cls.call_args[1] + assert "GraphWeave" in call_kwargs["user_agent"] + assert call_kwargs["language"] == "en" + + +def test_unknown_action(): + result = _wiki({"action": "unknown"}) + assert result["success"] is False diff --git a/packages/execution/uv.lock b/packages/execution/uv.lock index f7ac89d..2bb3cb7 100644 --- a/packages/execution/uv.lock +++ b/packages/execution/uv.lock @@ -297,6 +297,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, ] +[[package]] +name = "duckduckgo-search" +version = "7.5.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "lxml" }, + { name = "primp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/dc/919d3d51ed702890a3e6e736e1e152d5d90856393200306e82fb54fde39e/duckduckgo_search-7.5.5.tar.gz", hash = "sha256:44ef03bfa5484bada786590f2d4c213251131765721383a177a0da6fa5c5e41a", size = 24768, upload-time = "2025-03-27T08:11:26.951Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/da/8376678b4a9ae0f9418d93df9c9cf851dced49c95ceb38daac6651e38f7a/duckduckgo_search-7.5.5-py3-none-any.whl", hash = "sha256:c71a0661aa436f215d9a05d653af424affb58825ab3e79f3b788053cbdee9ebc", size = 20421, upload-time = "2025-03-27T08:11:25.515Z" }, +] + [[package]] name = "fastapi" version = "0.135.1" @@ -367,6 +381,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, + { name = "duckduckgo-search" }, { name = "fastapi" }, { name = "httpx" }, { name = "langchain-anthropic" }, @@ -376,8 +391,10 @@ dependencies = [ { name = "pydantic" }, { name = "simpleeval" }, { name = "slowapi" }, + { name = "tavily-python" }, { name = "trafilatura" }, { name = "uvicorn", extra = ["standard"] }, + { name = "wikipedia-api" }, ] [package.dev-dependencies] @@ -391,6 +408,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.22.1" }, + { name = "duckduckgo-search", specifier = ">=6.0.0,<8.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "httpx", specifier = ">=0.28.0" }, { name = "langchain-anthropic", specifier = ">=0.3.0" }, @@ -400,8 +418,10 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.11.0" }, { name = "simpleeval", specifier = ">=1.0.3" }, { name = "slowapi", specifier = ">=0.1.9" }, + { name = "tavily-python", specifier = ">=0.5.0,<1.0" }, { name = "trafilatura", specifier = ">=2.0.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, + { name = "wikipedia-api", specifier = ">=0.7.0" }, ] [package.metadata.requires-dev] @@ -925,6 +945,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "primp" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/0e/62ed44af95c66fd6fa8ad49c8bde815f64c7e976772d6979730be2b7cd97/primp-1.1.3.tar.gz", hash = "sha256:56adc3b8a5048cbd5f926b21fdff839195f3a9181512ca33f56ddc66f4c95897", size = 311356, upload-time = "2026-03-11T06:42:51.763Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/6b/36794b5758a0dd1251e67b6ab3ea946e53fa69745e0ecc29facc072ddf5b/primp-1.1.3-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:24383cfc267f620769be102b7fa4b64c7d47105f86bd21d047f1e07709e83c6e", size = 4000660, upload-time = "2026-03-11T06:42:58.092Z" }, + { url = "https://files.pythonhosted.org/packages/98/18/ebbe318a926d158c57f9e9cf49bbea70e8f0bd7f87e7675ed68e0d6ab433/primp-1.1.3-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:61bcb8c53b41e4bac43d04a1374b6ab7d8ded0f3517d32c5cdd5c30562756805", size = 3737318, upload-time = "2026-03-11T06:42:50.19Z" }, + { url = "https://files.pythonhosted.org/packages/a9/4c/430c9154284b53b771e6713a18dec4ad0159e4a501a20b222d67c730ced9/primp-1.1.3-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0c6b9388578ee9d903f30549a792c5f391fdeb9d36b508da2ffb8e13c764954", size = 3881005, upload-time = "2026-03-11T06:43:12.894Z" }, + { url = "https://files.pythonhosted.org/packages/93/34/2466ef66386a1b50e6aaf7832f9f603628407bb33342378faf4b38c4aee8/primp-1.1.3-cp310-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:09a8bfa870c92c81d76611846ec53b2520845e3ec5f4139f47604986bcf4bc25", size = 3514480, upload-time = "2026-03-11T06:43:06.058Z" }, + { url = "https://files.pythonhosted.org/packages/ff/42/ca7a71df6493dd6c1971c0cc3b20b8125e2547eb3bf88b4429715cb6ed81/primp-1.1.3-cp310-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ac372cb9959fff690b255fad91c5b3bc948c14065da9fc00ad80d139651515af", size = 3734658, upload-time = "2026-03-11T06:43:47.486Z" }, + { url = "https://files.pythonhosted.org/packages/bc/7c/0fb34db619e9935e11140929713c2c7b5323c1e8ba75cad6f0aade51c89d/primp-1.1.3-cp310-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3034672a007f04e12b8fe7814c97ea172e8b9c5d45bd7b00cf6e7334fdd4222a", size = 4011898, upload-time = "2026-03-11T06:43:41.121Z" }, + { url = "https://files.pythonhosted.org/packages/da/8b/afd1bd8b14f38d58c5ebd0d45fc6b74914956907aa4e981bb2e5231626d3/primp-1.1.3-cp310-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a07d5b7d7278dc63452a59f3bf851dc4d1f8ddc2aada7844cbdb68002256e2f4", size = 3910728, upload-time = "2026-03-11T06:43:01.819Z" }, + { url = "https://files.pythonhosted.org/packages/32/9e/1ec3a9678efcbb51e50d7b4886d9195f956c9fd7f4efcff13ccb152248b0/primp-1.1.3-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08eec2f58abbcc1060032a2af81dabacec87a580a364a75862039f7422ac82e6", size = 4114189, upload-time = "2026-03-11T06:42:47.639Z" }, + { url = "https://files.pythonhosted.org/packages/28/d9/76de611027c0688be188d5a833be45b1e36d9c0c98baefab27bf6336ab9d/primp-1.1.3-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9716d4cd36db2c175443fe1bbd54045a944fc9c49d01a385af8ada1fe9c948df", size = 4061973, upload-time = "2026-03-11T06:43:37.301Z" }, + { url = "https://files.pythonhosted.org/packages/37/3b/a30a5ea366705d0ece265b12ad089793d644bd5730b18201e3a0a7fa7b5f/primp-1.1.3-cp310-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:e19daca65dc6df369c33e711fa481ad2afe5d26c5bde926c069b3ab067c4fd45", size = 3747920, upload-time = "2026-03-11T06:43:10.403Z" }, + { url = "https://files.pythonhosted.org/packages/df/46/e3c323221c371cdfe6c2ed971f7a70e3b69f30b561977715c55230bd5fda/primp-1.1.3-cp310-abi3-musllinux_1_2_i686.whl", hash = "sha256:ee357537712aa486364b0194cf403c5f9eaaa1354e23e9ac8322a22003f31e6b", size = 3861184, upload-time = "2026-03-11T06:43:49.391Z" }, + { url = "https://files.pythonhosted.org/packages/8a/7f/babaf00753daad7d80061003d7ae1bdfca64ea94c181cdea8d25c8a7226a/primp-1.1.3-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:06c53e77ebf6ac00633bc09e7e5a6d1a994592729d399ca8f065451a2574b92e", size = 4364610, upload-time = "2026-03-11T06:42:56.223Z" }, + { url = "https://files.pythonhosted.org/packages/03/48/c7bca8045c681f5f60972c180d2a20582c7a0857b3b07b12e0a0ee062ac4/primp-1.1.3-cp310-abi3-win32.whl", hash = "sha256:4b1ea3693c118bf04a6e05286f0a73637cf6fe5c9fd77fa1e29a01f190adf512", size = 3265160, upload-time = "2026-03-11T06:43:43.774Z" }, + { url = "https://files.pythonhosted.org/packages/45/3e/4a4b8a0f6f15734cded91e85439e68912b2bb8eafe7132420c13c2db8340/primp-1.1.3-cp310-abi3-win_amd64.whl", hash = "sha256:5ea386a4c8c4d8c1021d17182f4ee24dbb6f17c107c4e9ee5500b6372cf08f32", size = 3603953, upload-time = "2026-03-11T06:43:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/70/46/1baf13a7f5fbed6052deb3e4822c69441a8d0fd990fe2a50e4cec802130b/primp-1.1.3-cp310-abi3-win_arm64.whl", hash = "sha256:63c7b1a1ccbcd07213f438375df186f807cdc5214bc2debb055737db9b5078de", size = 3619917, upload-time = "2026-03-11T06:42:44.76Z" }, + { url = "https://files.pythonhosted.org/packages/be/0c/a73cbe13f075e7ceaa5172b44ebc6f423713c6b4efe168114993a1710b26/primp-1.1.3-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:4b3d52f3233134584ef527e7e52f1b371a964ade1df0461f8187100e41d7fa84", size = 3987141, upload-time = "2026-03-11T06:43:24.904Z" }, + { url = "https://files.pythonhosted.org/packages/49/56/b70d7991fb1e07af53706b1f69f78a0b440a7b4b2a2999c44ab44afef1e7/primp-1.1.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b3d947e2c1d15147e8f4736d027b9f3bef518d67da859ead1c54e028ff491bbb", size = 3735665, upload-time = "2026-03-11T06:43:29.347Z" }, + { url = "https://files.pythonhosted.org/packages/31/82/69efc663341c2bab55659ed221903a090e5c80255c2de2acc70f3726a3fc/primp-1.1.3-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ed2fee7d4758f6bb873b19a6759f54e0bc453213dad5ba7e52de7582921079", size = 3873695, upload-time = "2026-03-11T06:43:15.396Z" }, + { url = "https://files.pythonhosted.org/packages/07/7e/6b360742019ef8fb4ea036a420eb21b0a58d380ca09c68b075fc103cc043/primp-1.1.3-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5aa717f256af9e4391fb1c4dc946d99d04652b4c57dad20c3947e839ab26769", size = 3512644, upload-time = "2026-03-11T06:43:08.368Z" }, + { url = "https://files.pythonhosted.org/packages/03/46/51d2ada6d5b53b8496eddf2c80392deab13698987412d0234f88e72390c1/primp-1.1.3-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17f37fcacd97540f68b06f2b468b111ca7f2b142c48370db7344b522274fc0d6", size = 3733114, upload-time = "2026-03-11T06:43:22.838Z" }, + { url = "https://files.pythonhosted.org/packages/45/f5/5f5f5f4bef7e247ec3543e2fbdb670d8db8753a7693baf9c8b9fcf52cd43/primp-1.1.3-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d5f010d0b8ba111dd9a66f814c2cd56332e047c98f45d7714ffbf2b1cec5b073", size = 4005664, upload-time = "2026-03-11T06:43:20.824Z" }, + { url = "https://files.pythonhosted.org/packages/f2/bf/99cf4a5f179b3f13b0c2ba4d3ae8f8af19f0084308e76cb79a0cee03c31b/primp-1.1.3-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e1e431915e4a7094d589213fc14e955243d93751031d889f4b359fa8ed54298", size = 3895746, upload-time = "2026-03-11T06:43:35.376Z" }, + { url = "https://files.pythonhosted.org/packages/c3/75/4c625e1cab37585365b0856ca44f31ad598e92a847d23561f454b7f36fca/primp-1.1.3-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaffa22dae2f193d899d9f68cca109ea5d16cdf4c901c20cec186de89e7d5db4", size = 4109815, upload-time = "2026-03-11T06:43:04.059Z" }, + { url = "https://files.pythonhosted.org/packages/49/72/6197ea78779d359f307be1acc64659896fc960ed91c0bdc6e6e698e423e6/primp-1.1.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f93bee50990884621ef482e8434e87f9fbb4eca6f4d47973c44c5d6393c35679", size = 4050839, upload-time = "2026-03-11T06:43:18.296Z" }, + { url = "https://files.pythonhosted.org/packages/a4/b2/cdd565b28bcf7ce555f4decdf89dafd16db8ed3ba8661890d3b9337abe45/primp-1.1.3-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:399dfb9ad01c3612c9e510a7034ac925af5524cade0961d8a019dedd90a46474", size = 3748397, upload-time = "2026-03-11T06:43:27.347Z" }, + { url = "https://files.pythonhosted.org/packages/62/6e/def3a90821b52589dbe1f57477c2c89bde7a5b26a7c166d7751930c06f98/primp-1.1.3-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:78ce595bbb9f339e83975efa9db2a81128842fad1a2fdafb78d72fcdc59590fc", size = 3861261, upload-time = "2026-03-11T06:43:39.292Z" }, + { url = "https://files.pythonhosted.org/packages/10/7d/3e610614d6a426502cfc6eccea21ef4557b39177d365df393c994945ca43/primp-1.1.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d709bdf520aa9401c0592b642730b3477c828629f01d2550977b77135b34e8d", size = 4358608, upload-time = "2026-03-11T06:43:45.606Z" }, + { url = "https://files.pythonhosted.org/packages/91/50/eb190cefe5eb05896825a5b3365d5650b9327161329cd1df4f7351b66ba9/primp-1.1.3-cp314-cp314t-win32.whl", hash = "sha256:6fe893eb87156dfb146dd666c7c8754670de82e38af0a27d82a47b7461ec2eea", size = 3259903, upload-time = "2026-03-11T06:42:59.922Z" }, + { url = "https://files.pythonhosted.org/packages/1f/a8/9e8534bc6d729a667f79b249fcdbf2230b0eb41214e277998cd6be900498/primp-1.1.3-cp314-cp314t-win_amd64.whl", hash = "sha256:ced76ef6669f31dc4af25e81e87914310645bcfc0892036bde084dafd6d00c3c", size = 3602569, upload-time = "2026-03-11T06:42:53.955Z" }, + { url = "https://files.pythonhosted.org/packages/9c/92/e18be996a01c7fd0e7dd7d198edefe42813cdfe1637bbbc80370ce656f62/primp-1.1.3-cp314-cp314t-win_arm64.whl", hash = "sha256:efadef0dfd10e733a254a949abf9ed05c668c28a68aa6513d811c0c6acd54cdb", size = 3611571, upload-time = "2026-03-11T06:43:31.249Z" }, +] + [[package]] name = "pyasn1" version = "0.6.2" @@ -1301,6 +1359,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] +[[package]] +name = "tavily-python" +version = "0.7.23" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "requests" }, + { name = "tiktoken" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/d1/197419d6133643848514e5e84e8f41886e825b73bf91ae235a1595c964f5/tavily_python-0.7.23.tar.gz", hash = "sha256:3b92232e0e29ab68898b765f281bb4f2c650b02210b64affbc48e15292e96161", size = 25968, upload-time = "2026-03-09T19:17:32.333Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/27/f9c6e9249367be0772fb754849e03cbbc6ad8d80a479bf30ea8811828b2e/tavily_python-0.7.23-py3-none-any.whl", hash = "sha256:52ef85c44b926bce3f257570cd32bc1bd4db54666acf3105617f27411a59e188", size = 19079, upload-time = "2026-03-09T19:17:29.593Z" }, +] + [[package]] name = "tenacity" version = "9.1.4" @@ -1605,6 +1677,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, ] +[[package]] +name = "wikipedia-api" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/59/a6f5790043cff046658d44bc5b594a7d0cc0d9a1a6911a0df6e7aba2179c/wikipedia_api-0.10.2.tar.gz", hash = "sha256:93fc84d2d88b043c626a03bc013a741c206ab60c0517bfce51fa60a0edc5087d", size = 29841, upload-time = "2026-03-06T22:01:16.615Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/41/cf1ceb3071b58175d6960d5f45bfd5c007fc8ab5bec8ca189efd02bb05e4/wikipedia_api-0.10.2-py3-none-any.whl", hash = "sha256:0aa6d09e46909d396d81af97de5dee06004c1f823725b4f66c352b1096d48163", size = 22690, upload-time = "2026-03-06T22:01:15.024Z" }, +] + [[package]] name = "wrapt" version = "2.1.2" From 7e04033e7859bf78718d1bc67fbe28020beae404 Mon Sep 17 00:00:00 2001 From: prosdev Date: Sat, 14 Mar 2026 15:24:19 -0700 Subject: [PATCH 3/5] fix: address code review findings for Phase 5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Exporter: validate all identifiers against safe pattern before interpolation into generated Python (prevents code injection). Apply _escape() to all string values in generated code. - url_fetch: guard against empty getaddrinfo results returning (None, None) — now returns an explicit error. - file_read: fix fd leak if fstat raises unexpectedly — added finally block with ownership tracking. - file_write: wrap os.makedirs in try/except to return proper error envelope instead of uncaught 500. - web_search: wrap int(max_results) in try/except to handle non-numeric input gracefully. - exporter: _python_default uses .get() fallback for unknown types. - Added cross-owner isolation test for export route. Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/execution/app/exporter.py | 75 ++++++++++++++++--- packages/execution/app/tools/file_read.py | 8 +- packages/execution/app/tools/file_write.py | 9 ++- packages/execution/app/tools/url_fetch.py | 3 + packages/execution/app/tools/web_search.py | 5 +- .../execution/tests/unit/test_exporter.py | 49 +++++++++++- .../tests/unit/test_routes_export.py | 13 ++++ 7 files changed, 148 insertions(+), 14 deletions(-) diff --git a/packages/execution/app/exporter.py b/packages/execution/app/exporter.py index 24a3f5f..a497124 100644 --- a/packages/execution/app/exporter.py +++ b/packages/execution/app/exporter.py @@ -2,6 +2,29 @@ from __future__ import annotations +import re + +_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +class ExportError(Exception): + """Raised when a schema cannot be safely exported.""" + + +def _validate_identifier(value: str, context: str) -> str: + """Validate that a value is a safe Python identifier. + + Raises ExportError if the value contains characters that could + inject code when interpolated into generated Python source. + """ + if not _SAFE_IDENTIFIER.match(value): + raise ExportError( + f"Unsafe identifier for {context}: {value!r}. " + "Must match [a-zA-Z_][a-zA-Z0-9_]*" + ) + return value + + # Type → Python type name mapping _TYPE_NAMES = { "string": "str", @@ -57,6 +80,35 @@ def export_graph(schema: dict) -> dict: edges = schema.get("edges", []) state_fields = schema.get("state", []) + # Validate all identifiers before interpolation into generated code + for n in nodes: + _validate_identifier(n["id"], "node_id") + cfg = n.get("config", {}) + if "output_key" in cfg: + _validate_identifier(cfg["output_key"], "output_key") + if "input_key" in cfg: + _validate_identifier(cfg["input_key"], "input_key") + if "input_map" in cfg: + for key in cfg["input_map"]: + _validate_identifier(key, "input_map key") + cond = cfg.get("condition", {}) + if "field" in cond: + _validate_identifier(cond["field"], "condition field") + if "branch" in cond: + _validate_identifier(cond["branch"], "condition branch") + if "on_success" in cond: + _validate_identifier(cond["on_success"], "on_success") + if "on_error" in cond: + _validate_identifier(cond["on_error"], "on_error") + if "continue" in cond: + _validate_identifier(cond["continue"], "continue branch") + if "exceeded" in cond: + _validate_identifier(cond["exceeded"], "exceeded branch") + if "default_branch" in cfg: + _validate_identifier(cfg["default_branch"], "default_branch") + for f in state_fields: + _validate_identifier(f["key"], "state field key") + nodes_by_id = {n["id"]: n for n in nodes} start_id = next((n["id"] for n in nodes if n["type"] == "start"), None) end_ids = {n["id"] for n in nodes if n["type"] == "end"} @@ -274,7 +326,7 @@ def _build_llm_function(node_id: str, config: dict) -> str: # Build messages lines.append(" from langchain_core.messages import HumanMessage, SystemMessage") lines.append( - f' llm = {cls_name}(model="{model}", ' + f' llm = {cls_name}(model="{_escape(model)}", ' f"temperature={temp}, max_tokens={max_tokens})" ) lines.append(" messages = []") @@ -288,7 +340,7 @@ def _build_llm_function(node_id: str, config: dict) -> str: if input_map: parts = [] for key, expr in input_map.items(): - parts.append(f'"{key}: " + str(state.get("{expr}", ""))') + parts.append(f'"{key}: " + str(state.get("{_escape(expr)}", ""))') user_content = ' + "\\n" + '.join(parts) lines.append(f" messages.append(HumanMessage(content={user_content}))") else: @@ -306,13 +358,13 @@ def _build_tool_function(node_id: str, config: dict) -> str: input_map = config.get("input_map", {}) lines = [f"def {node_id}(state: GraphState) -> dict:"] - lines.append(f' """Run the {tool_name} tool."""') + lines.append(f' """Run the {_escape(tool_name)} tool."""') # Resolve inputs if input_map: lines.append(" inputs = {") for key, expr in input_map.items(): - lines.append(f' "{key}": state.get("{expr}", ""),') + lines.append(f' "{key}": state.get("{_escape(expr)}", ""),') lines.append(" }") else: lines.append(" inputs = {}") @@ -357,7 +409,7 @@ def _build_router_function( if ctype == "field_equals": field = condition.get("field", "") - value = condition.get("value", "") + value = _escape(condition.get("value", "")) branch = condition.get("branch", "") lines.append(f' if state.get("{field}") == "{value}":') lines.append(f' return "{branch}"') @@ -365,7 +417,7 @@ def _build_router_function( elif ctype == "field_contains": field = condition.get("field", "") - value = condition.get("value", "") + value = _escape(condition.get("value", "")) branch = condition.get("branch", "") lines.append(f' if "{value}" in str(state.get("{field}", "")):') lines.append(f' return "{branch}"') @@ -493,9 +545,14 @@ def _build_main_block(state_fields: list[dict]) -> str: def _python_default(type_name: str) -> object: - return {"string": "", "number": 0.0, "boolean": False, "list": [], "object": {}}[ - type_name - ] + _defaults = { + "string": "", + "number": 0.0, + "boolean": False, + "list": [], + "object": {}, + } + return _defaults.get(type_name, "") def _escape(s: str) -> str: diff --git a/packages/execution/app/tools/file_read.py b/packages/execution/app/tools/file_read.py index 2e1dd10..934443f 100644 --- a/packages/execution/app/tools/file_read.py +++ b/packages/execution/app/tools/file_read.py @@ -43,14 +43,15 @@ def run(self, inputs: dict) -> dict: try: size = os.fstat(fd).st_size if size > _MAX_FILE_SIZE: - os.close(fd) return { "success": False, - "error": f"File too large: {size} bytes (max {_MAX_FILE_SIZE})", + "error": (f"File too large: {size} bytes (max {_MAX_FILE_SIZE})"), "recoverable": False, } + # fdopen takes ownership of fd — closes it on exit with os.fdopen(fd, "r", encoding="utf-8") as f: + fd = -1 # prevent double-close in finally text = f.read() except UnicodeDecodeError as exc: return { @@ -64,6 +65,9 @@ def run(self, inputs: dict) -> dict: "error": f"Read error: {exc}", "recoverable": False, } + finally: + if fd >= 0: + os.close(fd) truncated = len(text) > _MAX_TEXT_LENGTH if truncated: diff --git a/packages/execution/app/tools/file_write.py b/packages/execution/app/tools/file_write.py index 70d87f7..3f1a847 100644 --- a/packages/execution/app/tools/file_write.py +++ b/packages/execution/app/tools/file_write.py @@ -46,7 +46,14 @@ def run(self, inputs: dict) -> dict: # Create parent directories parent = os.path.dirname(resolved) - os.makedirs(parent, exist_ok=True) + try: + os.makedirs(parent, exist_ok=True) + except OSError as exc: + return { + "success": False, + "error": f"Cannot create directories: {exc}", + "recoverable": False, + } # Open with O_NOFOLLOW to reject symlinks at the leaf flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW diff --git a/packages/execution/app/tools/url_fetch.py b/packages/execution/app/tools/url_fetch.py index 8269100..9d707c6 100644 --- a/packages/execution/app/tools/url_fetch.py +++ b/packages/execution/app/tools/url_fetch.py @@ -53,6 +53,9 @@ def validate_url(url: str) -> tuple[str | None, str | None]: if resolved_ip is None: resolved_ip = str(ip) + if resolved_ip is None: + return (f"No usable IP addresses for hostname: {hostname}", None) + return (None, resolved_ip) diff --git a/packages/execution/app/tools/web_search.py b/packages/execution/app/tools/web_search.py index fdbbd12..a794127 100644 --- a/packages/execution/app/tools/web_search.py +++ b/packages/execution/app/tools/web_search.py @@ -22,7 +22,10 @@ def run(self, inputs: dict) -> dict: "recoverable": False, } - max_results = min(int(inputs.get("max_results", 5)), _MAX_RESULTS) + try: + max_results = min(int(inputs.get("max_results", 5)), _MAX_RESULTS) + except (ValueError, TypeError): + max_results = 5 tavily_key = os.environ.get("TAVILY_API_KEY") if tavily_key: diff --git a/packages/execution/tests/unit/test_exporter.py b/packages/execution/tests/unit/test_exporter.py index c918018..3e14ed0 100644 --- a/packages/execution/tests/unit/test_exporter.py +++ b/packages/execution/tests/unit/test_exporter.py @@ -4,7 +4,9 @@ import ast -from app.exporter import export_graph +import pytest + +from app.exporter import ExportError, export_graph def _base_schema(**overrides): @@ -371,3 +373,48 @@ def test_export_complex_graph(): # Has checkpointer (human_input present) assert "InMemorySaver" in result["code"] + + +# ── Security: identifier injection prevention ──────────────────────── + + +def test_malicious_node_id_rejected(): + """node_id with code injection is rejected.""" + schema = _base_schema() + schema["nodes"].insert( + -1, + { + "id": "x(s):\n import os\ndef y", + "type": "tool", + "label": "Evil", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": "result", + }, + }, + ) + with pytest.raises(ExportError, match="Unsafe identifier"): + export_graph(schema) + + +def test_malicious_output_key_rejected(): + """output_key with injection characters is rejected.""" + schema = _base_schema() + schema["nodes"].insert( + -1, + { + "id": "tool_1", + "type": "tool", + "label": "Tool", + "position": {"x": 0, "y": 100}, + "config": { + "tool_name": "calculator", + "input_map": {"expression": "result"}, + "output_key": 'result"}\nimport os', + }, + }, + ) + with pytest.raises(ExportError, match="Unsafe identifier"): + export_graph(schema) diff --git a/packages/execution/tests/unit/test_routes_export.py b/packages/execution/tests/unit/test_routes_export.py index 414907c..589ccdb 100644 --- a/packages/execution/tests/unit/test_routes_export.py +++ b/packages/execution/tests/unit/test_routes_export.py @@ -115,3 +115,16 @@ async def test_export_route_not_found(client, api_key): _, raw = api_key resp = await client.get("/v1/graphs/nonexistent/export", headers=_headers(raw)) assert resp.status_code == 404 + + +async def test_export_route_wrong_owner(client, api_key): + """Export as different owner returns 404, not the graph code.""" + _, raw_a = api_key + gid = await _create_graph(client, raw_a) + + # Create a second user + db = app.state.db + _, raw_b = await create_test_key(db, scopes=SCOPES_DEFAULT, name="other") + + resp = await client.get(f"/v1/graphs/{gid}/export", headers=_headers(raw_b)) + assert resp.status_code == 404 From e0265a2a1a848dc05f180794ca323964bbe53b23 Mon Sep 17 00:00:00 2001 From: prosdev Date: Sat, 14 Mar 2026 21:21:32 -0700 Subject: [PATCH 4/5] docs: clarify tool response envelope in BaseTool docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Success responses include {success, result, source, truncated}. Error responses include {success, error, recoverable}. recoverable only applies to errors — matches schema spec and all existing tool implementations. Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/execution/app/tools/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/execution/app/tools/base.py b/packages/execution/app/tools/base.py index 5f533fc..0909899 100644 --- a/packages/execution/app/tools/base.py +++ b/packages/execution/app/tools/base.py @@ -20,7 +20,9 @@ def run(self, inputs: dict) -> dict: """Execute the tool. Returns: - Response envelope: { success, result/error, recoverable }. + Response envelope: + - Success: { success: True, result, source, truncated } + - Error: { success: False, error, recoverable } """ ... From 479406bafec6df89b5ced3befac0d7ff154563df Mon Sep 17 00:00:00 2001 From: prosdev Date: Sat, 14 Mar 2026 21:23:43 -0700 Subject: [PATCH 5/5] fix: validate edge condition_branch and add O_NOFOLLOW test - Exporter: validate condition_branch/label values from edges against safe identifier pattern (closes remaining injection gap). - Add test_symlink_inside_sandbox_blocked_by_onofollow to verify the O_NOFOLLOW defense layer independently from the realpath pre-check (platform-aware: macOS allows read-only symlinks). Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/execution/app/exporter.py | 4 ++++ .../tests/unit/test_tools/test_file_read.py | 24 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/packages/execution/app/exporter.py b/packages/execution/app/exporter.py index a497124..6f9a69d 100644 --- a/packages/execution/app/exporter.py +++ b/packages/execution/app/exporter.py @@ -108,6 +108,10 @@ def export_graph(schema: dict) -> dict: _validate_identifier(cfg["default_branch"], "default_branch") for f in state_fields: _validate_identifier(f["key"], "state field key") + for edge in edges: + branch = edge.get("condition_branch") or edge.get("label") + if branch: + _validate_identifier(branch, "edge condition_branch") nodes_by_id = {n["id"]: n for n in nodes} start_id = next((n["id"] for n in nodes if n["type"] == "start"), None) diff --git a/packages/execution/tests/unit/test_tools/test_file_read.py b/packages/execution/tests/unit/test_tools/test_file_read.py index 073fc39..3a92750 100644 --- a/packages/execution/tests/unit/test_tools/test_file_read.py +++ b/packages/execution/tests/unit/test_tools/test_file_read.py @@ -82,6 +82,30 @@ def test_symlink_escape(tmp_path): assert result["success"] is False +def test_symlink_inside_sandbox_blocked_by_onofollow(tmp_path): + """Symlink inside sandbox (realpath passes) is blocked by O_NOFOLLOW.""" + sandbox = tmp_path / "sandbox" + sandbox.mkdir() + target = sandbox / "real.txt" + target.write_text("real content", encoding="utf-8") + link = sandbox / "link.txt" + link.symlink_to(target) + + # realpath resolves link.txt → real.txt (inside sandbox), so + # the realpath check passes. O_NOFOLLOW on the leaf should reject it. + result = _read({"path": "link.txt"}, str(sandbox)) + # On macOS, O_NOFOLLOW + O_RDONLY may still follow symlinks. + # On Linux, O_NOFOLLOW rejects symlinks at the leaf. + # Either way, the read should either fail or return the real content + # (since the target is inside the sandbox, this is safe either way). + if result["success"]: + # macOS: O_NOFOLLOW doesn't block read-only symlinks + assert result["result"] == "real content" + else: + # Linux: O_NOFOLLOW blocks the symlink + assert "success" in result and result["success"] is False + + def test_binary_file_returns_error(tmp_path): (tmp_path / "binary.bin").write_bytes(b"\x80\x81\x82\xff") result = _read({"path": "binary.bin"}, str(tmp_path))