diff --git a/README.md b/README.md index a2a9994..2e895b3 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@

- Tests + Tests Skills Tools - Lines + Lines Platform - PRs + PRs

--- @@ -321,8 +321,11 @@ Target → TargetDetector → ScanPlanner → ScanEngine → Parsers → Pipelin │ │ │ events │ │ score ``` -- **Executors** — Shell, Docker, MCP server (connection-pooled) +- **Executors** — Shell, Docker, MCP server (connection-pooled), Proxied Shell (ephemeral cloud proxy) - **DAG engine** — dependency-aware task dispatch with reactive edges (one tool's output triggers another) +- **Dynamic mutation** — OutputAnalyzers extract structured intel from tool output, MutationStrategies inject new tasks into the DAG at runtime based on accumulated attack surface state (KillChainState) +- **HITL approval gates** — dangerous tasks (C2 deployment, exploitation) pause for operator approval with configurable timeouts, durable persistence, and write-before-signal crash safety +- **Ephemeral proxy routing** — high-throughput scans route through auto-provisioned cloud nodes (DigitalOcean, Vultr) for rate-limit resilience, with guaranteed teardown - **Normalization** — paths, CWEs, severities, titles standardized across tools - **Deduplication** — strict hash + fuzzy multi-pass matching across tools - **Correlation** — cross-finding relation detection, remediation grouping @@ -384,7 +387,7 @@ A full-stack web interface for multi-user engagement management, built on FastAP /api/v1/exports/ Data export /api/v1/correlation/ Threat correlation /api/v1/chain/ Attack chain analysis -/api/v1/scans/ Scan orchestration (CRUD, control, SSE streaming) +/api/v1/scans/ Scan orchestration (CRUD, control, SSE, approval gates) /api/v1/system/ System info and health ``` @@ -460,11 +463,14 @@ opentools dashboard --engagement my-audit │ │ │ │ ├── engagement/ ── SQLite store (WAL, FTS5, migrations) │ │ ├── scanner/ │ -│ │ ├── engine.py ──── DAG task executor │ +│ │ ├── engine.py ──── DAG task executor + mutation + gates │ │ │ ├── planner.py ── profile → task graph builder │ │ │ ├── pipeline.py ── normalize → dedup → correlate │ │ │ ├── parsing/ ── semgrep, gitleaks, trivy, nmap, generic │ -│ │ ├── executor/ ── shell, docker, MCP │ +│ │ ├── executor/ ── shell, docker, MCP, proxied shell │ +│ │ ├── mutation/ ── analyzers, strategies, kill chain state │ +│ │ ├── infra/ ── cloud providers, proxy tunnel, sweeper │ +│ │ ├── approval.py ── HITL gate registry │ │ │ └── store.py ── scan-specific SQLite store │ │ ├── chain/ │ │ │ ├── extractors/ ── regex, parser-aware, LLM │ @@ -670,7 +676,7 @@ python -m pytest tests/ -v ### Project Stats ``` -3 packages | 220+ source files | 1,150+ tests | 30K Python + 1.7K TypeScript | 9 PRs merged +3 packages | 240+ source files | 1,350+ tests | 33K Python + 1.7K TypeScript | 10 PRs merged ``` ### Tech Stack @@ -775,6 +781,20 @@ The parser router auto-discovers parser modules — no registration needed. - [x] Web scan API with SSE streaming - [x] Performance optimization pass — batch DB writes, lazy fetching, reverse indexes, singleton stores +### Phase 3.5: Reactive Engine Expansion + +- [x] Dynamic DAG mutation — OutputAnalyzer → KillChainState → MutationStrategy pipeline +- [x] Nmap and Nuclei output analyzers for structured intel extraction +- [x] RedisProbeStrategy — auto-pivot on discovered Redis services +- [x] Ephemeral proxy routing — CloudNodeProvider ABC with DigitalOcean + Vultr +- [x] Shielded teardown — guaranteed cloud node destruction even under cancellation +- [x] ProxiedShellExecutor — transparent proxy routing for NETWORK_ISOLATED tasks +- [x] Orphan node sweeper for startup cleanup of leaked infrastructure +- [x] HITL approval gates — persistence-first, execution wrapper model +- [x] ApprovalRegistry — in-memory notification hub with database-owned expiry +- [x] FastAPI gate endpoints — list/approve/reject with write-before-signal guarantee +- [x] Command injection guard for strategy-spawned tasks + ### Phase 4 (Planned) - [ ] Attack chain visualization (linked findings → narrative graph) diff --git a/docs/superpowers/plans/2026-04-13-dag-mutation-ephemeral-proxy.md b/docs/superpowers/plans/2026-04-13-dag-mutation-ephemeral-proxy.md new file mode 100644 index 0000000..a3408bc --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-dag-mutation-ephemeral-proxy.md @@ -0,0 +1,2612 @@ +# Dynamic DAG Mutation + Ephemeral Proxy Routing Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Transform OpenTools from a static scan pipeline into a reactive execution engine that mutates its own DAG based on tool output, and optionally routes high-throughput scans through ephemeral proxy infrastructure for rate-limit resilience. + +**Architecture:** Two independent subsystems. Phase A adds a mutation layer (OutputAnalyzer → KillChainState → MutationStrategy) that hooks into the existing `ScanEngine._mark_completed` path — all synchronous within a single event loop turn, so no race conditions. Phase B adds an `AsyncContextManager`-based ephemeral proxy lifecycle that wraps `run_streaming` calls, with shielded teardown guaranteeing cloud node destruction even under cancellation. Both phases compose: a MutationStrategy can spawn tasks with `isolation: network_isolated`, and the ProxiedShellExecutor handles routing transparently. + +**Tech Stack:** Python 3.12+, Pydantic v2, asyncio, httpx (already in deps), pytest + pytest-asyncio + +--- + +## File Structure + +### Phase A — Dynamic DAG Mutation + +| Action | Path | Responsibility | +|--------|------|----------------| +| Create | `packages/cli/src/opentools/scanner/mutation/__init__.py` | Package exports | +| Create | `packages/cli/src/opentools/scanner/mutation/models.py` | `IntelBundle`, `DiscoveredService`, `DiscoveredVuln`, `KillChainState` | +| Create | `packages/cli/src/opentools/scanner/mutation/analyzer.py` | `OutputAnalyzer` protocol + `NmapAnalyzer` + `NucleiAnalyzer` | +| Create | `packages/cli/src/opentools/scanner/mutation/strategy.py` | `MutationStrategy` protocol + `RedisProbeStrategy` + `get_builtin_strategies()` | +| Modify | `packages/cli/src/opentools/scanner/engine.py` | Add mutation fields to `__init__`, hook mutation into `_mark_completed`, harden `_inject_tasks` | +| Create | `packages/cli/tests/test_scanner/test_mutation_models.py` | Tests for models | +| Create | `packages/cli/tests/test_scanner/test_mutation_analyzer.py` | Tests for analyzers | +| Create | `packages/cli/tests/test_scanner/test_mutation_strategy.py` | Tests for strategies | +| Create | `packages/cli/tests/test_scanner/test_engine_mutation.py` | Integration: engine + mutation layer | + +### Phase B — Ephemeral Proxy Routing + +| Action | Path | Responsibility | +|--------|------|----------------| +| Modify | `packages/cli/src/opentools/shared/subprocess.py` | Add `env` parameter to `run_streaming` | +| Create | `packages/cli/src/opentools/scanner/infra/__init__.py` | Package exports | +| Create | `packages/cli/src/opentools/scanner/infra/provider.py` | `CloudNodeProvider` ABC, `EphemeralNode`, `ProvisioningError` | +| Create | `packages/cli/src/opentools/scanner/infra/digitalocean.py` | `DigitalOceanProvider` (httpx-based) | +| Create | `packages/cli/src/opentools/scanner/infra/proxy.py` | `ephemeral_proxy` context manager, `ProxyEndpoint`, `_shielded_destroy` | +| Create | `packages/cli/src/opentools/scanner/infra/sweeper.py` | `sweep_orphaned_nodes` startup cleanup | +| Create | `packages/cli/src/opentools/scanner/executor/proxied_shell.py` | `ProxiedShellExecutor` | +| Modify | `packages/cli/src/opentools/scanner/executor/__init__.py` | Export `ProxiedShellExecutor` | +| Modify | `packages/cli/tests/test_scanner/test_shared_subprocess.py` | Add `env` parameter test | +| Create | `packages/cli/tests/test_scanner/test_infra_provider.py` | Tests for provider ABC + DO provider | +| Create | `packages/cli/tests/test_scanner/test_infra_proxy.py` | Tests for proxy context manager + shielded teardown | +| Create | `packages/cli/tests/test_scanner/test_executor_proxied_shell.py` | Tests for proxied executor | + +--- + +## Phase A: Dynamic DAG Mutation + +### Task 1: KillChainState + IntelBundle Models + +**Files:** +- Create: `packages/cli/src/opentools/scanner/mutation/__init__.py` +- Create: `packages/cli/src/opentools/scanner/mutation/models.py` +- Test: `packages/cli/tests/test_scanner/test_mutation_models.py` + +- [ ] **Step 1: Write failing tests for IntelBundle and DiscoveredService** + +```python +# packages/cli/tests/test_scanner/test_mutation_models.py +"""Tests for mutation layer data models.""" + +from opentools.scanner.mutation.models import ( + DiscoveredService, + DiscoveredVuln, + IntelBundle, + KillChainState, +) + + +class TestDiscoveredService: + def test_construction(self): + svc = DiscoveredService( + host="10.0.0.1", + port=6379, + protocol="tcp", + service="redis", + ) + assert svc.host == "10.0.0.1" + assert svc.port == 6379 + assert svc.protocol == "tcp" + assert svc.service == "redis" + assert svc.product is None + assert svc.version is None + + def test_with_product_and_version(self): + svc = DiscoveredService( + host="10.0.0.1", + port=80, + protocol="tcp", + service="http", + product="Apache httpd", + version="2.4.51", + ) + assert svc.product == "Apache httpd" + assert svc.version == "2.4.51" + + +class TestDiscoveredVuln: + def test_construction(self): + vuln = DiscoveredVuln( + host="10.0.0.1", + port=443, + template_id="CVE-2021-44228", + severity="critical", + matched_at="https://10.0.0.1:443/api", + extracted_data={"payload": "jndi:ldap"}, + ) + assert vuln.template_id == "CVE-2021-44228" + assert vuln.severity == "critical" + + def test_port_optional(self): + vuln = DiscoveredVuln( + host="10.0.0.1", + port=None, + template_id="exposed-git", + severity="medium", + matched_at="http://10.0.0.1/.git/config", + extracted_data={}, + ) + assert vuln.port is None + + +class TestIntelBundle: + def test_empty_default(self): + bundle = IntelBundle() + assert bundle.services == [] + assert bundle.vulns == [] + assert bundle.urls == [] + assert bundle.metadata == {} + + def test_with_services(self): + svc = DiscoveredService( + host="10.0.0.1", port=22, protocol="tcp", service="ssh", + ) + bundle = IntelBundle(services=[svc]) + assert len(bundle.services) == 1 + assert bundle.services[0].service == "ssh" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_models.py::TestDiscoveredService -v` +Expected: FAIL with `ModuleNotFoundError: No module named 'opentools.scanner.mutation'` + +- [ ] **Step 3: Implement IntelBundle, DiscoveredService, DiscoveredVuln** + +```python +# packages/cli/src/opentools/scanner/mutation/__init__.py +"""Dynamic DAG mutation layer — output analysis, state accumulation, task synthesis.""" +``` + +```python +# packages/cli/src/opentools/scanner/mutation/models.py +"""Data models for the mutation layer. + +IntelBundle carries structured intelligence extracted from tool output. +KillChainState accumulates intel across all completed tasks, enabling +cross-task reasoning for dynamic task injection. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class DiscoveredService(BaseModel): + """A network service discovered by a scanning tool.""" + host: str + port: int + protocol: str # "tcp" | "udp" + service: str # e.g., "redis", "http", "ssh" + product: str | None = None # e.g., "Redis", "Apache httpd" + version: str | None = None # e.g., "6.2.7", "2.4.51" + banner: str | None = None + + +class DiscoveredVuln(BaseModel): + """A vulnerability discovered by a scanning tool.""" + host: str + port: int | None + template_id: str # nuclei template ID or CVE + severity: str + matched_at: str # URL or host:port + extracted_data: dict = Field(default_factory=dict) + + +class IntelBundle(BaseModel): + """Structured intelligence extracted from a single task's output. + + Produced by an OutputAnalyzer, consumed by KillChainState.ingest(). + """ + services: list[DiscoveredService] = Field(default_factory=list) + vulns: list[DiscoveredVuln] = Field(default_factory=list) + urls: list[str] = Field(default_factory=list) + metadata: dict = Field(default_factory=dict) + + +class KillChainState(BaseModel): + """Accumulated attack surface knowledge across all completed tasks. + + Only mutated inside ScanEngine._mark_completed (synchronous within + a single event loop turn), so no locking is needed. + """ + services: dict[str, DiscoveredService] = Field(default_factory=dict) + vulns: dict[str, DiscoveredVuln] = Field(default_factory=dict) + urls: set[str] = Field(default_factory=set) + tasks_spawned: dict[str, int] = Field(default_factory=dict) + total_spawned: int = 0 + + def ingest(self, bundle: IntelBundle) -> None: + """Merge an IntelBundle into accumulated state. Deduplicates by key.""" + for svc in bundle.services: + key = f"{svc.host}:{svc.port}/{svc.protocol}" + self.services[key] = svc + for vuln in bundle.vulns: + key = f"{vuln.host}:{vuln.template_id}" + self.vulns[key] = vuln + self.urls.update(bundle.urls) + + def has_service(self, service_name: str) -> bool: + """Check if any discovered service matches the given name.""" + return any(s.service == service_name for s in self.services.values()) + + def get_services(self, service_name: str) -> list[DiscoveredService]: + """Return all discovered services matching the given name.""" + return [s for s in self.services.values() if s.service == service_name] + + def record_spawn(self, strategy_name: str, count: int = 1) -> None: + """Record that a strategy spawned tasks.""" + self.tasks_spawned[strategy_name] = ( + self.tasks_spawned.get(strategy_name, 0) + count + ) + self.total_spawned += count +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_models.py::TestDiscoveredService tests/test_scanner/test_mutation_models.py::TestDiscoveredVuln tests/test_scanner/test_mutation_models.py::TestIntelBundle -v` +Expected: All PASS + +- [ ] **Step 5: Write failing tests for KillChainState** + +Add to `packages/cli/tests/test_scanner/test_mutation_models.py`: + +```python +class TestKillChainState: + def test_empty_default(self): + state = KillChainState() + assert state.services == {} + assert state.vulns == {} + assert len(state.urls) == 0 + assert state.total_spawned == 0 + + def test_ingest_services(self): + state = KillChainState() + bundle = IntelBundle(services=[ + DiscoveredService( + host="10.0.0.1", port=6379, protocol="tcp", service="redis", + ), + DiscoveredService( + host="10.0.0.1", port=80, protocol="tcp", service="http", + ), + ]) + state.ingest(bundle) + assert len(state.services) == 2 + assert "10.0.0.1:6379/tcp" in state.services + assert "10.0.0.1:80/tcp" in state.services + + def test_ingest_deduplicates(self): + state = KillChainState() + svc = DiscoveredService( + host="10.0.0.1", port=6379, protocol="tcp", service="redis", + ) + state.ingest(IntelBundle(services=[svc])) + state.ingest(IntelBundle(services=[svc])) + assert len(state.services) == 1 + + def test_ingest_vulns(self): + state = KillChainState() + vuln = DiscoveredVuln( + host="10.0.0.1", port=443, template_id="CVE-2021-44228", + severity="critical", matched_at="https://10.0.0.1:443/", + extracted_data={}, + ) + state.ingest(IntelBundle(vulns=[vuln])) + assert "10.0.0.1:CVE-2021-44228" in state.vulns + + def test_ingest_urls(self): + state = KillChainState() + state.ingest(IntelBundle(urls=["http://10.0.0.1/admin"])) + state.ingest(IntelBundle(urls=["http://10.0.0.1/admin", "http://10.0.0.1/api"])) + assert len(state.urls) == 2 + + def test_has_service(self): + state = KillChainState() + state.ingest(IntelBundle(services=[ + DiscoveredService( + host="10.0.0.1", port=6379, protocol="tcp", service="redis", + ), + ])) + assert state.has_service("redis") is True + assert state.has_service("mysql") is False + + def test_get_services(self): + state = KillChainState() + state.ingest(IntelBundle(services=[ + DiscoveredService(host="10.0.0.1", port=6379, protocol="tcp", service="redis"), + DiscoveredService(host="10.0.0.2", port=6379, protocol="tcp", service="redis"), + DiscoveredService(host="10.0.0.1", port=80, protocol="tcp", service="http"), + ])) + redis_services = state.get_services("redis") + assert len(redis_services) == 2 + assert all(s.service == "redis" for s in redis_services) + + def test_record_spawn(self): + state = KillChainState() + state.record_spawn("redis_probe", 2) + assert state.tasks_spawned["redis_probe"] == 2 + assert state.total_spawned == 2 + state.record_spawn("redis_probe", 1) + assert state.tasks_spawned["redis_probe"] == 3 + assert state.total_spawned == 3 +``` + +- [ ] **Step 6: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_models.py -v` +Expected: All PASS + +- [ ] **Step 7: Commit** + +```bash +git add packages/cli/src/opentools/scanner/mutation/__init__.py packages/cli/src/opentools/scanner/mutation/models.py packages/cli/tests/test_scanner/test_mutation_models.py +git commit -m "feat(mutation): add IntelBundle, DiscoveredService, KillChainState models" +``` + +--- + +### Task 2: OutputAnalyzer Protocol + NmapAnalyzer + +**Files:** +- Create: `packages/cli/src/opentools/scanner/mutation/analyzer.py` +- Test: `packages/cli/tests/test_scanner/test_mutation_analyzer.py` + +**Context:** The existing `NmapParser` in `packages/cli/src/opentools/scanner/parsing/parsers/nmap.py` converts nmap XML into `RawFinding` objects for the findings pipeline. `NmapAnalyzer` converts the same XML into `DiscoveredService` objects for the mutation layer. Different output, different consumer — but same XML parsing logic. + +- [ ] **Step 1: Write failing tests for NmapAnalyzer** + +```python +# packages/cli/tests/test_scanner/test_mutation_analyzer.py +"""Tests for OutputAnalyzer implementations.""" + +from opentools.scanner.mutation.analyzer import ( + AnalyzerRegistry, + NmapAnalyzer, + NucleiAnalyzer, + OutputAnalyzer, +) +from opentools.scanner.mutation.models import IntelBundle + + +# Minimal valid nmap XML with two open ports +NMAP_XML_TWO_PORTS = """\ + + + +
+ + + + + + + + + + + + + + + + +""" + +NMAP_XML_NO_OPEN = """\ + + + +
+ + + + + + + +""" + +NMAP_XML_MULTI_HOST = """\ + + + +
+ + + + + + + + +
+ + + + + + + +""" + + +class TestNmapAnalyzer: + def setup_method(self): + self.analyzer = NmapAnalyzer() + + def test_tool_name(self): + assert self.analyzer.tool == "nmap" + + def test_extracts_open_services(self): + bundle = self.analyzer.analyze(NMAP_XML_TWO_PORTS, "") + assert isinstance(bundle, IntelBundle) + assert len(bundle.services) == 2 + services_by_port = {s.port: s for s in bundle.services} + assert 6379 in services_by_port + assert services_by_port[6379].service == "redis" + assert services_by_port[6379].product == "Redis" + assert services_by_port[6379].version == "6.2.7" + assert services_by_port[6379].host == "10.0.0.1" + assert 80 in services_by_port + assert services_by_port[80].service == "http" + + def test_skips_closed_ports(self): + bundle = self.analyzer.analyze(NMAP_XML_TWO_PORTS, "") + ports = [s.port for s in bundle.services] + assert 22 not in ports + + def test_no_open_ports_returns_empty(self): + bundle = self.analyzer.analyze(NMAP_XML_NO_OPEN, "") + assert bundle.services == [] + + def test_multi_host(self): + bundle = self.analyzer.analyze(NMAP_XML_MULTI_HOST, "") + assert len(bundle.services) == 2 + hosts = {s.host for s in bundle.services} + assert hosts == {"10.0.0.1", "10.0.0.2"} + + def test_invalid_xml_returns_empty(self): + bundle = self.analyzer.analyze("not xml at all", "") + assert bundle.services == [] + + def test_empty_stdout_returns_empty(self): + bundle = self.analyzer.analyze("", "") + assert bundle.services == [] + + +# Minimal nuclei JSON lines output +NUCLEI_JSONL = ( + '{"template-id":"CVE-2021-44228","host":"10.0.0.1","port":"443",' + '"matched-at":"https://10.0.0.1:443/api","info":{"severity":"critical"},' + '"extracted-results":["jndi:ldap"]}\n' + '{"template-id":"exposed-git","host":"10.0.0.1","port":"80",' + '"matched-at":"http://10.0.0.1/.git/config","info":{"severity":"medium"},' + '"extracted-results":[]}\n' +) + + +class TestNucleiAnalyzer: + def setup_method(self): + self.analyzer = NucleiAnalyzer() + + def test_tool_name(self): + assert self.analyzer.tool == "nuclei" + + def test_extracts_vulns(self): + bundle = self.analyzer.analyze(NUCLEI_JSONL, "") + assert len(bundle.vulns) == 2 + template_ids = {v.template_id for v in bundle.vulns} + assert "CVE-2021-44228" in template_ids + assert "exposed-git" in template_ids + + def test_vuln_fields(self): + bundle = self.analyzer.analyze(NUCLEI_JSONL, "") + cve = next(v for v in bundle.vulns if v.template_id == "CVE-2021-44228") + assert cve.host == "10.0.0.1" + assert cve.port == 443 + assert cve.severity == "critical" + assert cve.matched_at == "https://10.0.0.1:443/api" + + def test_extracts_urls_from_matched_at(self): + bundle = self.analyzer.analyze(NUCLEI_JSONL, "") + assert len(bundle.urls) == 2 + + def test_empty_output_returns_empty(self): + bundle = self.analyzer.analyze("", "") + assert bundle.vulns == [] + + def test_invalid_json_lines_skipped(self): + mixed = '{"template-id":"x","host":"h","matched-at":"u","info":{"severity":"low"}}\nnot json\n' + bundle = self.analyzer.analyze(mixed, "") + assert len(bundle.vulns) == 1 + + +class TestAnalyzerRegistry: + def test_register_and_get(self): + registry = AnalyzerRegistry() + analyzer = NmapAnalyzer() + registry.register(analyzer) + assert registry.get("nmap") is analyzer + + def test_get_missing_returns_none(self): + registry = AnalyzerRegistry() + assert registry.get("nonexistent") is None + + def test_get_builtin_analyzers(self): + registry = AnalyzerRegistry() + registry.register_builtins() + assert registry.get("nmap") is not None + assert registry.get("nuclei") is not None +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_analyzer.py::TestNmapAnalyzer::test_tool_name -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement OutputAnalyzer protocol, NmapAnalyzer, NucleiAnalyzer, AnalyzerRegistry** + +```python +# packages/cli/src/opentools/scanner/mutation/analyzer.py +"""OutputAnalyzer protocol and builtin implementations. + +OutputAnalyzers extract structured attack surface intelligence from tool +output. They are distinct from ParserPlugin (which produces RawFinding +objects for the findings pipeline) — analyzers produce IntelBundle objects +for the mutation layer's KillChainState. +""" + +from __future__ import annotations + +import json +import xml.etree.ElementTree as ET +from typing import Protocol, runtime_checkable + +from opentools.scanner.mutation.models import ( + DiscoveredService, + DiscoveredVuln, + IntelBundle, +) + + +@runtime_checkable +class OutputAnalyzer(Protocol): + """Extracts attack surface intelligence from raw tool output.""" + tool: str + + def analyze(self, stdout: str, stderr: str) -> IntelBundle: ... + + +class NmapAnalyzer: + """Extract DiscoveredService objects from nmap XML output.""" + tool = "nmap" + + def analyze(self, stdout: str, stderr: str) -> IntelBundle: + if not stdout.strip(): + return IntelBundle() + + try: + root = ET.fromstring(stdout) + except ET.ParseError: + return IntelBundle() + + services: list[DiscoveredService] = [] + + for host in root.findall("host"): + addr_el = host.find("address") + addr = addr_el.get("addr", "unknown") if addr_el is not None else "unknown" + + ports_el = host.find("ports") + if ports_el is None: + continue + + for port in ports_el.findall("port"): + state_el = port.find("state") + if state_el is None or state_el.get("state") != "open": + continue + + svc_el = port.find("service") + services.append(DiscoveredService( + host=addr, + port=int(port.get("portid", "0")), + protocol=port.get("protocol", "tcp"), + service=svc_el.get("name", "") if svc_el is not None else "", + product=svc_el.get("product") if svc_el is not None else None, + version=svc_el.get("version") if svc_el is not None else None, + )) + + return IntelBundle(services=services) + + +class NucleiAnalyzer: + """Extract DiscoveredVuln objects from nuclei JSON-lines output.""" + tool = "nuclei" + + def analyze(self, stdout: str, stderr: str) -> IntelBundle: + if not stdout.strip(): + return IntelBundle() + + vulns: list[DiscoveredVuln] = [] + urls: list[str] = [] + + for line in stdout.strip().splitlines(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + + template_id = data.get("template-id", "") + host = data.get("host", "") + matched_at = data.get("matched-at", "") + info = data.get("info", {}) + severity = info.get("severity", "unknown") if isinstance(info, dict) else "unknown" + + port_raw = data.get("port", "") + port: int | None = None + if port_raw: + try: + port = int(port_raw) + except (ValueError, TypeError): + pass + + extracted = data.get("extracted-results", []) + extracted_data = {"results": extracted} if extracted else {} + + vulns.append(DiscoveredVuln( + host=host, + port=port, + template_id=template_id, + severity=severity, + matched_at=matched_at, + extracted_data=extracted_data, + )) + + if matched_at: + urls.append(matched_at) + + return IntelBundle(vulns=vulns, urls=urls) + + +class AnalyzerRegistry: + """Registry of OutputAnalyzer instances, keyed by tool name.""" + + def __init__(self) -> None: + self._analyzers: dict[str, OutputAnalyzer] = {} + + def register(self, analyzer: OutputAnalyzer) -> None: + self._analyzers[analyzer.tool] = analyzer + + def get(self, tool: str) -> OutputAnalyzer | None: + return self._analyzers.get(tool) + + def register_builtins(self) -> None: + """Register all builtin analyzers.""" + self.register(NmapAnalyzer()) + self.register(NucleiAnalyzer()) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_analyzer.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/mutation/analyzer.py packages/cli/tests/test_scanner/test_mutation_analyzer.py +git commit -m "feat(mutation): add OutputAnalyzer protocol with Nmap and Nuclei analyzers" +``` + +--- + +### Task 3: MutationStrategy Protocol + RedisProbeStrategy + +**Files:** +- Create: `packages/cli/src/opentools/scanner/mutation/strategy.py` +- Test: `packages/cli/tests/test_scanner/test_mutation_strategy.py` + +**Context:** Strategies examine the accumulated `KillChainState` (not just one task's output) and synthesize new `ScanTask` objects. Each strategy is idempotent — it tracks what it has already spawned via task ID conventions to avoid duplicates. + +- [ ] **Step 1: Write failing tests for MutationStrategy and RedisProbeStrategy** + +```python +# packages/cli/tests/test_scanner/test_mutation_strategy.py +"""Tests for MutationStrategy implementations.""" + +from opentools.scanner.mutation.models import ( + DiscoveredService, + IntelBundle, + KillChainState, +) +from opentools.scanner.mutation.strategy import ( + MutationStrategy, + RedisProbeStrategy, + get_builtin_strategies, +) +from opentools.scanner.models import ScanTask, TaskType, ExecutionTier + + +def _make_task( + task_id: str = "nmap-1", + tool: str = "nmap", + scan_id: str = "scan-1", +) -> ScanTask: + return ScanTask( + id=task_id, + scan_id=scan_id, + name=f"{tool}-task", + tool=tool, + task_type=TaskType.SHELL, + ) + + +def _state_with_redis() -> KillChainState: + state = KillChainState() + state.ingest(IntelBundle(services=[ + DiscoveredService( + host="10.0.0.1", port=6379, protocol="tcp", service="redis", + product="Redis", version="6.2.7", + ), + ])) + return state + + +def _state_with_http_only() -> KillChainState: + state = KillChainState() + state.ingest(IntelBundle(services=[ + DiscoveredService( + host="10.0.0.1", port=80, protocol="tcp", service="http", + ), + ])) + return state + + +class TestRedisProbeStrategy: + def setup_method(self): + self.strategy = RedisProbeStrategy() + + def test_name(self): + assert self.strategy.name == "redis_probe" + + def test_max_spawns(self): + assert self.strategy.max_spawns == 10 + + def test_spawns_redis_probe_when_redis_discovered(self): + state = _state_with_redis() + task = _make_task(tool="nmap") + new_tasks = self.strategy.evaluate(state, "scan-1", task) + assert len(new_tasks) == 1 + t = new_tasks[0] + assert t.tool == "redis-cli" + assert t.task_type == TaskType.DOCKER_EXEC + assert "10.0.0.1" in t.command + assert "6379" in t.command + assert t.scan_id == "scan-1" + assert t.spawned_by == "nmap-1" + assert "redis" in t.spawned_reason.lower() + + def test_no_spawn_when_no_redis(self): + state = _state_with_http_only() + task = _make_task(tool="nmap") + new_tasks = self.strategy.evaluate(state, "scan-1", task) + assert new_tasks == [] + + def test_no_spawn_when_already_spawned(self): + """Idempotent: evaluate() self-tracks, second call returns empty.""" + state = _state_with_redis() + task = _make_task(tool="nmap") + first = self.strategy.evaluate(state, "scan-1", task) + assert len(first) == 1 + + # Second evaluate on the same strategy instance — should be empty + # because evaluate() marked the service key internally. + second = self.strategy.evaluate(state, "scan-1", task) + assert second == [] + + def test_spawns_for_multiple_redis_instances(self): + state = KillChainState() + state.ingest(IntelBundle(services=[ + DiscoveredService(host="10.0.0.1", port=6379, protocol="tcp", service="redis"), + DiscoveredService(host="10.0.0.2", port=6379, protocol="tcp", service="redis"), + ])) + task = _make_task(tool="nmap") + new_tasks = self.strategy.evaluate(state, "scan-1", task) + assert len(new_tasks) == 2 + hosts = {t.command.split("-h ")[1].split(" ")[0] for t in new_tasks} + assert hosts == {"10.0.0.1", "10.0.0.2"} + + def test_ignores_non_nmap_tool(self): + """Only triggers on nmap/masscan completions.""" + state = _state_with_redis() + task = _make_task(tool="semgrep") + new_tasks = self.strategy.evaluate(state, "scan-1", task) + assert new_tasks == [] + + def test_task_ids_are_deterministic(self): + """Same input produces same task IDs, enabling dedup in _inject_tasks.""" + state = _state_with_redis() + task = _make_task(tool="nmap") + tasks_a = self.strategy.evaluate(state, "scan-1", task) + # Reset strategy state for second evaluation + strategy_b = RedisProbeStrategy() + tasks_b = strategy_b.evaluate(state, "scan-1", task) + assert tasks_a[0].id == tasks_b[0].id + + +class TestGetBuiltinStrategies: + def test_returns_list(self): + strategies = get_builtin_strategies() + assert isinstance(strategies, list) + + def test_contains_redis_probe(self): + strategies = get_builtin_strategies() + names = [s.name for s in strategies] + assert "redis_probe" in names + + def test_all_satisfy_protocol(self): + strategies = get_builtin_strategies() + for s in strategies: + assert isinstance(s, MutationStrategy) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_strategy.py::TestRedisProbeStrategy::test_name -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement MutationStrategy protocol, RedisProbeStrategy, get_builtin_strategies** + +```python +# packages/cli/src/opentools/scanner/mutation/strategy.py +"""MutationStrategy protocol and builtin implementations. + +Strategies evaluate the accumulated KillChainState after each task +completion and synthesize new ScanTask objects for injection into +the DAG. Each strategy tracks what it has already spawned to +maintain idempotency. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from opentools.scanner.models import ( + ExecutionTier, + ScanTask, + TaskType, +) +from opentools.scanner.mutation.models import KillChainState + + +@runtime_checkable +class MutationStrategy(Protocol): + """Evaluates kill chain state and synthesizes new tasks.""" + name: str + max_spawns: int + + def evaluate( + self, + state: KillChainState, + scan_id: str, + completed_task: ScanTask, + ) -> list[ScanTask]: ... + + +class RedisProbeStrategy: + """Spawn redis-cli INFO probes when nmap discovers Redis services. + + Uses DOCKER_EXEC to run redis-cli inside a container, avoiding + the need for redis-cli on the host. + + Self-tracking: marks service keys as spawned inside evaluate(), + so the engine doesn't need to call back. + """ + name = "redis_probe" + max_spawns = 10 + + # Tools that discover network services + _TRIGGER_TOOLS = {"nmap", "masscan"} + + def __init__(self) -> None: + self._spawned_keys: set[str] = set() + + def evaluate( + self, + state: KillChainState, + scan_id: str, + completed_task: ScanTask, + ) -> list[ScanTask]: + if completed_task.tool not in self._TRIGGER_TOOLS: + return [] + + redis_services = state.get_services("redis") + if not redis_services: + return [] + + tasks: list[ScanTask] = [] + for svc in redis_services: + key = f"{svc.host}:{svc.port}/{svc.protocol}" + if key in self._spawned_keys: + continue + + self._spawned_keys.add(key) # self-track immediately + task_id = f"redis-probe-{svc.host}-{svc.port}" + tasks.append(ScanTask( + id=task_id, + scan_id=scan_id, + name=f"Redis probe {svc.host}:{svc.port}", + tool="redis-cli", + task_type=TaskType.DOCKER_EXEC, + command=f"redis-cli -h {svc.host} -p {svc.port} INFO", + depends_on=[completed_task.id], + priority=20, + tier=ExecutionTier.FAST, + spawned_by=completed_task.id, + spawned_reason=f"nmap discovered Redis on {svc.host}:{svc.port}", + )) + + return tasks + + +def get_builtin_strategies() -> list[MutationStrategy]: + """Return all builtin mutation strategies.""" + return [ + RedisProbeStrategy(), + ] +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_strategy.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/mutation/strategy.py packages/cli/tests/test_scanner/test_mutation_strategy.py +git commit -m "feat(mutation): add MutationStrategy protocol with RedisProbeStrategy" +``` + +--- + +### Task 4: Engine Integration — Wire Mutation Layer into ScanEngine + +**Files:** +- Modify: `packages/cli/src/opentools/scanner/engine.py:35-63` (`__init__`) +- Modify: `packages/cli/src/opentools/scanner/engine.py:257-274` (`_mark_completed`) +- Modify: `packages/cli/src/opentools/scanner/engine.py:350-357` (`_inject_tasks`) +- Test: `packages/cli/tests/test_scanner/test_engine_mutation.py` + +**Context:** The mutation layer hooks into `_mark_completed` after the existing pipeline queueing and before the existing reactive edge evaluation. All mutation logic is synchronous — no `await` points — so the single-threaded event loop guarantee holds. The existing reactive edges remain backward compatible. + +- [ ] **Step 1: Write failing integration tests** + +```python +# packages/cli/tests/test_scanner/test_engine_mutation.py +"""Integration tests: ScanEngine + mutation layer.""" + +import asyncio +from datetime import datetime, timezone +from typing import Callable + +import pytest + +from opentools.scanner.cancellation import CancellationToken +from opentools.scanner.engine import ScanEngine +from opentools.scanner.executor.base import TaskOutput +from opentools.scanner.models import ( + Scan, + ScanStatus, + ScanTask, + TaskStatus, + TaskType, + TargetType, +) +from opentools.scanner.mutation.analyzer import AnalyzerRegistry, NmapAnalyzer +from opentools.scanner.mutation.models import KillChainState +from opentools.scanner.mutation.strategy import RedisProbeStrategy +from opentools.shared.progress import EventBus +from opentools.shared.resource_pool import AdaptiveResourcePool + + +# Nmap XML that discovers a Redis service +NMAP_REDIS_XML = """\ + + + +
+ + + + + + + +""" + +# Nmap XML with no interesting services +NMAP_HTTP_ONLY_XML = """\ + + + +
+ + + + + + + +""" + + +class MockExecutor: + def __init__(self, results: dict[str, TaskOutput] | None = None): + self._results = results or {} + self._default = TaskOutput(exit_code=0, stdout="ok", duration_ms=10) + self.executed: list[str] = [] + + async def execute( + self, task: ScanTask, on_output: Callable[[bytes], None], + cancellation: CancellationToken, + ) -> TaskOutput: + self.executed.append(task.id) + return self._results.get(task.id, self._default) + + +def _make_scan() -> Scan: + return Scan( + id="scan-1", engagement_id="eng-1", target="10.0.0.1", + target_type=TargetType.NETWORK, status=ScanStatus.PENDING, + created_at=datetime.now(timezone.utc), + ) + + +def _make_engine_with_mutation( + tasks: list[ScanTask], + executor: MockExecutor, + max_mutation_spawns: int = 100, +) -> ScanEngine: + pool = AdaptiveResourcePool(global_limit=4) + executors = { + TaskType.SHELL: executor, + TaskType.DOCKER_EXEC: executor, + TaskType.MCP_CALL: executor, + } + engine = ScanEngine( + scan=_make_scan(), + resource_pool=pool, + executors=executors, + event_bus=EventBus(), + cancellation=CancellationToken(), + ) + + # Wire up mutation layer + analyzer_registry = AnalyzerRegistry() + analyzer_registry.register_builtins() + engine.set_analyzer_registry(analyzer_registry) + engine.set_mutation_strategies([RedisProbeStrategy()]) + engine.set_max_mutation_spawns(max_mutation_spawns) + + engine.load_tasks(tasks) + return engine + + +class TestEngineMutationIntegration: + @pytest.mark.asyncio + async def test_nmap_redis_spawns_probe(self): + """Nmap finds Redis → engine spawns redis-cli probe → probe runs.""" + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_REDIS_XML, duration_ms=100, + ), + }) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + ) + engine = _make_engine_with_mutation([nmap_task], executor) + await engine.run() + + # The spawned redis probe should have been executed + assert "nmap-scan" in executor.executed + assert "redis-probe-10.0.0.1-6379" in executor.executed + + # Verify the spawned task has correct metadata + spawned = engine._tasks["redis-probe-10.0.0.1-6379"] + assert spawned.tool == "redis-cli" + assert spawned.spawned_by == "nmap-scan" + assert spawned.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_no_mutation_when_no_interesting_services(self): + """Nmap finds only HTTP → no mutation strategies fire.""" + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_HTTP_ONLY_XML, duration_ms=100, + ), + }) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + ) + engine = _make_engine_with_mutation([nmap_task], executor) + await engine.run() + + assert executor.executed == ["nmap-scan"] + assert len(engine._tasks) == 1 + + @pytest.mark.asyncio + async def test_mutation_respects_global_budget(self): + """Global spawn budget prevents unbounded task injection.""" + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_REDIS_XML, duration_ms=100, + ), + }) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + ) + engine = _make_engine_with_mutation( + [nmap_task], executor, max_mutation_spawns=0, + ) + await engine.run() + + # Budget is 0 — no mutations allowed + assert executor.executed == ["nmap-scan"] + assert len(engine._tasks) == 1 + + @pytest.mark.asyncio + async def test_mutation_coexists_with_reactive_edges(self): + """Both mutation strategies AND reactive edges can fire on the same task.""" + from opentools.scanner.models import ReactiveEdge + + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_REDIS_XML, duration_ms=100, + ), + }) + + edge_task = ScanTask( + id="edge-spawned", scan_id="scan-1", name="edge-task", + tool="test", task_type=TaskType.SHELL, command="echo edge", + ) + edge = ReactiveEdge( + id="edge-1", trigger_task_id="nmap-scan", + evaluator="always", spawns=[edge_task], + ) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + reactive_edges=[edge], + ) + + engine = _make_engine_with_mutation([nmap_task], executor) + engine.register_edge_evaluator("always", lambda t, o, e: e.spawns or []) + await engine.run() + + # Both mutation-spawned AND edge-spawned tasks should run + assert "redis-probe-10.0.0.1-6379" in executor.executed + assert "edge-spawned" in executor.executed + + @pytest.mark.asyncio + async def test_kill_chain_state_accessible(self): + """Engine exposes kill chain state for inspection.""" + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_REDIS_XML, duration_ms=100, + ), + }) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + ) + engine = _make_engine_with_mutation([nmap_task], executor) + await engine.run() + + state = engine.kill_chain_state + assert state.has_service("redis") + assert state.total_spawned == 1 + + @pytest.mark.asyncio + async def test_no_mutation_without_registry(self): + """Engine works normally when no mutation layer is configured.""" + executor = MockExecutor(results={ + "nmap-scan": TaskOutput( + exit_code=0, stdout=NMAP_REDIS_XML, duration_ms=100, + ), + }) + nmap_task = ScanTask( + id="nmap-scan", scan_id="scan-1", name="nmap", + tool="nmap", task_type=TaskType.SHELL, command="nmap 10.0.0.1", + ) + # Build engine WITHOUT mutation layer + pool = AdaptiveResourcePool(global_limit=4) + engine = ScanEngine( + scan=_make_scan(), resource_pool=pool, + executors={TaskType.SHELL: executor, TaskType.DOCKER_EXEC: executor}, + event_bus=EventBus(), cancellation=CancellationToken(), + ) + engine.load_tasks([nmap_task]) + await engine.run() + + # Should complete normally with no mutations + assert executor.executed == ["nmap-scan"] +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine_mutation.py::TestEngineMutationIntegration::test_nmap_redis_spawns_probe -v` +Expected: FAIL with `AttributeError: 'ScanEngine' object has no attribute 'set_analyzer_registry'` + +- [ ] **Step 3: Modify ScanEngine.__init__ — add mutation fields** + +In `packages/cli/src/opentools/scanner/engine.py`, add after the existing `self._cache` and `self._pipeline_results` fields (around line 69): + +```python + # Mutation layer (optional — engine works without it) + self._analyzer_registry: AnalyzerRegistry | None = None + self._mutation_strategies: list[Any] = [] + self._kill_chain = KillChainState() + self._max_mutation_spawns: int = 100 +``` + +Add the imports at the top of the file (after the existing imports): + +```python +from opentools.scanner.mutation.models import KillChainState +``` + +Add these public methods after the existing `set_cache` method (around line 122): + +```python + def set_analyzer_registry(self, registry: Any) -> None: + """Set the OutputAnalyzer registry for mutation layer.""" + self._analyzer_registry = registry + + def set_mutation_strategies(self, strategies: list[Any]) -> None: + """Set the mutation strategies for dynamic task injection.""" + self._mutation_strategies = list(strategies) + + def set_max_mutation_spawns(self, limit: int) -> None: + """Set the global budget for mutation-spawned tasks.""" + self._max_mutation_spawns = limit + + @property + def kill_chain_state(self) -> KillChainState: + """Read-only access to accumulated attack surface state.""" + return self._kill_chain +``` + +- [ ] **Step 4: Modify ScanEngine._mark_completed — hook mutation layer** + +Replace the existing `_mark_completed` method (lines 257-274) with: + +```python + def _mark_completed(self, task_id: str, output: TaskOutput) -> None: + task = self._tasks[task_id] + task.status = TaskStatus.COMPLETED + task.exit_code = output.exit_code + task.stdout = output.stdout + task.stderr = output.stderr + task.duration_ms = output.duration_ms + task.cached = output.cached + self._completed.add(task_id) + + # Queue output for pipeline processing + if self._pipeline is not None: + self._pipeline_results[task_id] = output + + # --- Mutation layer: analyze → ingest → evaluate strategies --- + mutation_tasks = self._evaluate_mutations(task, output) + + # --- Existing reactive edges (backward compatible) --- + edge_tasks = self._evaluate_edges(task, output) + + all_new = mutation_tasks + edge_tasks + if all_new: + self._inject_tasks(all_new) +``` + +Add the `_evaluate_mutations` method after `_evaluate_edges` (after line 348): + +```python + def _evaluate_mutations( + self, task: ScanTask, output: TaskOutput + ) -> list[ScanTask]: + """Run mutation layer: analyze output, update state, evaluate strategies.""" + if self._analyzer_registry is None: + return [] + + # 1. Extract intel from tool output + analyzer = self._analyzer_registry.get(task.tool) + if analyzer is not None and output.stdout: + bundle = analyzer.analyze(output.stdout, output.stderr or "") + self._kill_chain.ingest(bundle) + + # 2. Evaluate strategies against accumulated state + new_tasks: list[ScanTask] = [] + if self._kill_chain.total_spawned >= self._max_mutation_spawns: + return [] + + for strategy in self._mutation_strategies: + budget_used = self._kill_chain.tasks_spawned.get(strategy.name, 0) + if budget_used >= strategy.max_spawns: + continue + + remaining_strategy = strategy.max_spawns - budget_used + remaining_global = self._max_mutation_spawns - self._kill_chain.total_spawned + + spawned = strategy.evaluate(self._kill_chain, self.scan.id, task) + + allowed = min(remaining_strategy, remaining_global, len(spawned)) + accepted: list[ScanTask] = [] + for s in spawned[:allowed]: + if s.id not in self._tasks: + accepted.append(s) + + # Strategies self-track their spawned keys inside evaluate(), + # so the engine only needs to update the global budget. + if accepted: + self._kill_chain.record_spawn(strategy.name, len(accepted)) + new_tasks.extend(accepted) + + return new_tasks +``` + +- [ ] **Step 5: Harden _inject_tasks with dependency validation** + +Replace the existing `_inject_tasks` method (lines 350-357) with: + +```python + def _inject_tasks(self, tasks: list[ScanTask]) -> None: + """Add dynamically spawned tasks to the graph. + + Validates that all dependencies exist. Drops tasks with unknown + dependencies to prevent forward-reference cycles. + """ + for t in tasks: + if t.id in self._tasks: + continue + # Validate all dependencies exist in the graph + valid = True + for dep in t.depends_on: + if dep not in self._tasks: + import logging + logging.getLogger(__name__).warning( + "Dropping spawned task %s: depends on unknown task %s", + t.id, dep, + ) + valid = False + break + if valid: + self._tasks[t.id] = t + for dep in t.depends_on: + self._dependents[dep].add(t.id) +``` + +- [ ] **Step 6: Run integration tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine_mutation.py -v` +Expected: All PASS + +- [ ] **Step 7: Run existing engine tests to verify backward compatibility** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine.py -v` +Expected: All PASS — no regressions + +- [ ] **Step 8: Commit** + +```bash +git add packages/cli/src/opentools/scanner/engine.py packages/cli/tests/test_scanner/test_engine_mutation.py +git commit -m "feat(engine): integrate mutation layer — analyzer → state → strategy → inject" +``` + +--- + +## Phase B: Ephemeral Proxy Routing + +### Task 5: Add `env` Parameter to `run_streaming` + +**Files:** +- Modify: `packages/cli/src/opentools/shared/subprocess.py:43-48` (function signature) +- Modify: `packages/cli/src/opentools/shared/subprocess.py:68` (`create_subprocess_exec` call) +- Modify: `packages/cli/tests/test_scanner/test_shared_subprocess.py` + +- [ ] **Step 1: Write failing test for env parameter** + +Add to `packages/cli/tests/test_scanner/test_shared_subprocess.py`: + +```python +class TestRunStreamingEnv: + @pytest.mark.asyncio + async def test_env_vars_passed_to_subprocess(self): + """Custom env dict is forwarded to the child process.""" + import os + custom_env = {**os.environ, "OT_TEST_PROXY": "socks5://127.0.0.1:1080"} + result = await run_streaming( + [sys.executable, "-c", "import os; print(os.environ.get('OT_TEST_PROXY', 'MISSING'))"], + on_output=lambda _: None, + env=custom_env, + ) + assert result.exit_code == 0 + assert "socks5://127.0.0.1:1080" in result.stdout + + @pytest.mark.asyncio + async def test_env_none_inherits_parent(self): + """env=None (default) inherits the parent process environment.""" + result = await run_streaming( + [sys.executable, "-c", "import os; print(os.environ.get('PATH', 'MISSING'))"], + on_output=lambda _: None, + ) + assert result.exit_code == 0 + assert "MISSING" not in result.stdout +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_shared_subprocess.py::TestRunStreamingEnv -v` +Expected: FAIL with `TypeError: run_streaming() got an unexpected keyword argument 'env'` + +- [ ] **Step 3: Add env parameter to run_streaming** + +In `packages/cli/src/opentools/shared/subprocess.py`, change the function signature (line 43): + +```python +async def run_streaming( + args: list[str], + on_output: Callable[[bytes], None], + timeout: int = 300, + cancellation: object | None = None, # CancellationToken + env: dict[str, str] | None = None, +) -> SubprocessResult: +``` + +And pass it through at line 68: + +```python + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) +``` + +- [ ] **Step 4: Run all subprocess tests to verify pass + no regressions** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_shared_subprocess.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/shared/subprocess.py packages/cli/tests/test_scanner/test_shared_subprocess.py +git commit -m "feat(subprocess): add env parameter to run_streaming for proxy injection" +``` + +--- + +### Task 6: CloudNodeProvider ABC + DigitalOcean Implementation + +**Files:** +- Create: `packages/cli/src/opentools/scanner/infra/__init__.py` +- Create: `packages/cli/src/opentools/scanner/infra/provider.py` +- Create: `packages/cli/src/opentools/scanner/infra/digitalocean.py` +- Test: `packages/cli/tests/test_scanner/test_infra_provider.py` + +- [ ] **Step 1: Write failing tests** + +```python +# packages/cli/tests/test_scanner/test_infra_provider.py +"""Tests for CloudNodeProvider and DigitalOceanProvider.""" + +import asyncio +import json + +import httpx +import pytest + +from opentools.scanner.infra.provider import ( + CloudNodeProvider, + EphemeralNode, + ProvisioningError, + ProvisioningTimeout, +) +from opentools.scanner.infra.digitalocean import DigitalOceanProvider + + +class TestEphemeralNode: + def test_construction(self): + node = EphemeralNode( + provider_id="12345", + ip_address="1.2.3.4", + region="nyc3", + ssh_key_fingerprint="aa:bb:cc", + ) + assert node.provider_id == "12345" + assert node.ip_address == "1.2.3.4" + assert node.tags == [] + + +class TestDigitalOceanProvider: + @pytest.mark.asyncio + async def test_create_node_sends_correct_request(self): + """Verify the POST payload sent to the DO API.""" + captured_request = None + + async def mock_handler(request: httpx.Request) -> httpx.Response: + nonlocal captured_request + captured_request = request + return httpx.Response( + 200, json={"droplet": {"id": 12345}}, + ) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + result = await provider.create_node( + region="nyc3", ssh_public_key="fp:aa:bb", tags=["test"], + ) + assert result == "12345" + body = json.loads(captured_request.content) + assert body["region"] == "nyc3" + assert "test" in body["tags"] + + @pytest.mark.asyncio + async def test_poll_status_active(self): + async def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "droplet": { + "id": 12345, "status": "active", + "networks": {"v4": [{"ip_address": "1.2.3.4", "type": "public"}]}, + }, + }) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + status, ip = await provider.poll_status("12345") + assert status == "active" + assert ip == "1.2.3.4" + + @pytest.mark.asyncio + async def test_poll_status_creating(self): + async def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "droplet": {"id": 12345, "status": "new", "networks": {"v4": []}}, + }) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + status, ip = await provider.poll_status("12345") + assert status == "creating" + assert ip is None + + @pytest.mark.asyncio + async def test_destroy_node_204(self): + async def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(204) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + await provider.destroy_node("12345") # should not raise + + @pytest.mark.asyncio + async def test_destroy_node_404_is_idempotent(self): + async def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + await provider.destroy_node("12345") # should not raise + + @pytest.mark.asyncio + async def test_wait_until_ready_success(self): + call_count = 0 + + async def mock_handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count < 3: + return httpx.Response(200, json={ + "droplet": {"id": 12345, "status": "new", "networks": {"v4": []}}, + }) + return httpx.Response(200, json={ + "droplet": { + "id": 12345, "status": "active", + "networks": {"v4": [{"ip_address": "1.2.3.4", "type": "public"}]}, + }, + }) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + ip = await provider.wait_until_ready("12345", poll_interval=0.01, max_polls=10) + assert ip == "1.2.3.4" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_wait_until_ready_timeout(self): + async def mock_handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "droplet": {"id": 12345, "status": "new", "networks": {"v4": []}}, + }) + + transport = httpx.MockTransport(mock_handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.digitalocean.com/v2") + provider = DigitalOceanProvider(client=client) + + with pytest.raises(ProvisioningTimeout): + await provider.wait_until_ready("12345", poll_interval=0.01, max_polls=3) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_provider.py::TestEphemeralNode -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement CloudNodeProvider, EphemeralNode, DigitalOceanProvider** + +```python +# packages/cli/src/opentools/scanner/infra/__init__.py +"""Ephemeral infrastructure provisioning for proxied scan execution.""" +``` + +```python +# packages/cli/src/opentools/scanner/infra/provider.py +"""CloudNodeProvider ABC and shared types for ephemeral infrastructure.""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class EphemeralNode(BaseModel): + """A provisioned ephemeral cloud node.""" + provider_id: str + ip_address: str + region: str + ssh_key_fingerprint: str + tags: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ProvisioningError(Exception): + """Cloud node provisioning failed.""" + + +class ProvisioningTimeout(ProvisioningError): + """Cloud node did not become ready within the polling budget.""" + + +class CloudNodeProvider(ABC): + """Abstract base for provisioning ephemeral proxy nodes. + + All methods are async — they make HTTP calls and must never block + the event loop. + """ + + @abstractmethod + async def create_node( + self, region: str, ssh_public_key: str, tags: list[str], + ) -> str: + """Create a node. Returns the provider resource ID (not yet ready).""" + ... + + @abstractmethod + async def poll_status(self, provider_id: str) -> tuple[str, str | None]: + """Returns (status, ip_address | None). Status: "creating" | "active" | "error".""" + ... + + @abstractmethod + async def destroy_node(self, provider_id: str) -> None: + """Destroy a node. Idempotent — must not raise if already gone.""" + ... + + async def wait_until_ready( + self, + provider_id: str, + poll_interval: float = 3.0, + max_polls: int = 60, + ) -> str: + """Poll until active. Returns IP address. Yields to event loop between polls.""" + for attempt in range(max_polls): + status, ip = await self.poll_status(provider_id) + if status == "active" and ip is not None: + logger.info("Node %s ready at %s after %d polls", provider_id, ip, attempt + 1) + return ip + if status == "error": + raise ProvisioningError(f"Node {provider_id} entered error state") + await asyncio.sleep(poll_interval) + raise ProvisioningTimeout( + f"Node {provider_id} not ready after {max_polls * poll_interval:.0f}s" + ) +``` + +```python +# packages/cli/src/opentools/scanner/infra/digitalocean.py +"""DigitalOcean CloudNodeProvider implementation.""" + +from __future__ import annotations + +import uuid + +import httpx + +from opentools.scanner.infra.provider import CloudNodeProvider + + +class DigitalOceanProvider(CloudNodeProvider): + """Provision ephemeral droplets via the DigitalOcean API.""" + + def __init__(self, client: httpx.AsyncClient) -> None: + self._client = client + + @classmethod + def from_token(cls, api_token: str) -> DigitalOceanProvider: + """Create a provider with a new httpx client using the given API token.""" + client = httpx.AsyncClient( + base_url="https://api.digitalocean.com/v2", + headers={"Authorization": f"Bearer {api_token}"}, + timeout=30.0, + ) + return cls(client=client) + + async def create_node( + self, region: str, ssh_public_key: str, tags: list[str], + ) -> str: + resp = await self._client.post("/droplets", json={ + "name": f"ot-proxy-{uuid.uuid4().hex[:8]}", + "region": region, + "size": "s-1vcpu-512mb-10gb", + "image": "ubuntu-24-04-x64", + "ssh_keys": [ssh_public_key], + "tags": tags, + }) + resp.raise_for_status() + return str(resp.json()["droplet"]["id"]) + + async def poll_status(self, provider_id: str) -> tuple[str, str | None]: + resp = await self._client.get(f"/droplets/{provider_id}") + resp.raise_for_status() + droplet = resp.json()["droplet"] + status = "active" if droplet["status"] == "active" else "creating" + ip = None + for net in droplet.get("networks", {}).get("v4", []): + if net.get("type") == "public": + ip = net["ip_address"] + break + return status, ip + + async def destroy_node(self, provider_id: str) -> None: + resp = await self._client.delete(f"/droplets/{provider_id}") + if resp.status_code not in (204, 404): + resp.raise_for_status() +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_provider.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/infra/__init__.py packages/cli/src/opentools/scanner/infra/provider.py packages/cli/src/opentools/scanner/infra/digitalocean.py packages/cli/tests/test_scanner/test_infra_provider.py +git commit -m "feat(infra): add CloudNodeProvider ABC with DigitalOcean implementation" +``` + +--- + +### Task 7: Ephemeral Proxy Context Manager + Shielded Teardown + +**Files:** +- Create: `packages/cli/src/opentools/scanner/infra/proxy.py` +- Test: `packages/cli/tests/test_scanner/test_infra_proxy.py` + +**Context:** This is the critical piece — the `AsyncContextManager` that composes provision → tunnel → teardown with guaranteed cleanup. The `_shielded_destroy` pattern ensures the cloud API DELETE fires even when the parent task is cancelled. + +- [ ] **Step 1: Write failing tests for proxy lifecycle** + +```python +# packages/cli/tests/test_scanner/test_infra_proxy.py +"""Tests for ephemeral_proxy context manager and shielded teardown.""" + +import asyncio + +import pytest + +from opentools.scanner.infra.provider import ( + CloudNodeProvider, + ProvisioningError, +) +from opentools.scanner.infra.proxy import ( + ProxyEndpoint, + _shielded_destroy, + ephemeral_proxy, +) + + +class FakeProvider(CloudNodeProvider): + """In-memory provider for testing. No real cloud calls.""" + + def __init__( + self, + ready_after_polls: int = 1, + ip_address: str = "1.2.3.4", + fail_create: bool = False, + fail_destroy: bool = False, + ): + self.ready_after_polls = ready_after_polls + self.ip_address = ip_address + self.fail_create = fail_create + self.fail_destroy = fail_destroy + self.created_ids: list[str] = [] + self.destroyed_ids: list[str] = [] + self._poll_count = 0 + + async def create_node(self, region, ssh_public_key, tags): + if self.fail_create: + raise ProvisioningError("create failed") + node_id = f"fake-{len(self.created_ids)}" + self.created_ids.append(node_id) + return node_id + + async def poll_status(self, provider_id): + self._poll_count += 1 + if self._poll_count >= self.ready_after_polls: + return "active", self.ip_address + return "creating", None + + async def destroy_node(self, provider_id): + if self.fail_destroy: + raise ProvisioningError("destroy failed") + self.destroyed_ids.append(provider_id) + + +class TestShieldedDestroy: + @pytest.mark.asyncio + async def test_normal_destroy(self): + provider = FakeProvider() + provider.created_ids.append("node-1") + await _shielded_destroy(provider, "node-1") + assert "node-1" in provider.destroyed_ids + + @pytest.mark.asyncio + async def test_destroy_survives_cancellation(self): + """The destroy call completes even when the outer scope is cancelled.""" + provider = FakeProvider() + provider.created_ids.append("node-1") + destroyed = False + + async def destroy_with_delay(): + """Simulate a destroy that takes some time.""" + await asyncio.sleep(0.05) + await _shielded_destroy(provider, "node-1") + nonlocal destroyed + destroyed = True + + task = asyncio.ensure_future(destroy_with_delay()) + await asyncio.sleep(0.01) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Give the shielded destroy time to complete + await asyncio.sleep(0.1) + assert "node-1" in provider.destroyed_ids + + @pytest.mark.asyncio + async def test_destroy_failure_does_not_raise(self): + """Destroy errors are logged, not raised — we can't do more.""" + provider = FakeProvider(fail_destroy=True) + # Should not raise + await _shielded_destroy(provider, "node-1") + + +class TestProxyEndpoint: + def test_env_includes_proxy_vars(self): + endpoint = ProxyEndpoint(host="1.2.3.4", socks_port=10800) + env = endpoint.env + assert "socks5://127.0.0.1:10800" in env["HTTP_PROXY"] + assert "socks5://127.0.0.1:10800" in env["HTTPS_PROXY"] + assert "socks5://127.0.0.1:10800" in env["ALL_PROXY"] + # Both upper and lower case for tool compatibility + assert "socks5://127.0.0.1:10800" in env["http_proxy"] + assert "socks5://127.0.0.1:10800" in env["https_proxy"] + + def test_env_inherits_parent_env(self): + """Proxy env vars are merged with parent environment.""" + import os + endpoint = ProxyEndpoint(host="1.2.3.4", socks_port=10800) + env = endpoint.env + assert "PATH" in env or "Path" in env # platform-dependent casing + + +class TestEphemeralProxyLifecycle: + @pytest.mark.asyncio + async def test_provision_and_teardown(self): + """Happy path: provision → yield endpoint → destroy.""" + provider = FakeProvider(ready_after_polls=1) + + async with ephemeral_proxy( + provider=provider, + region="nyc3", + ssh_key="test-key", + local_socks_port=10800, + scan_id="scan-1", + _skip_tunnel=True, # skip real SSH for unit test + ) as endpoint: + assert isinstance(endpoint, ProxyEndpoint) + assert endpoint.host == "1.2.3.4" + assert endpoint.socks_port == 10800 + + # After exit, node must be destroyed + assert len(provider.created_ids) == 1 + assert len(provider.destroyed_ids) == 1 + assert provider.created_ids[0] == provider.destroyed_ids[0] + + @pytest.mark.asyncio + async def test_teardown_on_exception(self): + """Node is destroyed even if the body raises.""" + provider = FakeProvider(ready_after_polls=1) + + with pytest.raises(ValueError, match="deliberate"): + async with ephemeral_proxy( + provider=provider, region="nyc3", ssh_key="k", + local_socks_port=10800, scan_id="s1", + _skip_tunnel=True, + ): + raise ValueError("deliberate") + + assert len(provider.destroyed_ids) == 1 + + @pytest.mark.asyncio + async def test_create_failure_no_destroy(self): + """If creation fails, no destroy is attempted (nothing to destroy).""" + provider = FakeProvider(fail_create=True) + + with pytest.raises(ProvisioningError): + async with ephemeral_proxy( + provider=provider, region="nyc3", ssh_key="k", + local_socks_port=10800, scan_id="s1", + _skip_tunnel=True, + ): + pass # pragma: no cover + + assert provider.destroyed_ids == [] +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_proxy.py::TestProxyEndpoint -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement ephemeral_proxy, ProxyEndpoint, _shielded_destroy** + +```python +# packages/cli/src/opentools/scanner/infra/proxy.py +"""Ephemeral proxy context manager with guaranteed teardown. + +Usage:: + + async with ephemeral_proxy(provider, region="nyc3", ...) as endpoint: + result = await run_streaming(args, on_output, env=endpoint.env) + # Node is destroyed here, guaranteed. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from contextlib import asynccontextmanager +from typing import AsyncIterator + +from opentools.scanner.infra.provider import CloudNodeProvider, ProvisioningError + +logger = logging.getLogger(__name__) + +PROXY_TAG = "opentools-ephemeral-proxy" + + +class ProxyEndpoint: + """The usable result of a provisioned proxy — env vars for subprocess injection.""" + + def __init__(self, host: str, socks_port: int) -> None: + self.host = host + self.socks_port = socks_port + + @property + def env(self) -> dict[str, str]: + """Environment dict with proxy vars merged into parent env.""" + proxy_url = f"socks5://127.0.0.1:{self.socks_port}" + return { + **os.environ, + "HTTP_PROXY": proxy_url, + "HTTPS_PROXY": proxy_url, + "http_proxy": proxy_url, + "https_proxy": proxy_url, + "ALL_PROXY": proxy_url, + } + + +@asynccontextmanager +async def ephemeral_proxy( + provider: CloudNodeProvider, + region: str = "nyc3", + ssh_key: str = "", + ssh_key_path: str = "~/.ssh/id_ed25519", + local_socks_port: int = 10800, + scan_id: str = "", + _skip_tunnel: bool = False, +) -> AsyncIterator[ProxyEndpoint]: + """Provision ephemeral proxy, optionally establish tunnel, yield endpoint. + + GUARANTEE: The cloud node is destroyed on exit regardless of how the + body terminates — normal return, exception, timeout, or cancellation. + The destroy call is shielded from asyncio.CancelledError. + + Args: + _skip_tunnel: If True, skip SSH tunnel setup (for unit testing). + """ + tags = [PROXY_TAG, f"scan:{scan_id}"] + provider_id: str | None = None + tunnel_proc: asyncio.subprocess.Process | None = None + + try: + # Phase 1: Provision + provider_id = await provider.create_node( + region=region, ssh_public_key=ssh_key, tags=tags, + ) + logger.info("Provisioning node %s in %s", provider_id, region) + + ip_address = await provider.wait_until_ready(provider_id) + + # Phase 2: SSH tunnel (skippable for testing) + if not _skip_tunnel: + tunnel_proc = await _establish_tunnel( + ip_address=ip_address, + ssh_key_path=ssh_key_path, + local_port=local_socks_port, + ) + + endpoint = ProxyEndpoint(host=ip_address, socks_port=local_socks_port) + logger.info( + "Proxy ready: 127.0.0.1:%d → %s (node %s)", + local_socks_port, ip_address, provider_id, + ) + + yield endpoint + + finally: + # Phase 3: Teardown (reverse order) + + # 3a. Kill SSH tunnel + if tunnel_proc is not None and tunnel_proc.returncode is None: + tunnel_proc.terminate() + try: + await asyncio.wait_for(tunnel_proc.wait(), timeout=5.0) + except asyncio.TimeoutError: + tunnel_proc.kill() + await tunnel_proc.wait() + logger.info("SSH tunnel terminated") + + # 3b. Destroy cloud node (shielded from cancellation) + if provider_id is not None: + await _shielded_destroy(provider, provider_id) + + +async def _shielded_destroy(provider: CloudNodeProvider, provider_id: str) -> None: + """Destroy a cloud node, shielded from asyncio.CancelledError. + + Pattern: create a concrete Task, shield it, and if CancelledError fires + in our scope, await the task directly (it's still running unaffected). + """ + destroy_task = asyncio.ensure_future(provider.destroy_node(provider_id)) + try: + await asyncio.shield(destroy_task) + logger.info("Node %s destroyed", provider_id) + except asyncio.CancelledError: + try: + await destroy_task + logger.info("Node %s destroyed (post-cancellation)", provider_id) + except Exception: + logger.exception("Failed to destroy node %s during cancellation", provider_id) + raise + except Exception: + logger.exception("Failed to destroy node %s", provider_id) + + +async def _establish_tunnel( + ip_address: str, + ssh_key_path: str, + local_port: int, + max_retries: int = 5, + retry_delay: float = 3.0, +) -> asyncio.subprocess.Process: + """Start SSH dynamic SOCKS5 tunnel with retry for sshd startup race.""" + expanded_key = os.path.expanduser(ssh_key_path) + + for attempt in range(max_retries): + proc = await asyncio.create_subprocess_exec( + "ssh", + "-D", str(local_port), + "-N", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=10", + "-o", "ServerAliveInterval=15", + "-o", "ServerAliveCountMax=3", + "-o", "ExitOnForwardFailure=yes", + "-i", expanded_key, + f"root@{ip_address}", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + await asyncio.sleep(2.0) + if proc.returncode is None: + return proc + + stderr = b"" + if proc.stderr: + stderr = await proc.stderr.read() + if attempt < max_retries - 1: + logger.debug( + "SSH attempt %d failed: %s — retrying", + attempt + 1, stderr.decode(errors="replace").strip(), + ) + await asyncio.sleep(retry_delay) + continue + + raise ProvisioningError( + f"SSH tunnel failed after {max_retries} attempts: {stderr.decode(errors='replace')}" + ) + + raise ProvisioningError("SSH tunnel failed: exhausted retries") +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_proxy.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/infra/proxy.py packages/cli/tests/test_scanner/test_infra_proxy.py +git commit -m "feat(infra): ephemeral_proxy context manager with shielded teardown" +``` + +--- + +### Task 8: ProxiedShellExecutor + +**Files:** +- Create: `packages/cli/src/opentools/scanner/executor/proxied_shell.py` +- Modify: `packages/cli/src/opentools/scanner/executor/__init__.py` +- Test: `packages/cli/tests/test_scanner/test_executor_proxied_shell.py` + +**Context:** Wraps the existing shell execution path. Tasks opt in to proxying via `task.isolation == TaskIsolation.NETWORK_ISOLATED`. Non-proxied tasks fall through to a plain `run_streaming` call identical to `ShellExecutor`. + +- [ ] **Step 1: Write failing tests** + +```python +# packages/cli/tests/test_scanner/test_executor_proxied_shell.py +"""Tests for ProxiedShellExecutor.""" + +import sys + +import pytest + +from opentools.scanner.cancellation import CancellationToken +from opentools.scanner.executor.base import TaskOutput +from opentools.scanner.executor.proxied_shell import ProxiedShellExecutor +from opentools.scanner.infra.provider import CloudNodeProvider, ProvisioningError +from opentools.scanner.models import ScanTask, TaskIsolation, TaskType + + +class FakeProvider(CloudNodeProvider): + def __init__(self): + self.created = [] + self.destroyed = [] + self._poll_count = 0 + + async def create_node(self, region, ssh_public_key, tags): + self.created.append(region) + return "fake-node" + + async def poll_status(self, provider_id): + self._poll_count += 1 + return "active", "1.2.3.4" + + async def destroy_node(self, provider_id): + self.destroyed.append(provider_id) + + +def _make_task( + task_id: str = "t1", + command: str = f"{sys.executable} -c \"print('hello')\"", + isolation: TaskIsolation = TaskIsolation.NONE, + tool: str = "nuclei", +) -> ScanTask: + return ScanTask( + id=task_id, scan_id="scan-1", name="test", + tool=tool, task_type=TaskType.SHELL, + command=command, isolation=isolation, + ) + + +class TestProxiedShellExecutor: + @pytest.mark.asyncio + async def test_non_proxied_task_runs_directly(self): + """Tasks without NETWORK_ISOLATED run as plain subprocess.""" + executor = ProxiedShellExecutor(provider=None) + task = _make_task(isolation=TaskIsolation.NONE) + cancel = CancellationToken() + result = await executor.execute(task, lambda _: None, cancel) + assert result.exit_code == 0 + assert "hello" in result.stdout + + @pytest.mark.asyncio + async def test_no_provider_falls_through(self): + """NETWORK_ISOLATED task with no provider configured runs directly.""" + executor = ProxiedShellExecutor(provider=None) + task = _make_task(isolation=TaskIsolation.NETWORK_ISOLATED) + cancel = CancellationToken() + result = await executor.execute(task, lambda _: None, cancel) + assert result.exit_code == 0 + + @pytest.mark.asyncio + async def test_proxied_task_provisions_and_destroys(self): + """NETWORK_ISOLATED task with provider triggers full lifecycle.""" + provider = FakeProvider() + executor = ProxiedShellExecutor( + provider=provider, ssh_key="k", ssh_key_path="/dev/null", + _skip_tunnel=True, + ) + task = _make_task(isolation=TaskIsolation.NETWORK_ISOLATED) + cancel = CancellationToken() + result = await executor.execute(task, lambda _: None, cancel) + assert result.exit_code == 0 + assert len(provider.created) == 1 + assert len(provider.destroyed) == 1 + + @pytest.mark.asyncio + async def test_missing_command_raises(self): + executor = ProxiedShellExecutor(provider=None) + task = ScanTask( + id="t1", scan_id="s1", name="no-cmd", + tool="test", task_type=TaskType.SHELL, + ) + cancel = CancellationToken() + with pytest.raises(ValueError, match="no command"): + await executor.execute(task, lambda _: None, cancel) + + @pytest.mark.asyncio + async def test_socks_port_increments(self): + """Each proxied task gets a unique SOCKS port.""" + executor = ProxiedShellExecutor( + provider=FakeProvider(), ssh_key="k", + base_socks_port=10800, _skip_tunnel=True, + ) + assert executor._next_socks_port() == 10800 + assert executor._next_socks_port() == 10801 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_executor_proxied_shell.py::TestProxiedShellExecutor::test_non_proxied_task_runs_directly -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement ProxiedShellExecutor** + +```python +# packages/cli/src/opentools/scanner/executor/proxied_shell.py +"""ProxiedShellExecutor — optionally routes tool traffic through ephemeral proxy.""" + +from __future__ import annotations + +import shlex +from typing import Callable + +from opentools.scanner.cancellation import CancellationToken +from opentools.scanner.executor.base import TaskOutput +from opentools.scanner.infra.provider import CloudNodeProvider +from opentools.scanner.infra.proxy import ephemeral_proxy +from opentools.scanner.models import ScanTask, TaskIsolation +from opentools.shared.subprocess import run_streaming + + +class ProxiedShellExecutor: + """Shell executor with optional ephemeral proxy routing. + + Tasks with ``isolation == TaskIsolation.NETWORK_ISOLATED`` are routed + through an ephemeral SOCKS5 proxy. All other tasks run as plain + subprocesses, identical to ShellExecutor. + """ + + def __init__( + self, + provider: CloudNodeProvider | None, + ssh_key: str = "", + ssh_key_path: str = "~/.ssh/id_ed25519", + default_timeout: int = 300, + base_socks_port: int = 10800, + _skip_tunnel: bool = False, + ) -> None: + self._provider = provider + self._ssh_key = ssh_key + self._ssh_key_path = ssh_key_path + self._default_timeout = default_timeout + self._base_socks_port = base_socks_port + self._port_counter = 0 + self._skip_tunnel = _skip_tunnel + + def _next_socks_port(self) -> int: + port = self._base_socks_port + self._port_counter + self._port_counter += 1 + return port + + async def execute( + self, + task: ScanTask, + on_output: Callable[[bytes], None], + cancellation: CancellationToken, + ) -> TaskOutput: + if task.command is None: + raise ValueError(f"Task {task.id} has no command") + + needs_proxy = ( + self._provider is not None + and task.isolation == TaskIsolation.NETWORK_ISOLATED + ) + + if not needs_proxy: + return await self._run_direct(task, on_output, cancellation) + + return await self._run_proxied(task, on_output, cancellation) + + async def _run_direct( + self, task: ScanTask, on_output: Callable[[bytes], None], + cancellation: CancellationToken, + ) -> TaskOutput: + args = shlex.split(task.command) + result = await run_streaming( + args=args, on_output=on_output, + timeout=self._default_timeout, cancellation=cancellation, + ) + return TaskOutput( + exit_code=result.exit_code, stdout=result.stdout, + stderr=result.stderr, duration_ms=result.duration_ms, + ) + + async def _run_proxied( + self, task: ScanTask, on_output: Callable[[bytes], None], + cancellation: CancellationToken, + ) -> TaskOutput: + socks_port = self._next_socks_port() + + async with ephemeral_proxy( + provider=self._provider, + region="nyc3", + ssh_key=self._ssh_key, + ssh_key_path=self._ssh_key_path, + local_socks_port=socks_port, + scan_id=task.scan_id, + _skip_tunnel=self._skip_tunnel, + ) as proxy: + args = shlex.split(task.command) + result = await run_streaming( + args=args, on_output=on_output, + timeout=self._default_timeout, cancellation=cancellation, + env=proxy.env, + ) + + return TaskOutput( + exit_code=result.exit_code, stdout=result.stdout, + stderr=result.stderr, duration_ms=result.duration_ms, + ) +``` + +- [ ] **Step 4: Update executor __init__.py exports** + +In `packages/cli/src/opentools/scanner/executor/__init__.py`, add: + +```python +from opentools.scanner.executor.proxied_shell import ProxiedShellExecutor +``` + +And add `"ProxiedShellExecutor"` to the `__all__` list. + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_executor_proxied_shell.py -v` +Expected: All PASS + +- [ ] **Step 6: Run all executor tests for regression check** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_executor_shell.py tests/test_scanner/test_executor_docker.py -v` +Expected: All PASS + +- [ ] **Step 7: Commit** + +```bash +git add packages/cli/src/opentools/scanner/executor/proxied_shell.py packages/cli/src/opentools/scanner/executor/__init__.py packages/cli/tests/test_scanner/test_executor_proxied_shell.py +git commit -m "feat(executor): add ProxiedShellExecutor with ephemeral proxy routing" +``` + +--- + +### Task 9: Orphan Sweeper + Final Integration + +**Files:** +- Create: `packages/cli/src/opentools/scanner/infra/sweeper.py` +- Test: `packages/cli/tests/test_scanner/test_infra_sweeper.py` (inline below) + +- [ ] **Step 1: Write failing tests for sweep_orphaned_nodes** + +```python +# packages/cli/tests/test_scanner/test_infra_sweeper.py +"""Tests for orphan node sweeper.""" + +import pytest + +from opentools.scanner.infra.sweeper import sweep_orphaned_nodes + + +class FakeListableProvider: + """Provider that also supports listing nodes by tag.""" + def __init__(self, orphan_ids: list[str]): + self._orphan_ids = orphan_ids + self.destroyed_ids: list[str] = [] + + async def list_nodes_by_tag(self, tag: str) -> list[str]: + return list(self._orphan_ids) + + async def destroy_node(self, provider_id: str) -> None: + self.destroyed_ids.append(provider_id) + + +class TestSweepOrphanedNodes: + @pytest.mark.asyncio + async def test_destroys_orphans(self): + provider = FakeListableProvider(orphan_ids=["orphan-1", "orphan-2"]) + count = await sweep_orphaned_nodes(provider) + assert count == 2 + assert set(provider.destroyed_ids) == {"orphan-1", "orphan-2"} + + @pytest.mark.asyncio + async def test_no_orphans(self): + provider = FakeListableProvider(orphan_ids=[]) + count = await sweep_orphaned_nodes(provider) + assert count == 0 + assert provider.destroyed_ids == [] + + @pytest.mark.asyncio + async def test_destroy_failure_continues(self): + """One destroy failure should not stop cleanup of remaining nodes.""" + class PartialFailProvider(FakeListableProvider): + async def destroy_node(self, provider_id): + if provider_id == "orphan-1": + raise Exception("API error") + self.destroyed_ids.append(provider_id) + + provider = PartialFailProvider(orphan_ids=["orphan-1", "orphan-2"]) + count = await sweep_orphaned_nodes(provider) + # orphan-2 was destroyed, orphan-1 failed + assert count == 1 + assert "orphan-2" in provider.destroyed_ids +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_sweeper.py -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement sweep_orphaned_nodes** + +```python +# packages/cli/src/opentools/scanner/infra/sweeper.py +"""Startup sweeper for orphaned ephemeral proxy nodes. + +Call sweep_orphaned_nodes() during ScanEngine initialization to clean +up nodes from crashed previous runs. Uses the PROXY_TAG to identify +resources that belong to OpenTools. +""" + +from __future__ import annotations + +import logging + +from opentools.scanner.infra.proxy import PROXY_TAG + +logger = logging.getLogger(__name__) + + +async def sweep_orphaned_nodes(provider: object) -> int: + """Destroy any nodes tagged with PROXY_TAG from previous runs. + + The provider must implement ``list_nodes_by_tag(tag) -> list[str]`` + and ``destroy_node(provider_id)``. + + Returns the number of nodes successfully destroyed. + """ + if not hasattr(provider, "list_nodes_by_tag"): + logger.debug("Provider does not support list_nodes_by_tag, skipping sweep") + return 0 + + orphan_ids = await provider.list_nodes_by_tag(PROXY_TAG) + if not orphan_ids: + return 0 + + logger.info("Found %d orphaned proxy nodes to sweep", len(orphan_ids)) + destroyed = 0 + + for node_id in orphan_ids: + try: + await provider.destroy_node(node_id) + destroyed += 1 + logger.info("Destroyed orphaned node %s", node_id) + except Exception: + logger.exception("Failed to destroy orphaned node %s", node_id) + + return destroyed +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_sweeper.py -v` +Expected: All PASS + +- [ ] **Step 5: Run the full test suite for both phases** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_mutation_models.py tests/test_scanner/test_mutation_analyzer.py tests/test_scanner/test_mutation_strategy.py tests/test_scanner/test_engine_mutation.py tests/test_scanner/test_engine.py tests/test_scanner/test_shared_subprocess.py tests/test_scanner/test_infra_provider.py tests/test_scanner/test_infra_proxy.py tests/test_scanner/test_executor_proxied_shell.py tests/test_scanner/test_infra_sweeper.py -v` +Expected: All PASS + +- [ ] **Step 6: Commit** + +```bash +git add packages/cli/src/opentools/scanner/infra/sweeper.py packages/cli/tests/test_scanner/test_infra_sweeper.py +git commit -m "feat(infra): add orphan node sweeper for startup cleanup" +``` + +--- + +## Appendix: Concurrency Safety Proof + +The mutation layer introduces zero new concurrency hazards. Here is the execution trace through a single event loop turn: + +``` +await asyncio.wait(in_flight, FIRST_COMPLETED) # ← yield point +│ +├─ for completed_future in done: # synchronous +│ ├─ _mark_completed(task_id, output) # synchronous +│ │ ├─ analyzer.analyze(stdout, stderr) # synchronous (XML/JSON parse) +│ │ ├─ kill_chain.ingest(bundle) # synchronous (dict mutation) +│ │ ├─ strategy.evaluate(state, scan_id, task) # synchronous +│ │ ├─ _evaluate_edges(task, output) # synchronous (existing) +│ │ └─ _inject_tasks(new_tasks) # synchronous (dict mutation) +│ └─ (next completed_future) +│ +├─ ready_tasks_by_priority() # sees injected tasks +├─ dispatch → asyncio.ensure_future(...) # schedules, no yield +└─ await asyncio.wait(...) # ← next yield point +``` + +No `await` between state mutation and the next readiness check. The asyncio event loop cannot context-switch during this chain. diff --git a/docs/superpowers/plans/2026-04-13-hitl-approval-gate-vultr-provider.md b/docs/superpowers/plans/2026-04-13-hitl-approval-gate-vultr-provider.md new file mode 100644 index 0000000..fe7664c --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-hitl-approval-gate-vultr-provider.md @@ -0,0 +1,1384 @@ +# HITL Approval Gate + Vultr Provider Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a durable Human-In-The-Loop approval gate to the DAG execution engine and a Vultr cloud provider for ephemeral proxy routing. + +**Architecture:** The approval gate is an execution wrapper — not a separate TaskType — that inserts a sleep-until-approved phase into `_execute_task` before the real executor fires. Persistence is SQLite-first: the FastAPI route writes decisions to the database before signaling the in-memory `asyncio.Event`. The Vultr provider implements the existing `CloudNodeProvider` ABC with explicit SSH key injection for automated tunnel bootstrapping. + +**Tech Stack:** Python 3.12+, Pydantic v2, asyncio, httpx, aiosqlite, FastAPI, pytest + pytest-asyncio + +**Spec:** [2026-04-13-hitl-approval-gate-vultr-provider-design.md](../specs/2026-04-13-hitl-approval-gate-vultr-provider-design.md) + +**Depends on:** [2026-04-13-dag-mutation-ephemeral-proxy.md](2026-04-13-dag-mutation-ephemeral-proxy.md) (Phase A Tasks 1-4 and Phase B Task 6) + +--- + +## File Structure + +### Phase C — HITL Approval Gate + +| Action | Path | Responsibility | +|--------|------|----------------| +| Modify | `packages/cli/src/opentools/scanner/models.py:54-60,201-231` | Add `AWAITING_APPROVAL` to `TaskStatus`, `ApprovalRequirement` model, new fields on `ScanTask` | +| Create | `packages/cli/src/opentools/scanner/approval.py` | `ApprovalRegistry` singleton | +| Modify | `packages/cli/src/opentools/scanner/engine.py:215-251` | Insert gate phase into `_execute_task` | +| Modify | `packages/web/backend/app/models.py:252-284` | Add approval columns to `ScanTaskRecord` | +| Modify | `packages/web/backend/app/routes/scans.py` | Add gate list/approve/reject endpoints | +| Create | `packages/cli/tests/test_scanner/test_approval_registry.py` | Tests for ApprovalRegistry | +| Create | `packages/cli/tests/test_scanner/test_engine_approval.py` | Integration tests for gate in engine | +| Create | `packages/web/backend/tests/test_gate_routes.py` | Tests for FastAPI gate endpoints | + +### Task 6b — Vultr Provider + +| Action | Path | Responsibility | +|--------|------|----------------| +| Create | `packages/cli/src/opentools/scanner/infra/vultr.py` | `VultrProvider` implementation | +| Create | `packages/cli/tests/test_scanner/test_infra_vultr.py` | Tests with httpx MockTransport | + +--- + +## Phase C: HITL Approval Gate + +### Task 10: Model Additions — ApprovalRequirement, TaskStatus, ScanTask Fields + +**Files:** +- Modify: `packages/cli/src/opentools/scanner/models.py:54-60` (TaskStatus enum) +- Modify: `packages/cli/src/opentools/scanner/models.py:188-231` (new model + ScanTask fields) +- Test: `packages/cli/tests/test_scanner/test_mutation_models.py` (append) + +- [ ] **Step 1: Write failing tests for ApprovalRequirement and new ScanTask fields** + +```python +# Append to packages/cli/tests/test_scanner/test_mutation_models.py +# OR create packages/cli/tests/test_scanner/test_approval_models.py + +from opentools.scanner.models import ( + ApprovalRequirement, + ScanTask, + TaskStatus, + TaskType, +) + + +class TestApprovalRequirement: + def test_defaults(self): + req = ApprovalRequirement() + assert req.timeout_seconds == 3600 + assert req.description == "" + + def test_custom_values(self): + req = ApprovalRequirement( + timeout_seconds=600, + description="Deploy Sliver agent on 10.0.0.1", + ) + assert req.timeout_seconds == 600 + assert req.description == "Deploy Sliver agent on 10.0.0.1" + + +class TestTaskStatusApproval: + def test_awaiting_approval_value(self): + assert TaskStatus.AWAITING_APPROVAL == "awaiting_approval" + + def test_awaiting_approval_in_enum(self): + assert "awaiting_approval" in [s.value for s in TaskStatus] + + +class TestScanTaskApprovalFields: + def test_requires_approval_default_none(self): + task = ScanTask( + id="t1", scan_id="s1", name="test", + tool="nmap", task_type=TaskType.SHELL, + ) + assert task.requires_approval is None + assert task.approval_ticket_id is None + assert task.approval_expires_at is None + + def test_requires_approval_set(self): + req = ApprovalRequirement( + timeout_seconds=1800, + description="Dangerous action", + ) + task = ScanTask( + id="t1", scan_id="s1", name="test", + tool="c2", task_type=TaskType.SHELL, + requires_approval=req, + ) + assert task.requires_approval.timeout_seconds == 1800 + assert task.requires_approval.description == "Dangerous action" + + def test_approval_ticket_fields(self): + from datetime import datetime, timezone + task = ScanTask( + id="t1", scan_id="s1", name="test", + tool="nmap", task_type=TaskType.SHELL, + approval_ticket_id="gate-t1-abc123", + approval_expires_at=datetime(2026, 4, 13, 15, 0, 0, tzinfo=timezone.utc), + ) + assert task.approval_ticket_id == "gate-t1-abc123" + assert task.approval_expires_at.year == 2026 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_approval_models.py::TestApprovalRequirement::test_defaults -v` +Expected: FAIL with `ImportError: cannot import name 'ApprovalRequirement'` + +- [ ] **Step 3: Implement model additions** + +In `packages/cli/src/opentools/scanner/models.py`, add `AWAITING_APPROVAL` to `TaskStatus` (after line 60): + +```python +class TaskStatus(StrEnum): + PENDING = "pending" + BLOCKED = "blocked" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + AWAITING_APPROVAL = "awaiting_approval" +``` + +Add `ApprovalRequirement` model before `ReactiveEdge` (before line 188): + +```python +class ApprovalRequirement(BaseModel): + """Gate metadata for tasks requiring operator approval before execution.""" + timeout_seconds: int = 3600 + description: str = "" +``` + +Add new fields to `ScanTask` (after line 231, before the closing of the class): + +```python +class ScanTask(BaseModel): + # ... existing fields through completed_at ... + requires_approval: Optional[ApprovalRequirement] = None + approval_ticket_id: Optional[str] = None + approval_expires_at: Optional[datetime] = None +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_approval_models.py -v` +Expected: All PASS + +- [ ] **Step 5: Run existing model tests for regression check** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_models.py tests/test_scanner/test_engine.py -v` +Expected: All PASS + +- [ ] **Step 6: Commit** + +```bash +git add packages/cli/src/opentools/scanner/models.py packages/cli/tests/test_scanner/test_approval_models.py +git commit -m "feat(models): add ApprovalRequirement, AWAITING_APPROVAL status, approval fields on ScanTask" +``` + +--- + +### Task 11: ApprovalRegistry — In-Memory Notification Hub + +**Files:** +- Create: `packages/cli/src/opentools/scanner/approval.py` +- Test: `packages/cli/tests/test_scanner/test_approval_registry.py` + +- [ ] **Step 1: Write failing tests for ApprovalRegistry** + +```python +# packages/cli/tests/test_scanner/test_approval_registry.py +"""Tests for ApprovalRegistry — in-memory notification hub.""" + +import asyncio + +import pytest + +from opentools.scanner.approval import ApprovalRegistry + + +class TestApprovalRegistry: + def test_register_returns_event(self): + registry = ApprovalRegistry() + event = registry.register("ticket-1") + assert isinstance(event, asyncio.Event) + assert not event.is_set() + + def test_signal_sets_event(self): + registry = ApprovalRegistry() + event = registry.register("ticket-1") + result = registry.signal("ticket-1") + assert result is True + assert event.is_set() + + def test_signal_missing_returns_false(self): + registry = ApprovalRegistry() + result = registry.signal("nonexistent") + assert result is False + + def test_remove_cleans_up(self): + registry = ApprovalRegistry() + registry.register("ticket-1") + registry.remove("ticket-1") + assert registry.signal("ticket-1") is False + + def test_remove_missing_does_not_raise(self): + registry = ApprovalRegistry() + registry.remove("nonexistent") # should not raise + + def test_has_ticket(self): + registry = ApprovalRegistry() + assert registry.has_ticket("ticket-1") is False + registry.register("ticket-1") + assert registry.has_ticket("ticket-1") is True + + def test_pending_tickets(self): + registry = ApprovalRegistry() + registry.register("ticket-1") + registry.register("ticket-2") + assert registry.pending_ticket_ids() == {"ticket-1", "ticket-2"} + + @pytest.mark.asyncio + async def test_event_wakes_awaiter(self): + """Verify the full async flow: register → await → signal → wake.""" + registry = ApprovalRegistry() + event = registry.register("ticket-1") + woke = False + + async def waiter(): + nonlocal woke + await asyncio.wait_for(event.wait(), timeout=5.0) + woke = True + + task = asyncio.ensure_future(waiter()) + await asyncio.sleep(0.05) + assert not woke + + registry.signal("ticket-1") + await task + assert woke + + @pytest.mark.asyncio + async def test_multiple_gates_independent(self): + """Two gates: signaling one does not wake the other.""" + registry = ApprovalRegistry() + event_a = registry.register("a") + event_b = registry.register("b") + + registry.signal("a") + assert event_a.is_set() + assert not event_b.is_set() +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_approval_registry.py::TestApprovalRegistry::test_register_returns_event -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement ApprovalRegistry** + +```python +# packages/cli/src/opentools/scanner/approval.py +"""ApprovalRegistry — in-memory notification hub for HITL approval gates. + +NOT the source of truth. The database is always authoritative for gate +decisions. This registry only provides asyncio.Event handles so that +FastAPI routes can wake sleeping gate coroutines. + +If a ticket is missing from the registry (e.g., server restarted before +the engine reconstructed it), the route still writes the decision to +SQLite — the executor will pick it up on its next DB read. +""" + +from __future__ import annotations + +import asyncio + + +class ApprovalRegistry: + """In-memory notification hub mapping ticket IDs to asyncio.Events.""" + + def __init__(self) -> None: + self._events: dict[str, asyncio.Event] = {} + + def register(self, ticket_id: str) -> asyncio.Event: + """Create and store an event for a gate ticket. Returns the event.""" + event = asyncio.Event() + self._events[ticket_id] = event + return event + + def signal(self, ticket_id: str) -> bool: + """Signal the event if it exists. Returns False if not in registry.""" + event = self._events.get(ticket_id) + if event is None: + return False + event.set() + return True + + def remove(self, ticket_id: str) -> None: + """Clean up after a gate resolves.""" + self._events.pop(ticket_id, None) + + def has_ticket(self, ticket_id: str) -> bool: + """Check if a ticket is registered.""" + return ticket_id in self._events + + def pending_ticket_ids(self) -> set[str]: + """Return all registered ticket IDs.""" + return set(self._events.keys()) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_approval_registry.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/approval.py packages/cli/tests/test_scanner/test_approval_registry.py +git commit -m "feat(approval): add ApprovalRegistry in-memory notification hub" +``` + +--- + +### Task 12: Engine Integration — Gate Phase in `_execute_task` + +**Files:** +- Modify: `packages/cli/src/opentools/scanner/engine.py:35-69` (`__init__` — add registry + store) +- Modify: `packages/cli/src/opentools/scanner/engine.py:215-251` (`_execute_task` — insert gate phase) +- Test: `packages/cli/tests/test_scanner/test_engine_approval.py` + +**Context:** The gate phase inserts between resource acquisition and executor dispatch. It persists `AWAITING_APPROVAL` + `expires_at` to the store, registers an event, sleeps, then reads the decision from the store on wake. The resource pool uses an `approval_gate` group with effectively unlimited slots so sleeping gates don't starve real workers. + +- [ ] **Step 1: Write failing integration tests** + +```python +# packages/cli/tests/test_scanner/test_engine_approval.py +"""Integration tests: ScanEngine + HITL approval gate.""" + +import asyncio +from datetime import datetime, timezone +from typing import Any, Callable + +import pytest + +from opentools.scanner.approval import ApprovalRegistry +from opentools.scanner.cancellation import CancellationToken +from opentools.scanner.engine import ScanEngine +from opentools.scanner.executor.base import TaskOutput +from opentools.scanner.models import ( + ApprovalRequirement, + Scan, + ScanStatus, + ScanTask, + TaskStatus, + TaskType, + TargetType, +) +from opentools.shared.progress import EventBus +from opentools.shared.resource_pool import AdaptiveResourcePool + + +class MockExecutor: + def __init__(self, results: dict[str, TaskOutput] | None = None): + self._results = results or {} + self._default = TaskOutput(exit_code=0, stdout="ok", duration_ms=10) + self.executed: list[str] = [] + + async def execute( + self, task: ScanTask, on_output: Callable[[bytes], None], + cancellation: CancellationToken, + ) -> TaskOutput: + self.executed.append(task.id) + return self._results.get(task.id, self._default) + + +class FakeStore: + """Minimal in-memory store for approval gate testing.""" + + def __init__(self) -> None: + self._task_statuses: dict[str, dict[str, Any]] = {} + + async def update_task_status(self, task_id: str, status: str, **fields) -> None: + self._task_statuses[task_id] = {"status": status, **fields} + + async def get_task_status(self, task_id: str) -> dict[str, Any] | None: + return self._task_statuses.get(task_id) + + +def _make_scan() -> Scan: + return Scan( + id="scan-1", engagement_id="eng-1", target="10.0.0.1", + target_type=TargetType.NETWORK, status=ScanStatus.PENDING, + created_at=datetime.now(timezone.utc), + ) + + +def _make_gated_task( + task_id: str = "gated-task", + timeout_seconds: int = 3600, + description: str = "Dangerous action", +) -> ScanTask: + return ScanTask( + id=task_id, scan_id="scan-1", name="gated", + tool="c2-agent", task_type=TaskType.SHELL, + command="echo approved", + requires_approval=ApprovalRequirement( + timeout_seconds=timeout_seconds, + description=description, + ), + ) + + +def _make_engine( + tasks: list[ScanTask], + executor: MockExecutor, + store: FakeStore | None = None, + registry: ApprovalRegistry | None = None, +) -> ScanEngine: + pool = AdaptiveResourcePool( + global_limit=4, + group_limits={"approval_gate": 9999}, + ) + executors = { + TaskType.SHELL: executor, + TaskType.DOCKER_EXEC: executor, + TaskType.MCP_CALL: executor, + } + engine = ScanEngine( + scan=_make_scan(), + resource_pool=pool, + executors=executors, + event_bus=EventBus(), + cancellation=CancellationToken(), + ) + + if registry is not None: + engine.set_approval_registry(registry) + if store is not None: + engine.set_approval_store(store) + + engine.load_tasks(tasks) + return engine + + +class TestEngineApprovalGate: + @pytest.mark.asyncio + async def test_approved_task_executes(self): + """Gate approved → real executor fires → task completes.""" + executor = MockExecutor() + store = FakeStore() + registry = ApprovalRegistry() + task = _make_gated_task(timeout_seconds=5) + engine = _make_engine([task], executor, store, registry) + + async def approve_soon(): + # Wait for the gate to register, then approve via store + signal + for _ in range(50): + await asyncio.sleep(0.05) + if registry.has_ticket(registry.pending_ticket_ids().pop()) if registry.pending_ticket_ids() else False: + break + tickets = registry.pending_ticket_ids() + assert len(tickets) == 1 + ticket_id = tickets.pop() + # Write-before-signal: persist to store first + await store.update_task_status( + task.id, "approved", + approval_ticket_id=ticket_id, + ) + registry.signal(ticket_id) + + asyncio.ensure_future(approve_soon()) + await engine.run() + + assert "gated-task" in executor.executed + assert engine._tasks["gated-task"].status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_rejected_task_fails(self): + """Gate rejected → executor never fires → task fails.""" + executor = MockExecutor() + store = FakeStore() + registry = ApprovalRegistry() + task = _make_gated_task(timeout_seconds=5) + engine = _make_engine([task], executor, store, registry) + + async def reject_soon(): + for _ in range(50): + await asyncio.sleep(0.05) + if registry.pending_ticket_ids(): + break + ticket_id = registry.pending_ticket_ids().pop() + await store.update_task_status( + task.id, "rejected", + approval_ticket_id=ticket_id, + ) + registry.signal(ticket_id) + + asyncio.ensure_future(reject_soon()) + await engine.run() + + assert "gated-task" not in executor.executed + assert engine._tasks["gated-task"].status == TaskStatus.FAILED + + @pytest.mark.asyncio + async def test_timeout_rejects_automatically(self): + """Gate times out → executor never fires → task fails.""" + executor = MockExecutor() + store = FakeStore() + registry = ApprovalRegistry() + task = _make_gated_task(timeout_seconds=1) # 1 second timeout + engine = _make_engine([task], executor, store, registry) + + await engine.run() + + assert "gated-task" not in executor.executed + assert engine._tasks["gated-task"].status == TaskStatus.FAILED + # Verify store was updated with expiry + stored = await store.get_task_status("gated-task") + assert stored is not None + assert stored["status"] == "approval_expired" + + @pytest.mark.asyncio + async def test_non_gated_task_unaffected(self): + """Tasks without requires_approval execute normally.""" + executor = MockExecutor() + task = ScanTask( + id="normal", scan_id="scan-1", name="normal", + tool="nmap", task_type=TaskType.SHELL, command="echo normal", + ) + engine = _make_engine([task], executor) + await engine.run() + + assert "normal" in executor.executed + assert engine._tasks["normal"].status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_parallel_branches_continue_during_gate(self): + """Safe branch executes while gated branch is sleeping.""" + executor = MockExecutor() + store = FakeStore() + registry = ApprovalRegistry() + + safe_task = ScanTask( + id="safe", scan_id="scan-1", name="safe", + tool="nmap", task_type=TaskType.SHELL, command="echo safe", + ) + gated_task = _make_gated_task(timeout_seconds=1) + + engine = _make_engine([safe_task, gated_task], executor, store, registry) + await engine.run() + + # Safe branch ran immediately; gated branch timed out + assert "safe" in executor.executed + assert "gated-task" not in executor.executed + assert engine._tasks["safe"].status == TaskStatus.COMPLETED + assert engine._tasks["gated-task"].status == TaskStatus.FAILED + + @pytest.mark.asyncio + async def test_no_gate_without_registry(self): + """If no registry set, gated tasks execute immediately (no gate).""" + executor = MockExecutor() + task = _make_gated_task(timeout_seconds=5) + engine = _make_engine([task], executor) # no registry, no store + await engine.run() + + assert "gated-task" in executor.executed + assert engine._tasks["gated-task"].status == TaskStatus.COMPLETED +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine_approval.py::TestEngineApprovalGate::test_non_gated_task_unaffected -v` +Expected: FAIL with `AttributeError: 'ScanEngine' object has no attribute 'set_approval_registry'` + +- [ ] **Step 3: Add approval fields to ScanEngine.__init__** + +In `packages/cli/src/opentools/scanner/engine.py`, add after the existing `_pipeline_results` field (around line 69): + +```python + # Approval gate (optional — engine works without it) + self._approval_registry: Any | None = None + self._approval_store: Any | None = None +``` + +Add public setters after the existing `set_cache` method: + +```python + def set_approval_registry(self, registry: Any) -> None: + """Set the ApprovalRegistry for HITL gate support.""" + self._approval_registry = registry + + def set_approval_store(self, store: Any) -> None: + """Set the store for persisting gate state.""" + self._approval_store = store +``` + +- [ ] **Step 4: Modify _execute_task — insert gate phase** + +Replace the existing `_execute_task` method (lines 215-251) with: + +```python + async def _execute_task( + self, task: ScanTask, executor: TaskExecutor + ) -> TaskOutput: + """Check cache → gate phase → acquire resource → dispatch to executor → release.""" + # Cache check + if task.cache_key and task.cache_key in self._cache: + return self._cache[task.cache_key] + + # --- GATE PHASE (if task requires approval and registry is configured) --- + if ( + task.requires_approval is not None + and self._approval_registry is not None + and self._approval_store is not None + ): + gate_result = await self._run_approval_gate(task) + if gate_result is not None: + return gate_result # rejected or expired + + # --- NORMAL EXECUTION --- + resource_group = task.resource_group or task.task_type.value + + if task.retry_policy is not None: + from opentools.shared.retry import execute_with_retry + + async def _attempt() -> TaskOutput: + await self._pool.acquire(task.id, task.priority, resource_group) + try: + return await executor.execute( + task, lambda _chunk: None, self._cancellation + ) + finally: + self._pool.release(resource_group) + + output = await execute_with_retry(_attempt, task.retry_policy) + else: + await self._pool.acquire(task.id, task.priority, resource_group) + try: + output = await executor.execute( + task, lambda _chunk: None, self._cancellation + ) + finally: + self._pool.release(resource_group) + + # Populate cache on success + if task.cache_key and output.exit_code == 0: + self._cache[task.cache_key] = output.model_copy(update={"cached": True}) + + return output + + async def _run_approval_gate(self, task: ScanTask) -> TaskOutput | None: + """Execute the approval gate phase. Returns None if approved (proceed), + or a TaskOutput if rejected/expired (stop).""" + import uuid + from datetime import datetime, timedelta, timezone + + timeout = task.requires_approval.timeout_seconds + ticket_id = f"gate-{task.id}-{uuid.uuid4().hex[:8]}" + expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout) + + # 1. Persist to store (source of truth) + task.approval_ticket_id = ticket_id + task.approval_expires_at = expires_at + await self._approval_store.update_task_status( + task.id, TaskStatus.AWAITING_APPROVAL.value, + approval_ticket_id=ticket_id, + approval_expires_at=expires_at.isoformat(), + ) + + # 2. Publish SSE event (if event bus supports it) + # The event bus is best-effort; failure here doesn't block the gate. + try: + self._event_bus.publish({ + "type": "approval_required", + "scan_id": self.scan.id, + "task_id": task.id, + "ticket_id": ticket_id, + "tool": task.tool, + "command": task.command, + "description": task.requires_approval.description, + "expires_at": expires_at.isoformat(), + }) + except Exception: + pass + + # 3. Register event and sleep + event = self._approval_registry.register(ticket_id) + + remaining = (expires_at - datetime.now(timezone.utc)).total_seconds() + if remaining <= 0: + self._approval_registry.remove(ticket_id) + await self._approval_store.update_task_status( + task.id, "approval_expired", + ) + return TaskOutput(exit_code=2, stderr="approval expired before gate could sleep") + + # Acquire from unlimited approval_gate group while sleeping + await self._pool.acquire(task.id, task.priority, "approval_gate") + try: + try: + await asyncio.wait_for(event.wait(), timeout=remaining) + except asyncio.TimeoutError: + pass # handled by DB read below + finally: + self._pool.release("approval_gate") + self._approval_registry.remove(ticket_id) + + # 4. READ TRUTH from store (never trust why we woke up) + stored = await self._approval_store.get_task_status(task.id) + if stored is not None: + status = stored.get("status", "") + else: + status = "" + + if status == "approved": + # Gate passed — update task status back to RUNNING and return None + # to signal _execute_task to proceed to the real executor + task.status = TaskStatus.RUNNING + return None + + if status == "rejected": + return TaskOutput( + exit_code=1, + stderr="rejected by operator", + ) + + # Timeout or unknown state + await self._approval_store.update_task_status( + task.id, "approval_expired", + ) + return TaskOutput( + exit_code=2, + stderr="approval expired", + ) +``` + +- [ ] **Step 5: Run integration tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine_approval.py -v` +Expected: All PASS + +- [ ] **Step 6: Run existing engine tests for regression check** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_engine.py -v` +Expected: All PASS — no regressions. Tasks without `requires_approval` are unaffected. + +- [ ] **Step 7: Commit** + +```bash +git add packages/cli/src/opentools/scanner/engine.py packages/cli/tests/test_scanner/test_engine_approval.py +git commit -m "feat(engine): integrate HITL approval gate phase into _execute_task" +``` + +--- + +### Task 13: FastAPI Gate Endpoints — List, Approve, Reject + +**Files:** +- Modify: `packages/web/backend/app/models.py:252-284` (add columns to ScanTaskRecord) +- Modify: `packages/web/backend/app/routes/scans.py` (add gate endpoints after cancel endpoint) +- Test: `packages/web/backend/tests/test_gate_routes.py` + +**Context:** The approve/reject routes enforce **write-before-signal**: they update the database first, then signal the in-memory event. The `ApprovalRegistry` singleton is stored in `_active_scans[scan_id]` alongside the engine reference, following the existing pattern at [api.py:199](packages/cli/src/opentools/scanner/api.py#L199). + +- [ ] **Step 1: Add approval columns to ScanTaskRecord** + +In `packages/web/backend/app/models.py`, add after line 284 (after `completed_at`): + +```python +class ScanTaskRecord(SQLModel, table=True): + # ... existing fields ... + # Approval gate fields + approval_ticket_id: Optional[str] = None + approval_expires_at: Optional[datetime] = Field(default=None, **_TZ_KW) +``` + +- [ ] **Step 2: Add gate request/response models and endpoints to scans.py** + +Append to `packages/web/backend/app/routes/scans.py`, after the cancel endpoint (after line 380): + +```python +# --------------------------------------------------------------------------- +# Approval gate endpoints +# --------------------------------------------------------------------------- + + +class GateResponse(BaseModel): + ticket_id: str + task_id: str + tool: str + command: Optional[str] = None + description: str + status: str + expires_at: Optional[str] = None + + +class GateDecisionResponse(BaseModel): + ticket_id: str + decision: str + + +class GateRejectRequest(BaseModel): + reason: str = "operator rejected" + + +@router.get("/{scan_id}/gates") +async def list_pending_gates( + scan_id: str, + session: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +): + """List tasks awaiting operator approval for a scan.""" + svc = ScanService(session, user) + scan = await svc.get_scan(scan_id) + if scan is None: + raise HTTPException(status_code=404, detail="Scan not found") + + tasks = await svc.get_scan_tasks(scan_id) + gates = [] + for t in tasks: + if t.status == "awaiting_approval" and t.approval_ticket_id: + gates.append(GateResponse( + ticket_id=t.approval_ticket_id, + task_id=t.id, + tool=t.tool, + command=t.command, + description="", # extracted from requires_approval JSON if available + status=t.status, + expires_at=t.approval_expires_at.isoformat() if t.approval_expires_at else None, + )) + return {"scan_id": scan_id, "gates": gates} + + +@router.post("/{scan_id}/gates/{ticket_id}/approve") +async def approve_gate( + scan_id: str, + ticket_id: str, + session: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +): + """Approve a pending approval gate. Write-before-signal.""" + svc = ScanService(session, user) + + # 1. Validate + task_record = await svc.get_task_by_ticket(scan_id, ticket_id) + if task_record is None: + raise HTTPException(status_code=404, detail="Gate ticket not found") + if task_record.status != "awaiting_approval": + raise HTTPException( + status_code=409, + detail=f"Gate already resolved: {task_record.status}", + ) + + # 2. PERSIST FIRST (source of truth) + await svc.update_task_approval_status(task_record.id, "approved") + await session.commit() + + # 3. Signal event (best-effort tripwire) + from opentools.scanner.api import _active_scans + entry = _active_scans.get(scan_id, {}) + registry = entry.get("approval_registry") + if registry is not None: + registry.signal(ticket_id) + + return GateDecisionResponse(ticket_id=ticket_id, decision="approved") + + +@router.post("/{scan_id}/gates/{ticket_id}/reject") +async def reject_gate( + scan_id: str, + ticket_id: str, + body: GateRejectRequest = GateRejectRequest(), + session: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +): + """Reject a pending approval gate. Write-before-signal.""" + svc = ScanService(session, user) + + # 1. Validate + task_record = await svc.get_task_by_ticket(scan_id, ticket_id) + if task_record is None: + raise HTTPException(status_code=404, detail="Gate ticket not found") + if task_record.status != "awaiting_approval": + raise HTTPException( + status_code=409, + detail=f"Gate already resolved: {task_record.status}", + ) + + # 2. PERSIST FIRST + await svc.update_task_approval_status(task_record.id, "rejected") + await session.commit() + + # 3. Signal event + from opentools.scanner.api import _active_scans + entry = _active_scans.get(scan_id, {}) + registry = entry.get("approval_registry") + if registry is not None: + registry.signal(ticket_id) + + return GateDecisionResponse(ticket_id=ticket_id, decision="rejected") +``` + +- [ ] **Step 3: Add helper methods to ScanService** + +These are the two new methods the gate routes need. Add to `packages/web/backend/app/services/scan_service.py`: + +```python + async def get_task_by_ticket( + self, scan_id: str, ticket_id: str + ) -> ScanTaskRecord | None: + """Find a task by its approval ticket ID within a scan.""" + from sqlalchemy import select + stmt = ( + select(ScanTaskRecord) + .where(ScanTaskRecord.scan_id == scan_id) + .where(ScanTaskRecord.approval_ticket_id == ticket_id) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def update_task_approval_status( + self, task_id: str, status: str + ) -> None: + """Update a task's status (for gate approval/rejection).""" + from sqlalchemy import update + stmt = ( + update(ScanTaskRecord) + .where(ScanTaskRecord.id == task_id) + .values(status=status) + ) + await self.session.execute(stmt) +``` + +- [ ] **Step 4: Write route tests** + +```python +# packages/web/backend/tests/test_gate_routes.py +"""Tests for approval gate API routes.""" + +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_list_gates_empty(client: AsyncClient, auth_headers: dict): + """No pending gates returns empty list.""" + # This test depends on the conftest fixtures creating a scan + # Minimal smoke test — full integration requires seeded data + resp = await client.get( + "/api/v1/scans/nonexistent/gates", headers=auth_headers, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_approve_missing_ticket(client: AsyncClient, auth_headers: dict): + resp = await client.post( + "/api/v1/scans/scan-1/gates/nonexistent/approve", + headers=auth_headers, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_reject_missing_ticket(client: AsyncClient, auth_headers: dict): + resp = await client.post( + "/api/v1/scans/scan-1/gates/nonexistent/reject", + headers=auth_headers, + json={"reason": "test"}, + ) + assert resp.status_code == 404 +``` + +- [ ] **Step 5: Run tests** + +Run: `cd packages/web/backend && python -m pytest tests/test_gate_routes.py -v` +Expected: PASS (404 cases) + +- [ ] **Step 6: Commit** + +```bash +git add packages/web/backend/app/models.py packages/web/backend/app/routes/scans.py packages/web/backend/app/services/scan_service.py packages/web/backend/tests/test_gate_routes.py +git commit -m "feat(api): add HITL gate endpoints — list, approve, reject with write-before-signal" +``` + +--- + +## Task 6b: Vultr Provider + +### Task 6b: VultrProvider Implementation + +**Files:** +- Create: `packages/cli/src/opentools/scanner/infra/vultr.py` +- Test: `packages/cli/tests/test_scanner/test_infra_vultr.py` + +**Context:** Implements `CloudNodeProvider` ABC (defined in Phase B Task 6). Uses `httpx.AsyncClient` with `MockTransport` for testing. The `sshkey_id` array is a hard requirement in the creation payload — without it, the instance boots with no authorized keys and the SSH tunnel fails. + +- [ ] **Step 1: Write failing tests for VultrProvider** + +```python +# packages/cli/tests/test_scanner/test_infra_vultr.py +"""Tests for VultrProvider — ephemeral Vultr instance provisioning.""" + +import json + +import httpx +import pytest + +from opentools.scanner.infra.provider import ( + ProvisioningError, + ProvisioningTimeout, +) +from opentools.scanner.infra.vultr import VultrProvider + + +class TestVultrProviderCreate: + @pytest.mark.asyncio + async def test_create_sends_correct_payload(self): + captured = {} + + async def handler(request: httpx.Request) -> httpx.Response: + captured["body"] = json.loads(request.content) + captured["url"] = str(request.url) + return httpx.Response(202, json={ + "instance": {"id": "vtr-abc123"}, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient( + transport=transport, + base_url="https://api.vultr.com/v2", + ) + provider = VultrProvider(client=client) + + result = await provider.create_node( + region="ewr", + ssh_public_key="ssh-key-uuid-1234", + tags=["opentools-ephemeral-proxy", "scan:scan-1"], + ) + + assert result == "vtr-abc123" + body = captured["body"] + assert body["region"] == "ewr" + assert body["plan"] == "vc2-1c-0.5gb" + assert body["os_id"] == 2284 + assert body["sshkey_id"] == ["ssh-key-uuid-1234"] + assert "opentools-ephemeral-proxy" in body["tags"] + assert body["backups"] == "disabled" + assert body["activation_email"] is False + + @pytest.mark.asyncio + async def test_create_returns_instance_id(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(202, json={ + "instance": {"id": "vtr-def456"}, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + result = await provider.create_node("lax", "key-id", ["tag"]) + assert result == "vtr-def456" + + +class TestVultrProviderPoll: + @pytest.mark.asyncio + async def test_active_with_ip(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", + "status": "active", + "power_status": "running", + "main_ip": "149.28.1.1", + }, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + status, ip = await provider.poll_status("vtr-abc") + assert status == "active" + assert ip == "149.28.1.1" + + @pytest.mark.asyncio + async def test_active_with_zero_ip_is_creating(self): + """Vultr reports active before IP is assigned — treat as creating.""" + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", + "status": "active", + "power_status": "running", + "main_ip": "0.0.0.0", + }, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + status, ip = await provider.poll_status("vtr-abc") + assert status == "creating" + assert ip is None + + @pytest.mark.asyncio + async def test_pending_status(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", + "status": "pending", + "power_status": "stopped", + "main_ip": "0.0.0.0", + }, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + status, ip = await provider.poll_status("vtr-abc") + assert status == "creating" + assert ip is None + + +class TestVultrProviderDestroy: + @pytest.mark.asyncio + async def test_destroy_204(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(204) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + await provider.destroy_node("vtr-abc") # should not raise + + @pytest.mark.asyncio + async def test_destroy_404_idempotent(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + await provider.destroy_node("vtr-abc") # should not raise + + +class TestVultrProviderListByTag: + @pytest.mark.asyncio + async def test_list_nodes_by_tag(self): + async def handler(request: httpx.Request) -> httpx.Response: + assert "tag=opentools-ephemeral-proxy" in str(request.url) + return httpx.Response(200, json={ + "instances": [ + {"id": "vtr-1"}, + {"id": "vtr-2"}, + ], + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + ids = await provider.list_nodes_by_tag("opentools-ephemeral-proxy") + assert ids == ["vtr-1", "vtr-2"] + + @pytest.mark.asyncio + async def test_list_empty(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"instances": []}) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + ids = await provider.list_nodes_by_tag("opentools-ephemeral-proxy") + assert ids == [] + + +class TestVultrProviderWaitUntilReady: + @pytest.mark.asyncio + async def test_waits_for_ip_assignment(self): + call_count = 0 + + async def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count < 3: + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", "status": "active", + "power_status": "running", "main_ip": "0.0.0.0", + }, + }) + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", "status": "active", + "power_status": "running", "main_ip": "149.28.1.1", + }, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + ip = await provider.wait_until_ready("vtr-abc", poll_interval=0.01, max_polls=10) + assert ip == "149.28.1.1" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_timeout_raises(self): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={ + "instance": { + "id": "vtr-abc", "status": "pending", + "power_status": "stopped", "main_ip": "0.0.0.0", + }, + }) + + transport = httpx.MockTransport(handler) + client = httpx.AsyncClient(transport=transport, base_url="https://api.vultr.com/v2") + provider = VultrProvider(client=client) + + with pytest.raises(ProvisioningTimeout): + await provider.wait_until_ready("vtr-abc", poll_interval=0.01, max_polls=3) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_vultr.py::TestVultrProviderCreate::test_create_sends_correct_payload -v` +Expected: FAIL with `ModuleNotFoundError` + +- [ ] **Step 3: Implement VultrProvider** + +```python +# packages/cli/src/opentools/scanner/infra/vultr.py +"""VultrProvider — ephemeral Vultr instance provisioning via REST API. + +Implements CloudNodeProvider ABC. Uses httpx.AsyncClient for all HTTP +calls — fully async, never blocks the event loop. + +IMPORTANT: The sshkey_id array is REQUIRED in the creation payload. +Without it, the instance boots with no authorized SSH keys and the +automated tunnel establishment will fail with Permission denied. +""" + +from __future__ import annotations + +import uuid + +import httpx + +from opentools.scanner.infra.provider import CloudNodeProvider + + +class VultrProvider(CloudNodeProvider): + """Provision ephemeral instances via the Vultr API.""" + + def __init__(self, client: httpx.AsyncClient) -> None: + self._client = client + + @classmethod + def from_token(cls, api_token: str) -> VultrProvider: + """Create a provider with a new httpx client using the given API token.""" + client = httpx.AsyncClient( + base_url="https://api.vultr.com/v2", + headers={"Authorization": f"Bearer {api_token}"}, + timeout=30.0, + ) + return cls(client=client) + + async def create_node( + self, region: str, ssh_public_key: str, tags: list[str], + ) -> str: + """Create a Vultr instance. ssh_public_key must be a Vultr SSH key UUID.""" + resp = await self._client.post("/instances", json={ + "region": region, + "plan": "vc2-1c-0.5gb", + "os_id": 2284, # Ubuntu 24.04 LTS + "label": f"ot-proxy-{uuid.uuid4().hex[:8]}", + "sshkey_id": [ssh_public_key], + "tags": tags, + "backups": "disabled", + "activation_email": False, + }) + resp.raise_for_status() + return resp.json()["instance"]["id"] + + async def poll_status(self, provider_id: str) -> tuple[str, str | None]: + """Poll instance status. Returns ("active", ip) only when IP is assigned.""" + resp = await self._client.get(f"/instances/{provider_id}") + resp.raise_for_status() + instance = resp.json()["instance"] + + vultr_status = instance.get("status", "") + power_status = instance.get("power_status", "") + main_ip = instance.get("main_ip", "0.0.0.0") + + # Vultr can report "active" before IP is assigned + if ( + vultr_status == "active" + and power_status == "running" + and main_ip != "0.0.0.0" + ): + return "active", main_ip + + return "creating", None + + async def destroy_node(self, provider_id: str) -> None: + """Destroy a Vultr instance. Idempotent — 404 is not an error.""" + resp = await self._client.delete(f"/instances/{provider_id}") + if resp.status_code not in (204, 404): + resp.raise_for_status() + + async def list_nodes_by_tag(self, tag: str) -> list[str]: + """List all instance IDs with the given tag (for orphan sweeping).""" + resp = await self._client.get("/instances", params={"tag": tag}) + resp.raise_for_status() + instances = resp.json().get("instances", []) + return [inst["id"] for inst in instances] +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/test_scanner/test_infra_vultr.py -v` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/scanner/infra/vultr.py packages/cli/tests/test_scanner/test_infra_vultr.py +git commit -m "feat(infra): add VultrProvider with SSH key injection and 0.0.0.0 IP guard" +``` + +--- + +## Appendix: Combined Build Order (All Phases) + +For reference, here is the complete task sequence across all three phases: + +**Phase A — Dynamic DAG Mutation (existing plan)** +1. KillChainState + IntelBundle models +2. OutputAnalyzer protocol + NmapAnalyzer +3. MutationStrategy protocol + RedisProbeStrategy +4. Engine integration (mutation layer in `_mark_completed`) + +**Phase B — Ephemeral Proxy Routing (existing plan)** +5. Add `env` param to `run_streaming` +6. CloudNodeProvider ABC + DigitalOcean implementation +6b. **VultrProvider implementation** ← NEW +7. `ephemeral_proxy` context manager + shielded teardown +8. ProxiedShellExecutor +9. Orphan sweeper + +**Phase C — HITL Approval Gate** ← NEW +10. Model additions (ApprovalRequirement, AWAITING_APPROVAL, ScanTask fields) +11. ApprovalRegistry (in-memory notification hub) +12. Engine integration (gate phase in `_execute_task`) +13. FastAPI gate endpoints (list, approve, reject) diff --git a/docs/superpowers/plans/2026-04-13-phase3c2-attack-chain-graph-view.md b/docs/superpowers/plans/2026-04-13-phase3c2-attack-chain-graph-view.md new file mode 100644 index 0000000..88b0dd1 --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-phase3c2-attack-chain-graph-view.md @@ -0,0 +1,1865 @@ +# Phase 3C.2: Attack Chain Graph View — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add an interactive per-engagement attack chain graph visualization to the web dashboard, with edge curation (confirm/reject), MITRE ATT&CK phase coloring, and server-side filtering for scale. + +**Architecture:** Standalone Vue page at `/engagements/:id/chain` wraps `force-graph` (vasturiano). Backend serves filtered subgraphs via a new `GET /api/chain/subgraph` endpoint that caps nodes and filters by severity/status. Edge curation uses `PATCH /api/chain/relations/:id`. No new database tables — builds on 3C.1 models. + +**Tech Stack:** FastAPI, SQLAlchemy async, Vue 3, PrimeVue, `force-graph` (vasturiano), TanStack Query + +**Spec:** `docs/superpowers/specs/2026-04-13-phase3c2-attack-chain-graph-view-design.md` + +--- + +## File Map + +### Backend (new/modified) + +| File | Action | Responsibility | +|------|--------|---------------| +| `packages/web/backend/app/routes/chain.py` | Modify | Add `GET /api/chain/subgraph` and `PATCH /api/chain/relations/{relation_id}` endpoints | +| `packages/web/backend/app/services/chain_service.py` | Modify | Add `subgraph_for_engagement()` and `update_relation_status()` methods | +| `packages/web/backend/app/services/chain_dto.py` | Modify | Add `relation_to_link_dict()` for force-graph link shape with drift computation | +| `packages/web/backend/tests/test_chain_subgraph.py` | Create | Tests for subgraph endpoint filtering, capping, neighborhood, drift | +| `packages/web/backend/tests/test_chain_curation.py` | Create | Tests for relation PATCH (valid transitions, invalid status, auth scoping) | + +### Frontend (new/modified) + +| File | Action | Responsibility | +|------|--------|---------------| +| `packages/web/frontend/src/views/ChainGraphView.vue` | Create | Page component — data fetching, filter state, layout orchestration | +| `packages/web/frontend/src/components/ForceGraphCanvas.vue` | Create | Wrapper around `force-graph` — rendering config, custom draw callbacks, interaction events | +| `packages/web/frontend/src/components/ChainDetailPanel.vue` | Create | Right drawer — node details, edge details with reasons, curation buttons | +| `packages/web/frontend/src/components/ChainFilterToolbar.vue` | Create | Severity/status toggle buttons | +| `packages/web/frontend/src/components/ChainLegend.vue` | Create | Bottom bar — severity color key, edge style key, node count | +| `packages/web/frontend/src/components/ChainEmptyState.vue` | Create | Empty state + rebuild progress polling | +| `packages/web/frontend/src/router/index.ts` | Modify | Add `/engagements/:id/chain` route | +| `packages/web/frontend/src/views/EngagementDetailView.vue` | Modify | Add "View Attack Chain" button | + +--- + +## Task 1: Backend — `relation_to_link_dict` DTO with drift computation + +**Files:** +- Modify: `packages/web/backend/app/services/chain_dto.py` + +This task adds the conversion function that produces the force-graph link shape with inline drift computation. All subsequent backend tasks depend on this. + +- [ ] **Step 1: Write the `relation_to_link_dict` function** + +Add to `packages/web/backend/app/services/chain_dto.py`: + +```python +def relation_to_link_dict(relation: FindingRelation) -> dict[str, Any]: + """Convert a CLI ``FindingRelation`` to a force-graph link dict. + + Includes drift detection: if the relation has status USER_CONFIRMED + and the current reasons differ from the confirmed_at_reasons snapshot, + drift is True. + """ + status_value = ( + relation.status.value + if hasattr(relation.status, "value") + else str(relation.status) + ) + + # Drift: true if user confirmed but reasons have since changed + drift = False + if status_value == "user_confirmed" and relation.confirmed_at_reasons is not None: + current_rules = sorted(r.rule for r in relation.reasons) + confirmed_rules = sorted(r.rule for r in relation.confirmed_at_reasons) + drift = current_rules != confirmed_rules + + return { + "id": relation.id, + "source": relation.source_finding_id, + "target": relation.target_finding_id, + "value": relation.weight, + "status": status_value, + "drift": drift, + "reasons": [r.rule for r in relation.reasons], + "relation_type": relation.llm_relation_type, + "rationale": relation.llm_rationale, + } +``` + +- [ ] **Step 2: Verify the module still imports cleanly** + +Run: `cd packages/web/backend && python -c "from app.services.chain_dto import relation_to_link_dict; print('OK')"` +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/app/services/chain_dto.py +git commit -m "feat(chain): add relation_to_link_dict DTO with drift detection" +``` + +--- + +## Task 2: Backend — `subgraph_for_engagement` service method + +**Files:** +- Modify: `packages/web/backend/app/services/chain_service.py` + +Adds the service method that queries the store for findings + relations scoped to an engagement, applies severity/status filters, enforces max_nodes cap, and returns the force-graph-shaped response. + +- [ ] **Step 1: Add imports at top of chain_service.py** + +Add these imports to the existing import block: + +```python +from app.services.chain_dto import relation_to_link_dict +``` + +- [ ] **Step 2: Add `subgraph_for_engagement` method to `ChainService`** + +```python + async def subgraph_for_engagement( + self, + session: AsyncSession, + *, + user_id: uuid.UUID, + engagement_id: str, + severities: set[str] | None = None, + statuses: set[str] | None = None, + max_nodes: int = 500, + seed_finding_id: str | None = None, + hops: int = 2, + format: str = "force-graph", + ) -> dict[str, Any]: + """Build a filtered subgraph for one engagement. + + Returns a dict with 'graph' (force-graph or canonical shape) + and 'meta' (total_findings, rendered_findings, filtered, generation). + """ + from opentools.chain.config import get_chain_config + from opentools.chain.query.graph_cache import GraphCache + from opentools.chain.query.adapters import to_canonical_json, to_force_graph + from opentools.chain.types import RelationStatus + + from sqlalchemy import select, func + from app.models import Finding, ChainFindingRelation + + store = chain_store_from_session(session) + await store.initialize() + + # Count total findings in engagement (for meta) + total_stmt = select(func.count()).select_from(Finding).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + Finding.deleted_at.is_(None), + ) + total_result = await session.execute(total_stmt) + total_findings = total_result.scalar() or 0 + + # Fetch findings for this engagement, applying severity filter + finding_stmt = select(Finding).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + Finding.deleted_at.is_(None), + ) + if severities: + finding_stmt = finding_stmt.where(Finding.severity.in_(severities)) + finding_stmt = finding_stmt.limit(max_nodes) + + finding_result = await session.execute(finding_stmt) + findings = list(finding_result.scalars().all()) + finding_ids = {f.id for f in findings} + + if not finding_ids: + empty_graph = {"nodes": [], "links": []} if format == "force-graph" else {"schema_version": "1.0", "nodes": [], "edges": [], "metadata": {}} + return { + "graph": empty_graph, + "meta": { + "total_findings": total_findings, + "rendered_findings": 0, + "filtered": bool(severities) or total_findings > max_nodes, + "generation": 0, + }, + } + + # Default status filter + if statuses is None: + statuses = {"auto_confirmed", "user_confirmed", "candidate"} + + # Fetch relations where both endpoints are in finding_ids + rel_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.source_finding_id.in_(finding_ids), + ChainFindingRelation.target_finding_id.in_(finding_ids), + ChainFindingRelation.status.in_(statuses), + ) + rel_result = await session.execute(rel_stmt) + relations_orm = list(rel_result.scalars().all()) + + # Build nodes + nodes = [ + { + "id": f.id, + "name": f.title, + "severity": f.severity, + "tool": f.tool, + "phase": f.phase, + } + for f in findings + ] + + # Build links via DTO + from opentools.chain.models import FindingRelation as DomainRelation, RelationReason + from opentools.chain.stores.postgres_async import _orm_to_relation + + links = [ + relation_to_link_dict(_orm_to_relation(r)) + for r in relations_orm + ] + + # Get latest generation from most recent linker run + from app.models import ChainLinkerRun + gen_stmt = ( + select(ChainLinkerRun.generation) + .where(ChainLinkerRun.user_id == user_id) + .order_by(ChainLinkerRun.started_at.desc()) + .limit(1) + ) + gen_result = await session.execute(gen_stmt) + generation = gen_result.scalar() or 0 + + if format == "force-graph": + graph = {"nodes": nodes, "links": links} + else: + graph = { + "schema_version": "1.0", + "nodes": [{"id": n["id"], "type": "finding", "severity": n["severity"], "tool": n["tool"], "title": n["name"]} for n in nodes], + "edges": [{"source": l["source"], "target": l["target"], "weight": l["value"], "status": l["status"], "symmetric": False, "reasons": l["reasons"], "relation_type": l["relation_type"], "rationale": l["rationale"]} for l in links], + "metadata": {"generation": generation, "max_weight": max((l["value"] for l in links), default=0)}, + } + + return { + "graph": graph, + "meta": { + "total_findings": total_findings, + "rendered_findings": len(findings), + "filtered": bool(severities) or len(findings) < total_findings, + "generation": generation, + }, + } +``` + +- [ ] **Step 3: Verify the module still imports** + +Run: `cd packages/web/backend && python -c "from app.services.chain_service import ChainService; print('OK')"` +Expected: `OK` + +- [ ] **Step 4: Commit** + +```bash +git add packages/web/backend/app/services/chain_service.py +git commit -m "feat(chain): add subgraph_for_engagement service method" +``` + +--- + +## Task 3: Backend — `update_relation_status` service method + +**Files:** +- Modify: `packages/web/backend/app/services/chain_service.py` + +Adds the service method for edge curation — updates relation status to `user_confirmed` or `user_rejected`, snapshots `confirmed_at_reasons_json` on confirm. + +- [ ] **Step 1: Add `update_relation_status` method to `ChainService`** + +```python + async def update_relation_status( + self, + session: AsyncSession, + *, + user_id: uuid.UUID, + relation_id: str, + new_status: str, + ) -> dict[str, Any] | None: + """Update a relation's status for edge curation. + + Only 'user_confirmed' and 'user_rejected' are valid. + On confirm, snapshots current reasons_json into confirmed_at_reasons_json. + Returns the updated relation dict, or None if not found. + """ + from sqlalchemy import select, update + from app.models import ChainFindingRelation + from datetime import datetime, timezone + from opentools.chain.stores.postgres_async import _orm_to_relation + from app.services.chain_dto import relation_to_dict + + # Fetch the relation, scoped to user + stmt = select(ChainFindingRelation).where( + ChainFindingRelation.id == relation_id, + ChainFindingRelation.user_id == user_id, + ) + result = await session.execute(stmt) + relation = result.scalar_one_or_none() + if relation is None: + return None + + # Update status + relation.status = new_status + relation.updated_at = datetime.now(timezone.utc) + + # On confirm, snapshot current reasons for drift detection + if new_status == "user_confirmed": + relation.confirmed_at_reasons_json = relation.reasons_json + + session.add(relation) + await session.commit() + await session.refresh(relation) + + return relation_to_dict(_orm_to_relation(relation)) +``` + +- [ ] **Step 2: Verify import** + +Run: `cd packages/web/backend && python -c "from app.services.chain_service import ChainService; print('OK')"` +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/app/services/chain_service.py +git commit -m "feat(chain): add update_relation_status for edge curation" +``` + +--- + +## Task 4: Backend — Subgraph and curation route endpoints + +**Files:** +- Modify: `packages/web/backend/app/routes/chain.py` + +Adds `GET /api/chain/subgraph` and `PATCH /api/chain/relations/{relation_id}`. + +- [ ] **Step 1: Add new Pydantic models for the endpoints** + +Add to `packages/web/backend/app/routes/chain.py`, after the existing model classes: + +```python +class SubgraphMeta(BaseModel): + total_findings: int + rendered_findings: int + filtered: bool + generation: int + + +class SubgraphResponse(BaseModel): + graph: dict + meta: SubgraphMeta + + +class RelationStatusUpdate(BaseModel): + status: str +``` + +- [ ] **Step 2: Add the subgraph endpoint** + +```python +@router.get("/subgraph", response_model=SubgraphResponse) +async def get_subgraph( + engagement_id: str, + severity: Optional[str] = None, + status: Optional[str] = None, + max_nodes: int = 500, + seed_finding_id: Optional[str] = None, + hops: int = 2, + format: str = "force-graph", + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), + service: ChainService = Depends(get_chain_service), +) -> SubgraphResponse: + severities = set(severity.split(",")) if severity else None + statuses = set(status.split(",")) if status else None + + result = await service.subgraph_for_engagement( + db, + user_id=user.id, + engagement_id=engagement_id, + severities=severities, + statuses=statuses, + max_nodes=max_nodes, + seed_finding_id=seed_finding_id, + hops=hops, + format=format, + ) + return SubgraphResponse( + graph=result["graph"], + meta=SubgraphMeta(**result["meta"]), + ) +``` + +- [ ] **Step 3: Add the relation curation endpoint** + +```python +@router.patch("/relations/{relation_id}") +async def update_relation( + relation_id: str, + body: RelationStatusUpdate, + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), + service: ChainService = Depends(get_chain_service), +): + valid_statuses = {"user_confirmed", "user_rejected"} + if body.status not in valid_statuses: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"status must be one of: {', '.join(valid_statuses)}", + ) + + result = await service.update_relation_status( + db, user_id=user.id, relation_id=relation_id, new_status=body.status, + ) + if result is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="relation not found") + return result +``` + +- [ ] **Step 4: Verify the app starts** + +Run: `cd packages/web/backend && python -c "from app.main import app; print('OK')"` +Expected: `OK` + +- [ ] **Step 5: Commit** + +```bash +git add packages/web/backend/app/routes/chain.py +git commit -m "feat(chain): add subgraph and relation curation endpoints" +``` + +--- + +## Task 5: Backend — Subgraph endpoint tests + +**Files:** +- Create: `packages/web/backend/tests/test_chain_subgraph.py` + +- [ ] **Step 1: Write subgraph endpoint tests** + +Create `packages/web/backend/tests/test_chain_subgraph.py`: + +```python +"""Subgraph endpoint tests (Phase 3C.2).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_engagement(user_id, eng_id): + async with test_session_factory() as session: + session.add(Engagement( + id=eng_id, user_id=user_id, name="Test", target="10.0.0.0/24", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +async def _seed_finding(user_id, eng_id, finding_id, severity="high", phase=None): + async with test_session_factory() as session: + session.add(Finding( + id=finding_id, user_id=user_id, engagement_id=eng_id, + tool="nmap", severity=severity, title=f"Finding {finding_id}", + phase=phase, created_at=NOW, + )) + await session.commit() + + +async def _seed_relation(user_id, src_id, tgt_id, rel_id, status="auto_confirmed", weight=0.8): + async with test_session_factory() as session: + session.add(ChainFindingRelation( + id=rel_id, user_id=user_id, source_finding_id=src_id, + target_finding_id=tgt_id, weight=weight, status=status, + symmetric=False, created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_subgraph_empty_engagement(auth_client): + """Engagement with no findings returns empty graph.""" + user_id = await _get_user_id(auth_client) + await _seed_engagement(user_id, "eng-empty") + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-empty") + assert resp.status_code == 200 + data = resp.json() + assert data["graph"]["nodes"] == [] + assert data["graph"]["links"] == [] + assert data["meta"]["total_findings"] == 0 + assert data["meta"]["rendered_findings"] == 0 + + +@pytest.mark.asyncio +async def test_subgraph_returns_nodes_and_links(auth_client): + """Seeded findings and relations appear in subgraph response.""" + user_id = await _get_user_id(auth_client) + await _seed_engagement(user_id, "eng-sub") + await _seed_finding(user_id, "eng-sub", "f-1", severity="critical") + await _seed_finding(user_id, "eng-sub", "f-2", severity="high") + await _seed_relation(user_id, "f-1", "f-2", "rel-1") + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-sub") + assert resp.status_code == 200 + data = resp.json() + assert len(data["graph"]["nodes"]) == 2 + assert len(data["graph"]["links"]) == 1 + link = data["graph"]["links"][0] + assert link["id"] == "rel-1" + assert link["source"] == "f-1" + assert link["target"] == "f-2" + assert "drift" in link + + +@pytest.mark.asyncio +async def test_subgraph_severity_filter(auth_client): + """Severity filter excludes non-matching findings.""" + user_id = await _get_user_id(auth_client) + await _seed_engagement(user_id, "eng-sev") + await _seed_finding(user_id, "eng-sev", "f-crit", severity="critical") + await _seed_finding(user_id, "eng-sev", "f-low", severity="low") + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-sev&severity=critical") + assert resp.status_code == 200 + data = resp.json() + assert len(data["graph"]["nodes"]) == 1 + assert data["graph"]["nodes"][0]["severity"] == "critical" + + +@pytest.mark.asyncio +async def test_subgraph_status_filter(auth_client): + """Status filter excludes non-matching relations.""" + user_id = await _get_user_id(auth_client) + await _seed_engagement(user_id, "eng-stat") + await _seed_finding(user_id, "eng-stat", "f-a") + await _seed_finding(user_id, "eng-stat", "f-b") + await _seed_relation(user_id, "f-a", "f-b", "rel-conf", status="auto_confirmed") + await _seed_relation(user_id, "f-b", "f-a", "rel-cand", status="candidate") + + # Only auto_confirmed + resp = await auth_client.get( + "/api/chain/subgraph?engagement_id=eng-stat&status=auto_confirmed" + ) + data = resp.json() + assert len(data["graph"]["links"]) == 1 + assert data["graph"]["links"][0]["status"] == "auto_confirmed" + + +@pytest.mark.asyncio +async def test_subgraph_max_nodes_cap(auth_client): + """max_nodes caps the number of returned findings.""" + user_id = await _get_user_id(auth_client) + await _seed_engagement(user_id, "eng-cap") + for i in range(10): + await _seed_finding(user_id, "eng-cap", f"f-cap-{i}") + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-cap&max_nodes=3") + data = resp.json() + assert len(data["graph"]["nodes"]) == 3 + assert data["meta"]["total_findings"] == 10 + assert data["meta"]["filtered"] is True + + +@pytest.mark.asyncio +async def test_subgraph_unauthenticated(client): + """Unauthenticated request returns 401.""" + resp = await client.get("/api/chain/subgraph?engagement_id=eng-x") + assert resp.status_code == 401 +``` + +- [ ] **Step 2: Run the tests** + +Run: `cd packages/web/backend && python -m pytest tests/test_chain_subgraph.py -v` +Expected: all tests PASS + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/tests/test_chain_subgraph.py +git commit -m "test(chain): subgraph endpoint tests — filters, cap, auth" +``` + +--- + +## Task 6: Backend — Curation endpoint tests + +**Files:** +- Create: `packages/web/backend/tests/test_chain_curation.py` + +- [ ] **Step 1: Write curation endpoint tests** + +Create `packages/web/backend/tests/test_chain_curation.py`: + +```python +"""Relation curation (PATCH) endpoint tests (Phase 3C.2).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_with_relation(user_id, rel_id="rel-cur", status="candidate"): + async with test_session_factory() as session: + session.add(Engagement( + id="eng-cur", user_id=user_id, name="Test", target="10.0.0.1", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + session.add(Finding( + id="f-cur-1", user_id=user_id, engagement_id="eng-cur", + tool="nmap", severity="high", title="Finding 1", created_at=NOW, + )) + session.add(Finding( + id="f-cur-2", user_id=user_id, engagement_id="eng-cur", + tool="nuclei", severity="medium", title="Finding 2", created_at=NOW, + )) + await session.flush() + session.add(ChainFindingRelation( + id=rel_id, user_id=user_id, source_finding_id="f-cur-1", + target_finding_id="f-cur-2", weight=0.75, status=status, + symmetric=False, reasons_json=b'[{"rule":"shared_strong_entity","weight_contribution":0.5,"idf_factor":null,"details":{}}]', + created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_confirm_candidate(auth_client): + """Confirming a candidate relation succeeds.""" + user_id = await _get_user_id(auth_client) + await _seed_with_relation(user_id, "rel-c1", status="candidate") + + resp = await auth_client.patch( + "/api/chain/relations/rel-c1", + json={"status": "user_confirmed"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "user_confirmed" + + +@pytest.mark.asyncio +async def test_reject_candidate(auth_client): + """Rejecting a candidate relation succeeds.""" + user_id = await _get_user_id(auth_client) + await _seed_with_relation(user_id, "rel-c2", status="candidate") + + resp = await auth_client.patch( + "/api/chain/relations/rel-c2", + json={"status": "user_rejected"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "user_rejected" + + +@pytest.mark.asyncio +async def test_toggle_confirmed_to_rejected(auth_client): + """User can change from confirmed to rejected.""" + user_id = await _get_user_id(auth_client) + await _seed_with_relation(user_id, "rel-c3", status="user_confirmed") + + resp = await auth_client.patch( + "/api/chain/relations/rel-c3", + json={"status": "user_rejected"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "user_rejected" + + +@pytest.mark.asyncio +async def test_invalid_status_returns_422(auth_client): + """Setting auto_confirmed via PATCH returns 422.""" + user_id = await _get_user_id(auth_client) + await _seed_with_relation(user_id, "rel-c4") + + resp = await auth_client.patch( + "/api/chain/relations/rel-c4", + json={"status": "auto_confirmed"}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_nonexistent_relation_returns_404(auth_client): + """Patching a nonexistent relation returns 404.""" + resp = await auth_client.patch( + "/api/chain/relations/rel-does-not-exist", + json={"status": "user_confirmed"}, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_confirm_snapshots_reasons(auth_client): + """Confirming snapshots reasons_json into confirmed_at_reasons_json.""" + user_id = await _get_user_id(auth_client) + await _seed_with_relation(user_id, "rel-c5", status="candidate") + + await auth_client.patch( + "/api/chain/relations/rel-c5", + json={"status": "user_confirmed"}, + ) + + # Verify in DB + async with test_session_factory() as session: + from sqlalchemy import select + result = await session.execute( + select(ChainFindingRelation).where(ChainFindingRelation.id == "rel-c5") + ) + rel = result.scalar_one() + assert rel.confirmed_at_reasons_json is not None + assert rel.confirmed_at_reasons_json == rel.reasons_json + + +@pytest.mark.asyncio +async def test_unauthenticated_returns_401(client): + """Unauthenticated curation request returns 401.""" + resp = await client.patch( + "/api/chain/relations/rel-x", + json={"status": "user_confirmed"}, + ) + assert resp.status_code == 401 +``` + +- [ ] **Step 2: Run the tests** + +Run: `cd packages/web/backend && python -m pytest tests/test_chain_curation.py -v` +Expected: all tests PASS + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/tests/test_chain_curation.py +git commit -m "test(chain): curation endpoint tests — transitions, validation, auth" +``` + +--- + +## Task 7: Frontend — Install `force-graph` and add route + +**Files:** +- Modify: `packages/web/frontend/package.json` (via npm) +- Modify: `packages/web/frontend/src/router/index.ts` +- Modify: `packages/web/frontend/src/views/EngagementDetailView.vue` + +- [ ] **Step 1: Install force-graph** + +Run: `cd packages/web/frontend && npm install force-graph` + +- [ ] **Step 2: Add the chain route to router** + +In `packages/web/frontend/src/router/index.ts`, add after the `finding-detail` route: + +```typescript + { path: '/engagements/:id/chain', name: 'engagement-chain', component: () => import('@/views/ChainGraphView.vue') }, +``` + +- [ ] **Step 3: Add "View Attack Chain" button to EngagementDetailView** + +In `packages/web/frontend/src/views/EngagementDetailView.vue`, add a button next to the existing Delete button in the header: + +Find the `