diff --git a/docs/superpowers/plans/2026-04-13-phase3c3-global-view-bayesian-calibration.md b/docs/superpowers/plans/2026-04-13-phase3c3-global-view-bayesian-calibration.md new file mode 100644 index 0000000..d034dad --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-phase3c3-global-view-bayesian-calibration.md @@ -0,0 +1,2367 @@ +# Phase 3C.3: Global View, Bayesian Calibration & Advanced Features — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add global cross-engagement graph view, Bayesian weight calibration, timeline playback, Markdown path export, swim lane Kill Chain layout, and attack vector scoring to the chain visualization. + +**Architecture:** Extends the 3C.2 subgraph endpoint (optional `engagement_id` for global mode), adds a calibration service with Beta priors, timeline scrubber component with temporal anchoring, Markdown export endpoint, Kill Chain layout mode in ForceGraphCanvas, and betweenness centrality scoring. One new DB table (`chain_calibration_state`). + +**Tech Stack:** FastAPI, SQLAlchemy async, rustworkx (betweenness centrality), Vue 3, PrimeVue, force-graph, TanStack Query + +**Spec:** `docs/superpowers/specs/2026-04-13-phase3c3-global-view-bayesian-calibration-design.md` + +--- + +## File Map + +### Backend (new/modified) + +| File | Action | Responsibility | +|------|--------|---------------| +| `packages/web/backend/app/models.py` | Modify | Add `ChainCalibrationState` table | +| `packages/web/backend/alembic/versions/007_chain_calibration_state.py` | Create | Migration for calibration_state table | +| `packages/web/backend/app/services/chain_service.py` | Modify | Make `engagement_id` optional in subgraph, add `calibrate`, `export_path`, pivotality computation | +| `packages/web/backend/app/services/chain_calibration.py` | Create | Bayesian calibration logic (Beta posteriors, re-scoring) | +| `packages/web/backend/app/services/chain_export.py` | Create | Markdown path report generation | +| `packages/web/backend/app/routes/chain.py` | Modify | Add calibrate endpoint, export endpoint, update subgraph params | +| `packages/cli/src/opentools/chain/cli.py` | Modify | Add `calibrate` command, `--format markdown` to `path` command | +| `packages/web/backend/tests/test_chain_global.py` | Create | Global subgraph, engagement_ids filter, new node fields | +| `packages/web/backend/tests/test_chain_calibration.py` | Create | Calibration endpoint + math tests | +| `packages/web/backend/tests/test_chain_export.py` | Create | Export endpoint tests | + +### Frontend (new/modified) + +| File | Action | Responsibility | +|------|--------|---------------| +| `packages/web/frontend/src/views/GlobalChainView.vue` | Create | Global cross-engagement graph page | +| `packages/web/frontend/src/components/EngagementFilterChips.vue` | Create | Engagement toggle chips for global view | +| `packages/web/frontend/src/components/ChainTimelineScrubber.vue` | Create | Dual-handle time range slider with activity heatmap | +| `packages/web/frontend/src/components/ForceGraphCanvas.vue` | Modify | Add timeRange prop, layoutMode prop, pivotality glow, engagement color mode | +| `packages/web/frontend/src/components/ChainDetailPanel.vue` | Modify | Add calibrated badge, export button, risk score display | +| `packages/web/frontend/src/components/AppLayout.vue` | Modify | Add "Attack Chain" nav item | +| `packages/web/frontend/src/router/index.ts` | Modify | Add `/chain/global` route | + +--- + +## Task 1: Backend — CalibrationState model + migration + +**Files:** +- Modify: `packages/web/backend/app/models.py` +- Create: `packages/web/backend/alembic/versions/007_chain_calibration_state.py` + +- [ ] **Step 1: Add CalibrationState model to models.py** + +Add at the end of `packages/web/backend/app/models.py`, after the `ChainFindingParserOutput` class: + +```python +class ChainCalibrationState(SQLModel, table=True): + """Per-rule Bayesian calibration state for a user.""" + __tablename__ = "chain_calibration_state" + id: str = Field(primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + rule: str = Field(index=True) + alpha: float = Field(default=1.0) + beta_param: float = Field(default=1.0) + observations: int = Field(default=0) + last_calibrated_at: datetime = Field(**_TZ_KW) + + __table_args__ = ( + UniqueConstraint("user_id", "rule", name="uq_calibration_state"), + ) +``` + +Note: field is named `beta_param` (not `beta`) to avoid shadowing Python's `beta` in math contexts. + +- [ ] **Step 2: Create Alembic migration** + +Create `packages/web/backend/alembic/versions/007_chain_calibration_state.py`: + +```python +"""Add chain_calibration_state table. + +Revision ID: 007 +Revises: 006 +""" +import sqlalchemy as sa +from alembic import op + +revision = "007" +down_revision = "006" + + +def upgrade() -> None: + op.create_table( + "chain_calibration_state", + sa.Column("id", sa.String(), primary_key=True), + sa.Column("user_id", sa.Uuid(), sa.ForeignKey("user.id"), nullable=False, index=True), + sa.Column("rule", sa.String(), nullable=False, index=True), + sa.Column("alpha", sa.Float(), nullable=False, server_default="1.0"), + sa.Column("beta_param", sa.Float(), nullable=False, server_default="1.0"), + sa.Column("observations", sa.Integer(), nullable=False, server_default="0"), + sa.Column("last_calibrated_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("user_id", "rule", name="uq_calibration_state"), + ) + + +def downgrade() -> None: + op.drop_table("chain_calibration_state") +``` + +- [ ] **Step 3: Verify model imports** + +Run: `cd packages/web/backend && python -c "from app.models import ChainCalibrationState; print('OK')"` +Expected: `OK` + +- [ ] **Step 4: Commit** + +```bash +git add packages/web/backend/app/models.py packages/web/backend/alembic/versions/007_chain_calibration_state.py +git commit -m "feat(chain): add ChainCalibrationState model and migration" +``` + +--- + +## Task 2: Backend — Calibration service + +**Files:** +- Create: `packages/web/backend/app/services/chain_calibration.py` + +- [ ] **Step 1: Create the calibration service** + +Create `packages/web/backend/app/services/chain_calibration.py`: + +```python +"""Bayesian weight calibration service. + +Uses Beta distribution priors per linking rule, updated from user +confirm/reject decisions. Posterior mean = alpha / (alpha + beta_param) +estimates each rule's reliability. +""" +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import ChainCalibrationState, ChainFindingRelation + +# Default Beta priors per rule +DEFAULT_PRIORS: dict[str, tuple[float, float]] = { + "shared_strong_entity": (2.0, 1.0), + "cve_adjacency": (2.0, 1.0), + "temporal_proximity": (1.0, 1.0), + "kill_chain": (1.0, 1.0), + "tool_chain": (1.0, 1.0), + "cross_engagement_ioc": (1.0, 1.0), +} + +MINIMUM_DECISIONS = 20 + + +async def get_or_create_priors( + session: AsyncSession, user_id: uuid.UUID +) -> dict[str, ChainCalibrationState]: + """Load existing calibration state or seed defaults.""" + stmt = select(ChainCalibrationState).where( + ChainCalibrationState.user_id == user_id + ) + result = await session.execute(stmt) + existing = {row.rule: row for row in result.scalars()} + + now = datetime.now(timezone.utc) + for rule, (alpha, beta) in DEFAULT_PRIORS.items(): + if rule not in existing: + row = ChainCalibrationState( + id=f"cal-{user_id}-{rule}", + user_id=user_id, + rule=rule, + alpha=alpha, + beta_param=beta, + observations=0, + last_calibrated_at=now, + ) + session.add(row) + existing[rule] = row + + await session.flush() + return existing + + +async def count_user_decisions( + session: AsyncSession, user_id: uuid.UUID, engagement_id: str | None = None +) -> int: + """Count total user-confirmed + user-rejected edges.""" + stmt = select(func.count()).select_from(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.in_(["user_confirmed", "user_rejected"]), + ) + if engagement_id: + from app.models import Finding + finding_ids_stmt = select(Finding.id).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + ) + stmt = stmt.where( + ChainFindingRelation.source_finding_id.in_(finding_ids_stmt) + ) + result = await session.execute(stmt) + return result.scalar() or 0 + + +async def calibrate( + session: AsyncSession, + *, + user_id: uuid.UUID, + engagement_id: str | None = None, + dry_run: bool = False, +) -> dict[str, Any]: + """Run Bayesian calibration from user decisions. + + Returns dict with 'rules' (per-rule posteriors), 'edges_updated', + 'below_threshold'. + """ + import orjson + + total_decisions = await count_user_decisions(session, user_id, engagement_id) + if total_decisions < MINIMUM_DECISIONS: + return { + "rules": [], + "edges_updated": 0, + "below_threshold": True, + "total_decisions": total_decisions, + "minimum_required": MINIMUM_DECISIONS, + } + + # Load or seed priors + priors = await get_or_create_priors(session, user_id) + + # Reset to defaults before re-counting + for rule, (alpha, beta) in DEFAULT_PRIORS.items(): + if rule in priors: + priors[rule].alpha = alpha + priors[rule].beta_param = beta + priors[rule].observations = 0 + + # Fetch all user-decided edges + decided_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.in_(["user_confirmed", "user_rejected"]), + ) + if engagement_id: + from app.models import Finding + finding_ids_stmt = select(Finding.id).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + ) + decided_stmt = decided_stmt.where( + ChainFindingRelation.source_finding_id.in_(finding_ids_stmt) + ) + + decided_result = await session.execute(decided_stmt) + decided_edges = list(decided_result.scalars()) + + # Update priors from decisions + for edge in decided_edges: + reasons_data = orjson.loads(edge.reasons_json) if edge.reasons_json else [] + rules_fired = {r["rule"] for r in reasons_data if "rule" in r} + + for rule in rules_fired: + if rule not in priors: + continue + if edge.status == "user_confirmed": + priors[rule].alpha += 1 + elif edge.status == "user_rejected": + priors[rule].beta_param += 1 + priors[rule].observations += 1 + + now = datetime.now(timezone.utc) + for p in priors.values(): + p.last_calibrated_at = now + + # Build posteriors summary + rules_summary = [ + { + "rule": rule, + "alpha": priors[rule].alpha, + "beta": priors[rule].beta_param, + "posterior": priors[rule].alpha / (priors[rule].alpha + priors[rule].beta_param), + "observations": priors[rule].observations, + } + for rule in sorted(priors.keys()) + ] + + edges_updated = 0 + if not dry_run: + # Re-score all non-rejected edges with bayesian weights + posteriors = { + rule: priors[rule].alpha / (priors[rule].alpha + priors[rule].beta_param) + for rule in priors + } + + all_edges_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.notin_(["rejected", "user_rejected"]), + ) + all_result = await session.execute(all_edges_stmt) + all_edges = list(all_result.scalars()) + + for edge in all_edges: + reasons_data = orjson.loads(edge.reasons_json) if edge.reasons_json else [] + new_weight = 0.0 + for reason in reasons_data: + rule = reason.get("rule", "") + contribution = reason.get("weight_contribution", 0.0) + posterior = posteriors.get(rule, 1.0) + new_weight += contribution * posterior + + # Cap at 1.0 + new_weight = min(new_weight, 1.0) + + if abs(edge.weight - new_weight) > 0.001: + edge.weight = new_weight + edge.weight_model_version = "bayesian_v1" + edge.updated_at = now + edges_updated += 1 + + # Persist calibration state and edge updates + for p in priors.values(): + session.add(p) + await session.commit() + + return { + "rules": rules_summary, + "edges_updated": edges_updated, + "below_threshold": False, + "total_decisions": total_decisions, + "minimum_required": MINIMUM_DECISIONS, + } +``` + +- [ ] **Step 2: Verify import** + +Run: `cd packages/web/backend && python -c "from app.services.chain_calibration import calibrate; print('OK')"` +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/app/services/chain_calibration.py +git commit -m "feat(chain): Bayesian calibration service with Beta priors" +``` + +--- + +## Task 3: Backend — Export service (Markdown path report) + +**Files:** +- Create: `packages/web/backend/app/services/chain_export.py` + +- [ ] **Step 1: Create the export service** + +Create `packages/web/backend/app/services/chain_export.py`: + +```python +"""Markdown attack path report generation.""" +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import ChainFindingRelation, Engagement, Finding + + +async def export_path_markdown( + session: AsyncSession, + *, + user_id: uuid.UUID, + finding_ids: list[str], + engagement_id: str | None = None, +) -> str: + """Generate a Markdown attack path report from an ordered list of finding IDs.""" + import orjson + + # Fetch engagement name if provided + eng_name = "Unknown Engagement" + if engagement_id: + eng_stmt = select(Engagement).where( + Engagement.id == engagement_id, Engagement.user_id == user_id + ) + eng_result = await session.execute(eng_stmt) + eng = eng_result.scalar_one_or_none() + if eng: + eng_name = eng.name + + # Fetch all findings in order + findings: list[Any] = [] + for fid in finding_ids: + stmt = select(Finding).where(Finding.id == fid, Finding.user_id == user_id) + result = await session.execute(stmt) + f = result.scalar_one_or_none() + if f is None: + raise ValueError(f"Finding {fid} not found") + findings.append(f) + + # Fetch relations between consecutive findings + relations: list[Any] = [] + for i in range(len(findings) - 1): + src_id = findings[i].id + tgt_id = findings[i + 1].id + rel_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.source_finding_id == src_id, + ChainFindingRelation.target_finding_id == tgt_id, + ) + rel_result = await session.execute(rel_stmt) + rel = rel_result.scalar_one_or_none() + relations.append(rel) + + # Compute risk score + severity_multipliers = {"critical": 5, "high": 4, "medium": 3, "low": 2, "info": 1} + max_sev = max(severity_multipliers.get(f.severity, 1) for f in findings) + edge_weight_sum = sum(r.weight for r in relations if r) + hop_count = len(findings) - 1 + import math + raw_score = (edge_weight_sum * max_sev) / max(math.sqrt(hop_count), 1) + risk_score = min(raw_score, 10.0) + + # Build markdown + now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC") + lines = [ + f"# Attack Path Report", + "", + f"**Engagement:** {eng_name}", + f"**Generated:** {now}", + f"**Path length:** {len(findings)} steps", + f"**Risk score:** {risk_score:.1f}/10", + "", + "## Summary", + "", + _build_summary(findings, relations), + "", + ] + + for i, finding in enumerate(findings): + sev = finding.severity.upper() if finding.severity else "UNKNOWN" + lines.append(f"## Step {i + 1}: {finding.title} ({sev})") + lines.append("") + lines.append(f"- **Tool:** {finding.tool}") + if finding.phase: + lines.append(f"- **Phase:** {finding.phase}") + if finding.evidence: + evidence = finding.evidence[:500] + lines.append(f"- **Evidence:** {evidence}") + if finding.remediation: + lines.append(f"- **Remediation:** {finding.remediation}") + + if i < len(relations) and relations[i]: + rel = relations[i] + reasons_data = orjson.loads(rel.reasons_json) if rel.reasons_json else [] + reason_names = [r.get("rule", "unknown") for r in reasons_data] + lines.append("") + lines.append( + f"**Link to Step {i + 2}:** {', '.join(reason_names)}, " + f"weight: {rel.weight:.2f}" + ) + lines.append("") + + # Recommendations + remediations = [f.remediation for f in findings if f.remediation] + if remediations: + lines.append("## Recommendations") + lines.append("") + seen = set() + for i, rem in enumerate(remediations): + if rem not in seen: + seen.add(rem) + lines.append(f"{len(seen)}. {rem}") + lines.append("") + + return "\n".join(lines) + + +def _build_summary(findings: list, relations: list) -> str: + """Template-based path summary.""" + if not findings: + return "No findings in path." + + first = findings[0] + last = findings[-1] + steps = len(findings) + + return ( + f"This attack path spans {steps} steps, starting from " + f"**{first.title}** ({first.severity}) and culminating in " + f"**{last.title}** ({last.severity}). " + f"The path traverses {steps - 1} link(s) through the target environment." + ) +``` + +- [ ] **Step 2: Verify import** + +Run: `cd packages/web/backend && python -c "from app.services.chain_export import export_path_markdown; print('OK')"` +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/app/services/chain_export.py +git commit -m "feat(chain): Markdown attack path report export service" +``` + +--- + +## Task 4: Backend — Extend subgraph service for global mode + scoring + +**Files:** +- Modify: `packages/web/backend/app/services/chain_service.py` + +This task modifies `subgraph_for_engagement` to support optional `engagement_id`, adds `engagement_ids` filter, adds `created_at`/`pivotality`/`engagement_id` to node objects, and adds `engagements` to meta. + +- [ ] **Step 1: Update method signature** + +In `packages/web/backend/app/services/chain_service.py`, change the `subgraph_for_engagement` method signature: + +```python + async def subgraph_for_engagement( + self, + session: AsyncSession, + *, + user_id: uuid.UUID, + engagement_id: str | None = None, # was required, now optional + engagement_ids: list[str] | None = None, # new: filter for global mode + severities: set[str] | None = None, + statuses: set[str] | None = None, + max_nodes: int = 500, + seed_finding_id: str | None = None, + hops: int = 2, + format: str = "force-graph", + ) -> dict[str, Any]: +``` + +- [ ] **Step 2: Update the finding query for global mode** + +Replace the finding query section. When `engagement_id` is None, query across all engagements (optionally filtered by `engagement_ids`): + +```python + # Fetch findings — scoped to engagement or global + finding_stmt = select(Finding).where( + Finding.user_id == user_id, + Finding.deleted_at.is_(None), + ) + if engagement_id: + finding_stmt = finding_stmt.where(Finding.engagement_id == engagement_id) + elif engagement_ids: + finding_stmt = finding_stmt.where(Finding.engagement_id.in_(engagement_ids)) + + # Total count (before severity filter and cap) + total_stmt = select(func.count()).select_from(finding_stmt.subquery()) + total_result = await session.execute(total_stmt) + total_findings = total_result.scalar() or 0 + + if severities: + finding_stmt = finding_stmt.where(Finding.severity.in_(severities)) + finding_stmt = finding_stmt.limit(max_nodes) + + finding_result = await session.execute(finding_stmt) + findings = list(finding_result.scalars().all()) + finding_ids = {f.id for f in findings} +``` + +- [ ] **Step 3: Add created_at, engagement_id, and pivotality to nodes** + +Replace the node building section: + +```python + # Compute betweenness centrality for pivotality scores + pivotality_scores: dict[str, float] = {} + if finding_ids and len(finding_ids) > 1: + import rustworkx as rx + g = rx.PyDiGraph() + id_to_idx: dict[str, int] = {} + for fid in finding_ids: + idx = g.add_node(fid) + id_to_idx[fid] = idx + for r in relations_orm: + src = r.source_finding_id + tgt = r.target_finding_id + if src in id_to_idx and tgt in id_to_idx: + g.add_edge(id_to_idx[src], id_to_idx[tgt], r.weight) + centrality = rx.betweenness_centrality(g) + max_c = max(centrality.values()) if centrality else 1.0 + for fid, idx in id_to_idx.items(): + raw = centrality.get(idx, 0.0) + pivotality_scores[fid] = raw / max_c if max_c > 0 else 0.0 + + # Build nodes with new fields + nodes = [ + { + "id": f.id, + "name": f.title, + "severity": f.severity, + "tool": f.tool, + "phase": f.phase, + "created_at": f.created_at.isoformat() if f.created_at else None, + "engagement_id": f.engagement_id, + "pivotality": round(pivotality_scores.get(f.id, 0.0), 3), + } + for f in findings + ] +``` + +- [ ] **Step 4: Add engagements to meta** + +Replace the meta building section: + +```python + # Collect distinct engagements represented in the result + from app.models import Engagement as EngModel + eng_ids_in_result = {f.engagement_id for f in findings} + engagements_meta = [] + if eng_ids_in_result: + eng_stmt = select(EngModel).where(EngModel.id.in_(eng_ids_in_result)) + eng_result = await session.execute(eng_stmt) + engagements_meta = [ + {"id": e.id, "name": e.name} + for e in eng_result.scalars() + ] + + return { + "graph": graph, + "meta": { + "total_findings": total_findings, + "rendered_findings": len(findings), + "filtered": bool(severities) or len(findings) < total_findings, + "generation": generation, + "engagements": engagements_meta, + }, + } +``` + +- [ ] **Step 5: Verify import** + +Run: `cd packages/web/backend && python -c "from app.services.chain_service import ChainService; print('OK')"` +Expected: `OK` + +- [ ] **Step 6: Commit** + +```bash +git add packages/web/backend/app/services/chain_service.py +git commit -m "feat(chain): global subgraph mode with pivotality, created_at, engagement meta" +``` + +--- + +## Task 5: Backend — Route endpoints (calibrate, export, subgraph updates) + +**Files:** +- Modify: `packages/web/backend/app/routes/chain.py` + +- [ ] **Step 1: Add new Pydantic models** + +Add after the existing `RelationStatusUpdate` class: + +```python +class CalibrateRequest(BaseModel): + scope: str = "user" + engagement_id: Optional[str] = None + dry_run: bool = False + + +class CalibrateResponse(BaseModel): + rules: list[dict] + edges_updated: int + below_threshold: bool + total_decisions: int + minimum_required: int + + +class ExportPathRequest(BaseModel): + finding_ids: list[str] + engagement_id: Optional[str] = None +``` + +- [ ] **Step 2: Update subgraph endpoint — make engagement_id optional, add engagement_ids** + +Change the `get_subgraph` endpoint signature: + +```python +@router.get("/subgraph", response_model=SubgraphResponse) +async def get_subgraph( + engagement_id: Optional[str] = None, + engagement_ids: Optional[str] = None, + severity: Optional[str] = None, + status_filter: Optional[str] = Query(default=None, alias="status"), + max_nodes: int = 500, + seed_finding_id: Optional[str] = None, + hops: int = 2, + format: str = "force-graph", + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), + service: ChainService = Depends(get_chain_service), +) -> SubgraphResponse: + severities = set(severity.split(",")) if severity else None + statuses = set(status_filter.split(",")) if status_filter else None + eng_ids_list = engagement_ids.split(",") if engagement_ids else None + + result = await service.subgraph_for_engagement( + db, + user_id=user.id, + engagement_id=engagement_id, + engagement_ids=eng_ids_list, + severities=severities, + statuses=statuses, + max_nodes=max_nodes, + seed_finding_id=seed_finding_id, + hops=hops, + format=format, + ) + return SubgraphResponse( + graph=result["graph"], + meta=SubgraphMeta(**result["meta"]), + ) +``` + +- [ ] **Step 3: Update SubgraphMeta model to include engagements** + +```python +class SubgraphMeta(BaseModel): + total_findings: int + rendered_findings: int + filtered: bool + generation: int + engagements: list[dict] = [] +``` + +- [ ] **Step 4: Add calibrate endpoint** + +```python +@router.post("/calibrate", response_model=CalibrateResponse) +async def calibrate_weights( + request: CalibrateRequest, + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +) -> CalibrateResponse: + from app.services.chain_calibration import calibrate + + if request.scope not in ("user", "engagement"): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="scope must be 'user' or 'engagement'", + ) + if request.scope == "engagement" and not request.engagement_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="engagement_id required when scope is 'engagement'", + ) + + result = await calibrate( + db, + user_id=user.id, + engagement_id=request.engagement_id if request.scope == "engagement" else None, + dry_run=request.dry_run, + ) + + if result["below_threshold"]: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=f"Need at least {result['minimum_required']} user decisions, have {result['total_decisions']}", + ) + + return CalibrateResponse(**result) +``` + +- [ ] **Step 5: Add export endpoint** + +```python +@router.post("/export/path") +async def export_path( + request: ExportPathRequest, + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +): + from app.services.chain_export import export_path_markdown + + if len(request.finding_ids) < 2: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="Path must contain at least 2 findings", + ) + + try: + markdown = await export_path_markdown( + db, + user_id=user.id, + finding_ids=request.finding_ids, + engagement_id=request.engagement_id, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + return {"markdown": markdown} +``` + +- [ ] **Step 6: Verify app starts** + +Run: `cd packages/web/backend && python -c "from app.main import app; print('OK')"` +Expected: `OK` + +- [ ] **Step 7: Commit** + +```bash +git add packages/web/backend/app/routes/chain.py +git commit -m "feat(chain): calibrate, export, and global subgraph endpoints" +``` + +--- + +## Task 6: Backend — Tests for global subgraph + +**Files:** +- Create: `packages/web/backend/tests/test_chain_global.py` + +- [ ] **Step 1: Write tests** + +Create `packages/web/backend/tests/test_chain_global.py`: + +```python +"""Global subgraph endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed(user_id): + """Seed two engagements with findings and a cross-engagement relation.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-g1", user_id=user_id, name="Pentest Q1", target="10.0.0.0/24", + type="pentest", created_at=NOW, updated_at=NOW, + )) + session.add(Engagement( + id="eng-g2", user_id=user_id, name="Web App", target="app.example.com", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + session.add(Finding( + id="f-g1", user_id=user_id, engagement_id="eng-g1", + tool="nmap", severity="high", title="Open SSH", created_at=NOW, + )) + session.add(Finding( + id="f-g2", user_id=user_id, engagement_id="eng-g2", + tool="nuclei", severity="critical", title="RCE in /api", created_at=NOW, + )) + await session.flush() + session.add(ChainFindingRelation( + id="rel-cross", user_id=user_id, source_finding_id="f-g1", + target_finding_id="f-g2", weight=0.6, status="auto_confirmed", + symmetric=False, created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_global_subgraph_returns_cross_engagement(auth_client): + """Omitting engagement_id returns findings from all engagements.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + assert resp.status_code == 200 + data = resp.json() + node_ids = {n["id"] for n in data["graph"]["nodes"]} + assert "f-g1" in node_ids + assert "f-g2" in node_ids + assert len(data["graph"]["links"]) >= 1 + + +@pytest.mark.asyncio +async def test_global_subgraph_includes_engagements_meta(auth_client): + """Meta includes engagements array with id and name.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + data = resp.json() + eng_ids = {e["id"] for e in data["meta"]["engagements"]} + assert "eng-g1" in eng_ids + assert "eng-g2" in eng_ids + + +@pytest.mark.asyncio +async def test_global_subgraph_engagement_ids_filter(auth_client): + """engagement_ids param filters to specific engagements.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?engagement_ids=eng-g1&max_nodes=100") + data = resp.json() + for n in data["graph"]["nodes"]: + assert n["engagement_id"] == "eng-g1" + + +@pytest.mark.asyncio +async def test_subgraph_nodes_have_created_at(auth_client): + """Node objects include created_at field.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-g1") + data = resp.json() + for n in data["graph"]["nodes"]: + assert "created_at" in n + + +@pytest.mark.asyncio +async def test_subgraph_nodes_have_pivotality(auth_client): + """Node objects include pivotality field.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + data = resp.json() + for n in data["graph"]["nodes"]: + assert "pivotality" in n + assert isinstance(n["pivotality"], (int, float)) +``` + +- [ ] **Step 2: Run tests** + +Run: `cd packages/web/backend && python -m pytest tests/test_chain_global.py -v` +Expected: all PASS + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/backend/tests/test_chain_global.py +git commit -m "test(chain): global subgraph, engagement filter, new node fields" +``` + +--- + +## Task 7: Backend — Tests for calibration + export + +**Files:** +- Create: `packages/web/backend/tests/test_chain_calibration.py` +- Create: `packages/web/backend/tests/test_chain_export.py` + +- [ ] **Step 1: Write calibration tests** + +Create `packages/web/backend/tests/test_chain_calibration.py`: + +```python +"""Calibration endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_decisions(user_id, count=25, confirmed_ratio=0.8): + """Seed engagement, findings, and user-decided edges.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-cal", user_id=user_id, name="Cal Test", target="10.0.0.1", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + + # Create pairs of findings with user-decided relations + for i in range(count): + f1_id = f"f-cal-{i}-a" + f2_id = f"f-cal-{i}-b" + session.add(Finding( + id=f1_id, user_id=user_id, engagement_id="eng-cal", + tool="nmap", severity="high", title=f"Finding {f1_id}", created_at=NOW, + )) + session.add(Finding( + id=f2_id, user_id=user_id, engagement_id="eng-cal", + tool="nuclei", severity="medium", title=f"Finding {f2_id}", created_at=NOW, + )) + await session.flush() + + is_confirmed = i < int(count * confirmed_ratio) + session.add(ChainFindingRelation( + id=f"rel-cal-{i}", user_id=user_id, + source_finding_id=f1_id, target_finding_id=f2_id, + weight=0.5, status="user_confirmed" if is_confirmed else "user_rejected", + symmetric=False, + reasons_json=f'[{{"rule":"shared_strong_entity","weight_contribution":0.5,"idf_factor":null,"details":{{}}}}]', + created_at=NOW, updated_at=NOW, + )) + + await session.commit() + + +@pytest.mark.asyncio +async def test_calibrate_below_threshold(auth_client): + """Calibration with too few decisions returns 422.""" + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "user"}) + assert resp.status_code == 422 + assert "Need at least" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_calibrate_success(auth_client): + """Calibration with enough decisions returns posteriors.""" + user_id = await _get_user_id(auth_client) + await _seed_decisions(user_id, count=25, confirmed_ratio=0.8) + + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "user"}) + assert resp.status_code == 200 + data = resp.json() + assert data["below_threshold"] is False + assert len(data["rules"]) > 0 + + # shared_strong_entity should have posterior > 0.5 (mostly confirmed) + sse = next(r for r in data["rules"] if r["rule"] == "shared_strong_entity") + assert sse["posterior"] > 0.5 + + +@pytest.mark.asyncio +async def test_calibrate_dry_run(auth_client): + """Dry run returns posteriors but edges_updated=0.""" + user_id = await _get_user_id(auth_client) + await _seed_decisions(user_id, count=25) + + resp = await auth_client.post("/api/chain/calibrate", json={ + "scope": "user", "dry_run": True, + }) + assert resp.status_code == 200 + assert resp.json()["edges_updated"] == 0 + + +@pytest.mark.asyncio +async def test_calibrate_invalid_scope(auth_client): + """Invalid scope returns 422.""" + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "global"}) + assert resp.status_code == 422 +``` + +- [ ] **Step 2: Write export tests** + +Create `packages/web/backend/tests/test_chain_export.py`: + +```python +"""Export endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_path(user_id): + """Seed engagement with a 3-step path.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-exp", user_id=user_id, name="Export Test", target="10.0.0.1", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + for i, (sev, title) in enumerate([ + ("critical", "SQL Injection"), + ("high", "Credential Dump"), + ("medium", "Lateral Movement"), + ]): + session.add(Finding( + id=f"f-exp-{i}", user_id=user_id, engagement_id="eng-exp", + tool="test", severity=sev, title=title, created_at=NOW, + evidence=f"Evidence for step {i}", + remediation=f"Fix step {i}", + )) + await session.flush() + session.add(ChainFindingRelation( + id="rel-exp-0", user_id=user_id, source_finding_id="f-exp-0", + target_finding_id="f-exp-1", weight=0.9, status="auto_confirmed", + symmetric=False, reasons_json=b'[{"rule":"shared_strong_entity","weight_contribution":0.9}]', + created_at=NOW, updated_at=NOW, + )) + session.add(ChainFindingRelation( + id="rel-exp-1", user_id=user_id, source_finding_id="f-exp-1", + target_finding_id="f-exp-2", weight=0.7, status="auto_confirmed", + symmetric=False, reasons_json=b'[{"rule":"temporal_proximity","weight_contribution":0.7}]', + created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_export_path_returns_markdown(auth_client): + """Valid path returns Markdown with expected sections.""" + user_id = await _get_user_id(auth_client) + await _seed_path(user_id) + + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-exp-0", "f-exp-1", "f-exp-2"], + "engagement_id": "eng-exp", + }) + assert resp.status_code == 200 + md = resp.json()["markdown"] + assert "# Attack Path Report" in md + assert "SQL Injection" in md + assert "Step 1:" in md + assert "Step 2:" in md + assert "Step 3:" in md + assert "Recommendations" in md + + +@pytest.mark.asyncio +async def test_export_path_invalid_finding(auth_client): + """Invalid finding ID returns 404.""" + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-nonexistent-1", "f-nonexistent-2"], + }) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_export_path_too_short(auth_client): + """Path with <2 findings returns 422.""" + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-exp-0"], + }) + assert resp.status_code == 422 +``` + +- [ ] **Step 3: Run all tests** + +Run: `cd packages/web/backend && python -m pytest tests/test_chain_calibration.py tests/test_chain_export.py -v` +Expected: all PASS + +- [ ] **Step 4: Commit** + +```bash +git add packages/web/backend/tests/test_chain_calibration.py packages/web/backend/tests/test_chain_export.py +git commit -m "test(chain): calibration and export endpoint tests" +``` + +--- + +## Task 8: Frontend — Install slider dependency, add global route + nav + +**Files:** +- Modify: `packages/web/frontend/src/router/index.ts` +- Modify: `packages/web/frontend/src/components/AppLayout.vue` + +- [ ] **Step 1: Add global chain route** + +In `packages/web/frontend/src/router/index.ts`, add after the `engagement-chain` route: + +```typescript + { path: '/chain/global', name: 'chain-global', component: () => import('@/views/GlobalChainView.vue') }, +``` + +- [ ] **Step 2: Add "Attack Chain" nav item to AppLayout** + +In `packages/web/frontend/src/components/AppLayout.vue`, add to the `menuItems` array after the IOCs entry: + +```typescript + { + label: 'Attack Chain', icon: 'pi pi-share-alt', + command: () => router.push('/chain/global'), + }, +``` + +- [ ] **Step 3: Commit** + +```bash +git add packages/web/frontend/src/router/index.ts packages/web/frontend/src/components/AppLayout.vue +git commit -m "feat(frontend): add global chain route and nav item" +``` + +--- + +## Task 9: Frontend — EngagementFilterChips component + +**Files:** +- Create: `packages/web/frontend/src/components/EngagementFilterChips.vue` + +- [ ] **Step 1: Create the component** + +Create `packages/web/frontend/src/components/EngagementFilterChips.vue`: + +```vue + + + + + Engagements: + + + + {{ eng.name }} + + + + +``` + +- [ ] **Step 2: Commit** + +```bash +git add packages/web/frontend/src/components/EngagementFilterChips.vue +git commit -m "feat(frontend): EngagementFilterChips — toggle engagement inclusion" +``` + +--- + +## Task 10: Frontend — ChainTimelineScrubber component + +**Files:** +- Create: `packages/web/frontend/src/components/ChainTimelineScrubber.vue` + +- [ ] **Step 1: Create the timeline scrubber** + +Create `packages/web/frontend/src/components/ChainTimelineScrubber.vue`: + +```vue + + + + + + + + + + + + + + + + + + + + + {{ rangeLabel }} + + + + + + +``` + +- [ ] **Step 2: Commit** + +```bash +git add packages/web/frontend/src/components/ChainTimelineScrubber.vue +git commit -m "feat(frontend): ChainTimelineScrubber — dual-handle slider with heatmap" +``` + +--- + +## Task 11: Frontend — ForceGraphCanvas extensions + +**Files:** +- Modify: `packages/web/frontend/src/components/ForceGraphCanvas.vue` + +This task adds four capabilities: time range filtering, Kill Chain layout mode, pivotality glow, and engagement color mode. + +- [ ] **Step 1: Extend the props interface** + +Add new props to the `defineProps`: + +```typescript +const props = defineProps<{ + data: GraphData + selectedNodeId: string | null + selectedLinkId: string | null + timeRange: { start: Date; end: Date } | null + layoutMode: 'force' | 'killchain' + colorMode: 'severity' | 'engagement' + engagementColors: Record +}>() +``` + +With defaults (add `withDefaults`): + +```typescript +const props = withDefaults(defineProps<{ + data: GraphData + selectedNodeId: string | null + selectedLinkId: string | null + timeRange?: { start: Date; end: Date } | null + layoutMode?: 'force' | 'killchain' + colorMode?: 'severity' | 'engagement' + engagementColors?: Record +}>(), { + timeRange: null, + layoutMode: 'force', + colorMode: 'severity', + engagementColors: () => ({}), +}) +``` + +- [ ] **Step 2: Add `created_at`, `engagement_id`, and `pivotality` to GraphNode interface** + +```typescript +interface GraphNode { + id: string + name: string + severity: string + tool: string + phase: string | null + created_at?: string | null + engagement_id?: string + pivotality?: number + x?: number + y?: number + fx?: number | undefined + fy?: number | undefined + neighborCount?: number +} +``` + +- [ ] **Step 3: Add time range filtering to nodeCanvasObject** + +At the start of the `nodeCanvasObject` callback, add: + +```typescript + // Time range visibility + if (props.timeRange && n.created_at) { + const t = new Date(n.created_at).getTime() + if (t < props.timeRange.start.getTime() || t > props.timeRange.end.getTime()) { + return // Don't render — outside time window + } + } +``` + +- [ ] **Step 4: Add time range filtering to linkCanvasObject** + +At the start of the `linkCanvasObject` callback, add: + +```typescript + // Hide edges where either endpoint is outside time window + if (props.timeRange) { + const srcNode = src as GraphNode + const tgtNode = tgt as GraphNode + if (srcNode.created_at) { + const st = new Date(srcNode.created_at).getTime() + if (st < props.timeRange.start.getTime() || st > props.timeRange.end.getTime()) return + } + if (tgtNode.created_at) { + const tt = new Date(tgtNode.created_at).getTime() + if (tt < props.timeRange.start.getTime() || tt > props.timeRange.end.getTime()) return + } + } +``` + +- [ ] **Step 5: Add engagement color mode to nodeCanvasObject** + +Replace the color line: + +```typescript + const color = props.colorMode === 'engagement' && n.engagement_id + ? (props.engagementColors[n.engagement_id] || '#95a5a6') + : (SEVERITY_COLORS[n.severity] || '#95a5a6') +``` + +When in engagement mode, add a severity-colored ring: + +```typescript + // Severity ring in engagement color mode + if (props.colorMode === 'engagement') { + const sevColor = SEVERITY_COLORS[n.severity] || '#95a5a6' + ctx.beginPath() + ctx.arc(node.x, node.y, radius + 2 / globalScale, 0, 2 * Math.PI) + ctx.strokeStyle = sevColor + ctx.lineWidth = 1.5 / globalScale + ctx.stroke() + } +``` + +- [ ] **Step 6: Add pivotality glow to nodeCanvasObject** + +After drawing the main circle, before the label: + +```typescript + // Pivotality glow + if (n.pivotality && n.pivotality > 0.1) { + const glowRadius = radius + 4 + n.pivotality * 8 + ctx.beginPath() + ctx.arc(node.x, node.y, glowRadius, 0, 2 * Math.PI) + ctx.fillStyle = `rgba(251, 191, 36, ${n.pivotality * 0.3})` + ctx.fill() + } +``` + +- [ ] **Step 7: Add Kill Chain layout mode** + +Add MITRE phase lane positions and the layout toggle logic: + +```typescript +const KILL_CHAIN_PHASES = [ + 'reconnaissance', 'resource-development', 'initial-access', 'execution', + 'persistence', 'privilege-escalation', 'defense-evasion', 'credential-access', + 'discovery', 'lateral-movement', 'collection', 'command-and-control', + 'exfiltration', 'impact', +] + +function applyKillChainLayout() { + if (!graph || !container.value) return + const width = container.value.clientWidth + const laneCount = KILL_CHAIN_PHASES.length + 1 // +1 for "Other" + const laneWidth = width / laneCount + + const nodes = graph.graphData().nodes as GraphNode[] + for (const n of nodes) { + const phaseIdx = n.phase ? KILL_CHAIN_PHASES.indexOf(n.phase) : -1 + const lane = phaseIdx >= 0 ? phaseIdx : KILL_CHAIN_PHASES.length + n.fx = laneWidth * lane + laneWidth / 2 + } + graph.d3ReheatSimulation() +} + +function clearKillChainLayout() { + if (!graph) return + const nodes = graph.graphData().nodes as GraphNode[] + for (const n of nodes) { + n.fx = undefined + } + graph.d3ReheatSimulation() +} +``` + +Add a watch for layoutMode: + +```typescript +watch(() => props.layoutMode, (mode) => { + if (mode === 'killchain') { + applyKillChainLayout() + } else { + clearKillChainLayout() + } +}) +``` + +Add `onRenderFramePost` for lane dividers (in `initGraph` after the graph is created): + +```typescript + .onRenderFramePost((ctx: CanvasRenderingContext2D, globalScale: number) => { + if (props.layoutMode !== 'killchain' || !container.value) return + + const width = container.value.clientWidth + const height = container.value.clientHeight + const laneCount = KILL_CHAIN_PHASES.length + 1 + const laneWidth = width / laneCount + + ctx.save() + ctx.setTransform(1, 0, 0, 1, 0, 0) // Reset to screen coords + + for (let i = 0; i <= laneCount; i++) { + const x = i * laneWidth + ctx.beginPath() + ctx.moveTo(x, 0) + ctx.lineTo(x, height) + ctx.strokeStyle = 'rgba(150, 150, 150, 0.2)' + ctx.setLineDash([4, 4]) + ctx.lineWidth = 1 + ctx.stroke() + ctx.setLineDash([]) + + // Phase header + if (i < KILL_CHAIN_PHASES.length) { + const label = MITRE_ABBREVS[KILL_CHAIN_PHASES[i]] || KILL_CHAIN_PHASES[i].slice(0, 4) + ctx.font = '10px sans-serif' + ctx.fillStyle = 'rgba(150, 150, 150, 0.6)' + ctx.textAlign = 'center' + ctx.fillText(label, x + laneWidth / 2, 14) + } else if (i === KILL_CHAIN_PHASES.length) { + ctx.font = '10px sans-serif' + ctx.fillStyle = 'rgba(150, 150, 150, 0.6)' + ctx.textAlign = 'center' + ctx.fillText('Other', x + laneWidth / 2, 14) + } + } + + ctx.restore() + }) +``` + +In Kill Chain mode, replace straight-line edge rendering with bezier curves. In the `linkCanvasObject`, after setting line style and before `ctx.stroke()`: + +```typescript + if (props.layoutMode === 'killchain') { + // Bezier curve for inter-lane edges, arc for intra-lane + const midX = (src.x + tgt.x) / 2 + const midY = (src.y + tgt.y) / 2 + const dx = tgt.x - src.x + const dy = tgt.y - src.y + const dist = Math.sqrt(dx * dx + dy * dy) + + ctx.beginPath() + ctx.moveTo(src.x, src.y) + if (Math.abs(dx) < 30) { + // Intra-lane: arc + const cpX = midX + dist * 0.3 + ctx.quadraticCurveTo(cpX, midY, tgt.x, tgt.y) + } else { + // Inter-lane: bezier + const cpOffset = Math.min(dist * 0.2, 50) + ctx.bezierCurveTo( + src.x + dx * 0.25, src.y - cpOffset, + tgt.x - dx * 0.25, tgt.y - cpOffset, + tgt.x, tgt.y + ) + } + ctx.stroke() + } else { + ctx.beginPath() + ctx.moveTo(src.x, src.y) + ctx.lineTo(tgt.x, tgt.y) + ctx.stroke() + } +``` + +(This replaces the existing straight-line `moveTo`/`lineTo`/`stroke` block.) + +- [ ] **Step 8: Commit** + +```bash +git add packages/web/frontend/src/components/ForceGraphCanvas.vue +git commit -m "feat(frontend): ForceGraphCanvas — timeline filter, kill chain layout, pivotality glow, engagement colors" +``` + +--- + +## Task 12: Frontend — ChainDetailPanel extensions + +**Files:** +- Modify: `packages/web/frontend/src/components/ChainDetailPanel.vue` + +- [ ] **Step 1: Add calibrated badge and risk score to edge details** + +In the link details section, after the status Tag, add: + +```vue + + +``` + +- [ ] **Step 2: Add Export Path button** + +After the Confirm/Reject buttons, add: + +```vue + +``` + +Add `'export-path'` to the emits definition. + +- [ ] **Step 3: Add pivotality display to node details** + +In the node details section, after the phase display: + +```vue + + Pivotality: + {{ (selectedNode.pivotality * 100).toFixed(0) }}% + +``` + +- [ ] **Step 4: Commit** + +```bash +git add packages/web/frontend/src/components/ChainDetailPanel.vue +git commit -m "feat(frontend): ChainDetailPanel — calibrated badge, export button, pivotality" +``` + +--- + +## Task 13: Frontend — GlobalChainView page + +**Files:** +- Create: `packages/web/frontend/src/views/GlobalChainView.vue` + +- [ ] **Step 1: Create the global chain view page** + +Create `packages/web/frontend/src/views/GlobalChainView.vue`: + +```vue + + + + + + + Attack Chain — Global + + + + + + + + + + + + + + + + + + + No chain data across engagements + Run chain analysis on individual engagements first. + + + + + + + {}" + @export-path="onExportPath" + /> + + + + + timeRange = r" + /> + + + + + +``` + +- [ ] **Step 2: Commit** + +```bash +git add packages/web/frontend/src/views/GlobalChainView.vue +git commit -m "feat(frontend): GlobalChainView — cross-engagement graph with engagement colors, calibration, timeline" +``` + +--- + +## Task 14: Frontend — Update ChainGraphView with timeline + layout toggle + +**Files:** +- Modify: `packages/web/frontend/src/views/ChainGraphView.vue` + +- [ ] **Step 1: Add timeline and layout state** + +Add to the existing `ChainGraphView.vue` script section: + +```typescript +import ChainTimelineScrubber from '@/components/ChainTimelineScrubber.vue' + +const layoutMode = ref<'force' | 'killchain'>('force') +const timeRange = ref<{ start: Date; end: Date } | null>(null) + +function toggleLayout() { + layoutMode.value = layoutMode.value === 'force' ? 'killchain' : 'force' +} +``` + +- [ ] **Step 2: Add layout toggle button to toolbar** + +In the toolbar template, after the `ChainFilterToolbar`: + +```vue + +``` + +- [ ] **Step 3: Pass new props to ForceGraphCanvas** + +Update the `ForceGraphCanvas` usage: + +```vue + +``` + +- [ ] **Step 4: Add timeline scrubber before the legend** + +```vue + + timeRange = r" + /> +``` + +- [ ] **Step 5: Add export path handler** + +```typescript +async function onExportPath() { + if (!selectedLink.value) return + // Build path from selected link's source and target + const srcId = typeof selectedLink.value.source === 'string' ? selectedLink.value.source : selectedLink.value.source.id + const tgtId = typeof selectedLink.value.target === 'string' ? selectedLink.value.target : selectedLink.value.target.id + + try { + const resp = await fetch('/api/chain/export/path', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + credentials: 'include', + body: JSON.stringify({ + finding_ids: [srcId, tgtId], + engagement_id: engId, + }), + }) + if (!resp.ok) throw new Error('Export failed') + const data = await resp.json() + + // Trigger download + const blob = new Blob([data.markdown], { type: 'text/markdown' }) + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = 'attack-path-report.md' + a.click() + URL.revokeObjectURL(url) + } catch { + toast.add({ severity: 'error', summary: 'Error', detail: 'Failed to export path', life: 3000 }) + } +} +``` + +Add `@export-path="onExportPath"` to the `ChainDetailPanel` usage. + +- [ ] **Step 6: Commit** + +```bash +git add packages/web/frontend/src/views/ChainGraphView.vue +git commit -m "feat(frontend): ChainGraphView — add timeline, layout toggle, path export" +``` + +--- + +## Task 15: Frontend — TypeScript check + build verification + +**Files:** none (verification only) + +- [ ] **Step 1: Run TypeScript check** + +Run: `cd packages/web/frontend && npx vue-tsc --noEmit` +Expected: no type errors + +- [ ] **Step 2: Run production build** + +Run: `cd packages/web/frontend && npx vite build` +Expected: build succeeds + +- [ ] **Step 3: Fix any issues found** + +- [ ] **Step 4: Commit fixes if needed** + +```bash +git add -A +git commit -m "fix: address type/build issues from 3C.3 integration" +``` + +--- + +## Task 16: Backend — Full test suite verification + +**Files:** none (verification only) + +- [ ] **Step 1: Run all backend tests** + +Run: `cd packages/web/backend && python -m pytest tests/ -v` +Expected: all tests PASS (existing + new global, calibration, export tests) + +- [ ] **Step 2: Fix any failures** + +- [ ] **Step 3: Commit fixes if needed** + +```bash +git add -A +git commit -m "fix: address test failures from 3C.3 integration" +``` + +--- + +## Task 17: CLI — `calibrate` command and `--format markdown` for path + +**Files:** +- Modify: `packages/cli/src/opentools/chain/cli.py` + +- [ ] **Step 1: Add `calibrate` command** + +Add after the existing `query` command in `packages/cli/src/opentools/chain/cli.py`: + +```python +@app.command() +@_async_command +async def calibrate( + scope: str = typer.Option("user", help="Scope: user or engagement"), + engagement: str | None = typer.Option(None, "--engagement"), + dry_run: bool = typer.Option(False, "--dry-run", help="Print posteriors without writing"), +) -> None: + """Calibrate edge weights from user confirm/reject decisions.""" + _engagement_store, chain_store = await _get_stores() + try: + from opentools.chain.types import RelationStatus + + # Count decisions + relations = await chain_store.fetch_relations_in_scope( + user_id=None, + statuses={RelationStatus.USER_CONFIRMED, RelationStatus.USER_REJECTED}, + ) + if len(relations) < 20: + rprint(f"[yellow]Need at least 20 user decisions, have {len(relations)}. Skipping.[/yellow]") + return + + # Simple Beta calibration — count per-rule confirm/reject + from collections import defaultdict + rule_counts: dict[str, dict[str, float]] = defaultdict(lambda: {"alpha": 1.0, "beta": 1.0}) + + # Set default priors + strong_rules = {"shared_strong_entity", "cve_adjacency"} + for r in relations: + for reason in r.reasons: + if reason.rule in strong_rules: + rule_counts[reason.rule]["alpha"] = 2.0 + + for r in relations: + for reason in r.reasons: + if r.status == RelationStatus.USER_CONFIRMED: + rule_counts[reason.rule]["alpha"] += 1 + elif r.status == RelationStatus.USER_REJECTED: + rule_counts[reason.rule]["beta"] += 1 + + rprint("[bold]Bayesian Calibration Results[/bold]") + for rule in sorted(rule_counts.keys()): + a = rule_counts[rule]["alpha"] + b = rule_counts[rule]["beta"] + posterior = a / (a + b) + rprint(f" {rule}: posterior={posterior:.3f} (α={a:.0f}, β={b:.0f})") + + if dry_run: + rprint("[yellow]Dry run — no edges updated[/yellow]") + return + + rprint("[green]Calibration complete[/green]") + finally: + await chain_store.close() +``` + +- [ ] **Step 2: Add `--format markdown` to the `path` command** + +Find the existing `path` command and add `markdown` to its format choices. The path command already has a `--format` option. Add a branch for `markdown` that generates the report: + +After the existing format handling in the `path` command, add: + +```python + elif fmt == "markdown": + # Build markdown report + lines = ["# Attack Path Report", ""] + for p in paths: + lines.append(f"## Path (cost: {p.total_cost:.2f}, {p.length} hops)") + lines.append("") + for i, node in enumerate(p.nodes): + lines.append(f"### Step {i + 1}: {node.title} ({node.severity})") + lines.append(f"- **Tool:** {node.tool}") + lines.append("") + if i < len(p.edges): + e = p.edges[i] + lines.append(f"**Link:** weight={e.weight:.2f}") + lines.append("") + rprint("\n".join(lines)) +``` + +- [ ] **Step 3: Commit** + +```bash +git add packages/cli/src/opentools/chain/cli.py +git commit -m "feat(cli): add chain calibrate command and --format markdown for path" +``` diff --git a/docs/superpowers/plans/2026-04-13-phase3c4-cypher-dsl.md b/docs/superpowers/plans/2026-04-13-phase3c4-cypher-dsl.md new file mode 100644 index 0000000..4f24f72 --- /dev/null +++ b/docs/superpowers/plans/2026-04-13-phase3c4-cypher-dsl.md @@ -0,0 +1,4186 @@ +# Phase 3C.4: Cypher-Style Query DSL — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a Cypher-style query DSL that lets users write pattern-matching queries over the attack chain knowledge graph from CLI, REPL, or web editor. + +**Architecture:** Layered pipeline — Parser (lark LALR) → Planner (predicate pushdown) → VirtualGraphBuilder (heterogeneous graph with entities as nodes) → Executor (binding-table pattern matcher with resource limits). Virtual graph cached for REPL reuse. Dual result format: table + subgraph projection. + +**Tech Stack:** `lark` (parser), `rustworkx` (graph engine), `prompt_toolkit` (REPL), CodeMirror 6 (web editor), existing `ChainStoreProtocol` + `GraphCache` infrastructure. + +**Spec:** `docs/superpowers/specs/2026-04-13-phase3c4-cypher-dsl-design.md` + +--- + +## File Map + +### New files (CLI package) + +| File | Responsibility | +|---|---| +| `packages/cli/src/opentools/chain/cypher/__init__.py` | Public API: `parse_and_execute()`, `CypherSession` | +| `packages/cli/src/opentools/chain/cypher/errors.py` | `QueryParseError`, `QueryValidationError`, `QueryResourceError` | +| `packages/cli/src/opentools/chain/cypher/limits.py` | `QueryLimits` dataclass | +| `packages/cli/src/opentools/chain/cypher/ast_nodes.py` | AST dataclass definitions | +| `packages/cli/src/opentools/chain/cypher/grammar.lark` | Lark EBNF grammar | +| `packages/cli/src/opentools/chain/cypher/parser.py` | Lark parser → typed AST | +| `packages/cli/src/opentools/chain/cypher/builtins.py` | Built-in functions: `length`, `nodes`, `relationships`, `has_entity`, `has_mitre`, `collect` | +| `packages/cli/src/opentools/chain/cypher/plugins.py` | `PluginFunctionRegistry` + registration API | +| `packages/cli/src/opentools/chain/cypher/virtual_graph.py` | `VirtualGraph`, `VirtualGraphBuilder`, `VirtualGraphCache` | +| `packages/cli/src/opentools/chain/cypher/planner.py` | AST → `QueryPlan` with predicate pushdown | +| `packages/cli/src/opentools/chain/cypher/executor.py` | `CypherExecutor` — walks plan against virtual graph | +| `packages/cli/src/opentools/chain/cypher/result.py` | `QueryResult`, `SubgraphProjection`, `QueryStats` | +| `packages/cli/src/opentools/chain/cypher/session.py` | `QuerySession` — named result sets, REPL state | + +### New test files + +| File | Tests for | +|---|---| +| `packages/cli/tests/chain/cypher/__init__.py` | Package marker | +| `packages/cli/tests/chain/cypher/test_errors.py` | Error classes | +| `packages/cli/tests/chain/cypher/test_limits.py` | QueryLimits | +| `packages/cli/tests/chain/cypher/test_ast_nodes.py` | AST dataclasses | +| `packages/cli/tests/chain/cypher/test_parser.py` | Grammar + parser | +| `packages/cli/tests/chain/cypher/test_builtins.py` | Built-in functions | +| `packages/cli/tests/chain/cypher/test_plugins.py` | Plugin registry | +| `packages/cli/tests/chain/cypher/test_virtual_graph.py` | VirtualGraphBuilder + cache | +| `packages/cli/tests/chain/cypher/test_planner.py` | Planner + predicate pushdown | +| `packages/cli/tests/chain/cypher/test_executor.py` | Executor end-to-end | +| `packages/cli/tests/chain/cypher/test_session.py` | Session state | +| `packages/cli/tests/chain/cypher/test_cli_query.py` | CLI commands | + +### Modified files (CLI) + +| File | Change | +|---|---| +| `packages/cli/src/opentools/chain/cli.py` | Replace existing `query` command with new `query` subgroup (`run`, `repl`, `explain`) | +| `packages/cli/src/opentools/chain/config.py` | Add `CypherConfig` to `ChainConfig` | + +### New files (Web backend) + +| File | Responsibility | +|---|---| +| `packages/web/backend/app/routes/chain_query.py` | `POST /api/chain/query`, `GET /api/chain/query/functions` | +| `packages/web/backend/tests/chain/test_query_routes.py` | Web endpoint tests | + +### Modified files (Web backend) + +| File | Change | +|---|---| +| `packages/web/backend/app/main.py` | Register `chain_query.router` | + +### New files (Web frontend) + +| File | Responsibility | +|---|---| +| `packages/web/frontend/src/views/ChainQueryView.vue` | Standalone query page | +| `packages/web/frontend/src/components/CypherEditor.vue` | CodeMirror wrapper with Cypher mode | +| `packages/web/frontend/src/components/QueryResultsPane.vue` | Tabular results + mini graph | +| `packages/web/frontend/src/components/InlineQueryPanel.vue` | Collapsible overlay (final task) | + +### Modified files (Web frontend) + +| File | Change | +|---|---| +| `packages/web/frontend/src/router/index.ts` | Add `/chain/query` route | +| `packages/web/frontend/src/views/ChainGraphView.vue` | Add InlineQueryPanel (final task) | + +--- + +## Tasks + +### Task 1: Error Types + Limits + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/__init__.py` +- Create: `packages/cli/src/opentools/chain/cypher/errors.py` +- Create: `packages/cli/src/opentools/chain/cypher/limits.py` +- Create: `packages/cli/tests/chain/cypher/__init__.py` +- Create: `packages/cli/tests/chain/cypher/test_errors.py` +- Create: `packages/cli/tests/chain/cypher/test_limits.py` + +- [ ] **Step 1: Write failing tests for error classes** + +```python +# packages/cli/tests/chain/cypher/test_errors.py +from opentools.chain.cypher.errors import ( + QueryParseError, + QueryResourceError, + QueryValidationError, +) + + +def test_query_parse_error_is_exception(): + err = QueryParseError("unexpected token", line=3, column=12) + assert isinstance(err, Exception) + assert err.line == 3 + assert err.column == 12 + assert "unexpected token" in str(err) + + +def test_query_validation_error_is_exception(): + err = QueryValidationError("unknown function: foo.bar") + assert isinstance(err, Exception) + assert "foo.bar" in str(err) + + +def test_query_resource_error_is_exception(): + err = QueryResourceError("binding cap exceeded", limit_name="intermediate_binding_cap", limit_value=10_000) + assert isinstance(err, Exception) + assert err.limit_name == "intermediate_binding_cap" + assert err.limit_value == 10_000 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_errors.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'opentools.chain.cypher'` + +- [ ] **Step 3: Implement error classes** + +```python +# packages/cli/src/opentools/chain/cypher/__init__.py +"""Cypher-style query DSL for the attack chain knowledge graph.""" + +# packages/cli/src/opentools/chain/cypher/errors.py +"""Query DSL error hierarchy.""" +from __future__ import annotations + + +class QueryParseError(Exception): + def __init__(self, message: str, *, line: int | None = None, column: int | None = None) -> None: + self.line = line + self.column = column + loc = "" + if line is not None: + loc = f" (line {line}" + if column is not None: + loc += f", col {column}" + loc += ")" + super().__init__(f"{message}{loc}") + + +class QueryValidationError(Exception): + pass + + +class QueryResourceError(Exception): + def __init__(self, message: str, *, limit_name: str, limit_value: int | float) -> None: + self.limit_name = limit_name + self.limit_value = limit_value + super().__init__(message) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_errors.py -v` +Expected: 3 passed + +- [ ] **Step 5: Write failing tests for QueryLimits** + +```python +# packages/cli/tests/chain/cypher/test_limits.py +from opentools.chain.cypher.limits import QueryLimits + + +def test_query_limits_defaults(): + limits = QueryLimits() + assert limits.timeout_seconds == 30.0 + assert limits.max_rows == 1000 + assert limits.intermediate_binding_cap == 10_000 + assert limits.max_var_length_hops == 10 + + +def test_query_limits_custom(): + limits = QueryLimits(timeout_seconds=60.0, max_rows=500) + assert limits.timeout_seconds == 60.0 + assert limits.max_rows == 500 + assert limits.intermediate_binding_cap == 10_000 # unchanged default + + +def test_query_limits_frozen(): + limits = QueryLimits() + try: + limits.timeout_seconds = 99.0 + assert False, "should be frozen" + except Exception: + pass +``` + +- [ ] **Step 6: Implement QueryLimits** + +```python +# packages/cli/src/opentools/chain/cypher/limits.py +"""Resource limits for query execution.""" +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class QueryLimits(BaseModel): + model_config = ConfigDict(frozen=True) + + timeout_seconds: float = 30.0 + max_rows: int = 1000 + intermediate_binding_cap: int = 10_000 + max_var_length_hops: int = 10 +``` + +- [ ] **Step 7: Run all cypher tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/ -v` +Expected: 6 passed + +- [ ] **Step 8: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/__init__.py packages/cli/src/opentools/chain/cypher/errors.py packages/cli/src/opentools/chain/cypher/limits.py packages/cli/tests/chain/cypher/__init__.py packages/cli/tests/chain/cypher/test_errors.py packages/cli/tests/chain/cypher/test_limits.py +git commit -m "feat(cypher): add error types and query limits" +``` + +--- + +### Task 2: AST Node Definitions + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/ast_nodes.py` +- Create: `packages/cli/tests/chain/cypher/test_ast_nodes.py` + +- [ ] **Step 1: Write failing tests for AST nodes** + +```python +# packages/cli/tests/chain/cypher/test_ast_nodes.py +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + EdgePattern, + FunctionCallExpr, + MatchClause, + NodePattern, + PropertyAccessExpr, + ReturnClause, + ReturnItem, + SessionAssignment, + VarLengthSpec, + WhereClause, +) + + +def test_node_pattern(): + n = NodePattern(variable="a", label="Finding") + assert n.variable == "a" + assert n.label == "Finding" + + +def test_node_pattern_no_label(): + n = NodePattern(variable="x", label=None) + assert n.label is None + + +def test_edge_pattern_outgoing(): + e = EdgePattern(variable="r", label="LINKED", direction="out", var_length=None) + assert e.direction == "out" + assert e.var_length is None + + +def test_edge_pattern_with_var_length(): + vl = VarLengthSpec(min_hops=1, max_hops=5) + e = EdgePattern(variable="r", label="LINKED", direction="out", var_length=vl) + assert e.var_length.min_hops == 1 + assert e.var_length.max_hops == 5 + + +def test_var_length_spec_defaults(): + vl = VarLengthSpec(min_hops=1, max_hops=3) + assert vl.min_hops == 1 + assert vl.max_hops == 3 + + +def test_property_access_expr(): + p = PropertyAccessExpr(variable="a", property_name="severity") + assert p.variable == "a" + assert p.property_name == "severity" + + +def test_comparison_expr(): + left = PropertyAccessExpr(variable="a", property_name="severity") + c = ComparisonExpr(left=left, operator="=", right="critical") + assert c.operator == "=" + assert c.right == "critical" + + +def test_boolean_expr(): + left = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", right="critical", + ) + right = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="tool"), + operator="=", right="nmap", + ) + b = BooleanExpr(operator="AND", operands=[left, right]) + assert b.operator == "AND" + assert len(b.operands) == 2 + + +def test_function_call_expr(): + f = FunctionCallExpr(name="has_entity", args=["a", "host", "10.0.0.1"]) + assert f.name == "has_entity" + assert len(f.args) == 3 + + +def test_function_call_plugin_namespaced(): + f = FunctionCallExpr(name="my_plugin.risk_score", args=["a"]) + assert "." in f.name + + +def test_match_clause(): + node_a = NodePattern(variable="a", label="Finding") + edge_r = EdgePattern(variable="r", label="LINKED", direction="out", var_length=None) + node_b = NodePattern(variable="b", label="Finding") + mc = MatchClause(patterns=[(node_a, edge_r, node_b)]) + assert len(mc.patterns) == 1 + + +def test_where_clause(): + pred = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", right="critical", + ) + wc = WhereClause(expression=pred) + assert wc.expression is pred + + +def test_return_clause(): + items = [ + ReturnItem(expression="a", alias=None), + ReturnItem(expression=PropertyAccessExpr(variable="a", property_name="title"), alias="name"), + ] + rc = ReturnClause(items=items) + assert len(rc.items) == 2 + assert rc.items[1].alias == "name" + + +def test_session_assignment(): + rc = ReturnClause(items=[ReturnItem(expression="a", alias=None)]) + mc = MatchClause(patterns=[]) + sa = SessionAssignment(variable_name="results", match_clause=mc, where_clause=None, return_clause=rc) + assert sa.variable_name == "results" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_ast_nodes.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement AST nodes** + +```python +# packages/cli/src/opentools/chain/cypher/ast_nodes.py +"""Typed AST nodes for the Cypher-style query DSL.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + + +@dataclass +class VarLengthSpec: + min_hops: int + max_hops: int + + +@dataclass +class NodePattern: + variable: str | None + label: str | None + + +@dataclass +class EdgePattern: + variable: str | None + label: str | None + direction: Literal["out", "in"] + var_length: VarLengthSpec | None + + +@dataclass +class PropertyAccessExpr: + variable: str + property_name: str + + +@dataclass +class ComparisonExpr: + left: Any # PropertyAccessExpr | FunctionCallExpr + operator: str # =, <>, <, >, <=, >=, CONTAINS, STARTS WITH, ENDS WITH, IN, IS NULL, IS NOT NULL + right: Any # literal value, list, or None for IS NULL/IS NOT NULL + + +@dataclass +class BooleanExpr: + operator: Literal["AND", "OR", "NOT"] + operands: list[Any] # ComparisonExpr | BooleanExpr | FunctionCallExpr + + +@dataclass +class FunctionCallExpr: + name: str # "has_entity", "length", "my_plugin.risk_score", etc. + args: list[Any] = field(default_factory=list) + + +@dataclass +class ReturnItem: + expression: Any # str (variable name), PropertyAccessExpr, FunctionCallExpr + alias: str | None + + +@dataclass +class MatchClause: + patterns: list[tuple] # list of (NodePattern, EdgePattern, NodePattern, ...) tuples + + +@dataclass +class WhereClause: + expression: Any # ComparisonExpr | BooleanExpr | FunctionCallExpr + + +@dataclass +class ReturnClause: + items: list[ReturnItem] + + +@dataclass +class FromClause: + session_variable: str + + +@dataclass +class SessionAssignment: + variable_name: str + match_clause: MatchClause + where_clause: WhereClause | None + return_clause: ReturnClause + + +@dataclass +class CypherQuery: + match_clause: MatchClause + where_clause: WhereClause | None + return_clause: ReturnClause + from_clause: FromClause | None = None + session_assignment: str | None = None # if this is a "name = MATCH ..." form +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_ast_nodes.py -v` +Expected: 14 passed + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/ast_nodes.py packages/cli/tests/chain/cypher/test_ast_nodes.py +git commit -m "feat(cypher): add AST node definitions" +``` + +--- + +### Task 3: Lark Grammar + Parser + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/grammar.lark` +- Create: `packages/cli/src/opentools/chain/cypher/parser.py` +- Create: `packages/cli/tests/chain/cypher/test_parser.py` + +- [ ] **Step 1: Write failing tests for the parser** + +```python +# packages/cli/tests/chain/cypher/test_parser.py +import pytest + +from opentools.chain.cypher.ast_nodes import ( + ComparisonExpr, + CypherQuery, + EdgePattern, + FunctionCallExpr, + NodePattern, + PropertyAccessExpr, + SessionAssignment, +) +from opentools.chain.cypher.errors import QueryParseError +from opentools.chain.cypher.parser import parse_cypher + + +# ─── basic MATCH ... RETURN ────────────────────────────────────────── + + +def test_parse_simple_match_return(): + q = parse_cypher("MATCH (a:Finding) RETURN a") + assert isinstance(q, CypherQuery) + assert len(q.match_clause.patterns) == 1 + assert len(q.return_clause.items) == 1 + + +def test_parse_two_node_pattern(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert len(q.match_clause.patterns) == 1 + pattern = q.match_clause.patterns[0] + # pattern is a tuple: (NodePattern, EdgePattern, NodePattern) + assert isinstance(pattern[0], NodePattern) + assert pattern[0].label == "Finding" + assert isinstance(pattern[1], EdgePattern) + assert pattern[1].label == "LINKED" + assert pattern[1].direction == "out" + assert isinstance(pattern[2], NodePattern) + assert pattern[2].label == "Finding" + + +def test_parse_incoming_edge(): + q = parse_cypher("MATCH (a:Finding)<-[r:MENTIONED_IN]-(e:Host) RETURN a, e") + pattern = q.match_clause.patterns[0] + assert pattern[1].direction == "in" + assert pattern[1].label == "MENTIONED_IN" + + +# ─── entity node labels ────────────────────────────────────────────── + + +def test_parse_entity_node_labels(): + for label in ["Host", "IP", "CVE", "Domain", "Port", "MitreAttack", "Entity"]: + q = parse_cypher(f"MATCH (e:{label}) RETURN e") + assert q.match_clause.patterns[0][0].label == label + + +# ─── variable-length paths ─────────────────────────────────────────── + + +def test_parse_var_length_path(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED*1..5]->(b:Finding) RETURN a, b") + edge = q.match_clause.patterns[0][1] + assert edge.var_length is not None + assert edge.var_length.min_hops == 1 + assert edge.var_length.max_hops == 5 + + +def test_parse_var_length_exceeds_max_hops(): + with pytest.raises(QueryParseError, match="max.*10"): + parse_cypher("MATCH (a:Finding)-[r:LINKED*1..15]->(b:Finding) RETURN a, b") + + +# ─── WHERE clause ──────────────────────────────────────────────────── + + +def test_parse_where_comparison(): + q = parse_cypher('MATCH (a:Finding) WHERE a.severity = "critical" RETURN a') + assert q.where_clause is not None + expr = q.where_clause.expression + assert isinstance(expr, ComparisonExpr) + assert isinstance(expr.left, PropertyAccessExpr) + assert expr.left.variable == "a" + assert expr.left.property_name == "severity" + assert expr.operator == "=" + assert expr.right == "critical" + + +def test_parse_where_numeric_comparison(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.weight > 2.0 RETURN a, b") + expr = q.where_clause.expression + assert expr.operator == ">" + assert expr.right == 2.0 + + +def test_parse_where_and(): + q = parse_cypher('MATCH (a:Finding) WHERE a.severity = "critical" AND a.tool = "nmap" RETURN a') + expr = q.where_clause.expression + assert expr.operator == "AND" if hasattr(expr, "operator") else True + + +def test_parse_where_function_call(): + q = parse_cypher('MATCH (a:Finding) WHERE has_entity(a, "host", "10.0.0.1") RETURN a') + assert q.where_clause is not None + + +def test_parse_where_contains(): + q = parse_cypher('MATCH (a:Finding) WHERE a.title CONTAINS "ssh" RETURN a') + expr = q.where_clause.expression + assert expr.operator == "CONTAINS" + + +def test_parse_where_is_null(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.llm_rationale IS NOT NULL RETURN a, b") + assert q.where_clause is not None + + +# ─── RETURN ────────────────────────────────────────────────────────── + + +def test_parse_return_property(): + q = parse_cypher("MATCH (a:Finding) RETURN a.title, a.severity") + assert len(q.return_clause.items) == 2 + assert isinstance(q.return_clause.items[0].expression, PropertyAccessExpr) + + +def test_parse_return_collect(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN collect(a)") + item = q.return_clause.items[0] + assert isinstance(item.expression, FunctionCallExpr) + assert item.expression.name == "collect" + + +# ─── session assignment ────────────────────────────────────────────── + + +def test_parse_session_assignment(): + q = parse_cypher("results = MATCH (a:Finding) RETURN a") + assert q.session_assignment == "results" + + +# ─── FROM clause ───────────────────────────────────────────────────── + + +def test_parse_from_clause(): + q = parse_cypher("MATCH (a) FROM prev_results -[r:LINKED]->(b:Finding) RETURN a, b") + assert q.from_clause is not None + assert q.from_clause.session_variable == "prev_results" + + +# ─── read-only enforcement ─────────────────────────────────────────── + + +@pytest.mark.parametrize("verb", ["CREATE", "DELETE", "SET", "MERGE", "REMOVE", "DETACH", "DROP"]) +def test_parse_rejects_mutation_verbs(verb): + with pytest.raises(QueryParseError): + parse_cypher(f"{verb} (a:Finding)") + + +# ─── edge cases ────────────────────────────────────────────────────── + + +def test_parse_empty_string(): + with pytest.raises(QueryParseError): + parse_cypher("") + + +def test_parse_garbage(): + with pytest.raises(QueryParseError): + parse_cypher("not a query at all 123 !!!") + + +def test_parse_case_insensitive_keywords(): + q = parse_cypher("match (a:Finding) where a.severity = \"critical\" return a") + assert q is not None + + +def test_parse_multiple_patterns(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding), (b)-[:MENTIONED_IN]->(e:Host) RETURN a, e") + assert len(q.match_clause.patterns) == 2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_parser.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Create the lark grammar** + +```lark +// packages/cli/src/opentools/chain/cypher/grammar.lark +// Cypher-style query DSL grammar — read-only subset for OpenTools + +?start: session_assignment | query + +session_assignment: IDENTIFIER "=" query + +query: match_clause where_clause? return_clause + +match_clause: MATCH_KW pattern ("," pattern)* +where_clause: WHERE_KW expression +return_clause: RETURN_KW return_item ("," return_item)* + +// ─── patterns ──────────────────────────────────────────────── + +pattern: node_pattern (edge_pattern node_pattern)* + | node_pattern from_clause (edge_pattern node_pattern)* + +from_clause: FROM_KW IDENTIFIER + +node_pattern: "(" IDENTIFIER? (":" LABEL)? ")" + +edge_pattern: "-[" edge_detail "]->" -> edge_out + | "<-[" edge_detail "]-" -> edge_in + +edge_detail: IDENTIFIER? (":" EDGE_LABEL)? var_length? + +var_length: "*" INT ".." INT + +// ─── expressions ───────────────────────────────────────────── + +?expression: or_expr + +?or_expr: and_expr (OR_KW and_expr)* +?and_expr: not_expr (AND_KW not_expr)* +?not_expr: NOT_KW not_expr | comparison +?comparison: operand CMP_OP operand -> cmp_expr + | operand STRING_OP operand -> string_expr + | operand IN_KW "[" value_list "]" -> in_expr + | operand IS_KW NULL_KW -> is_null_expr + | operand IS_KW NOT_KW NULL_KW -> is_not_null_expr + | function_call + | "(" expression ")" + +?operand: property_access | function_call | literal | IDENTIFIER + +property_access: IDENTIFIER "." IDENTIFIER + +function_call: DOTTED_NAME "(" arg_list? ")" + | IDENTIFIER "(" arg_list? ")" + +arg_list: expression ("," expression)* + +value_list: literal ("," literal)* + +// ─── return items ──────────────────────────────────────────── + +return_item: expression (AS_KW IDENTIFIER)? + +// ─── literals ──────────────────────────────────────────────── + +?literal: ESCAPED_STRING -> string_literal + | SIGNED_NUMBER -> number_literal + | TRUE_KW -> true_literal + | FALSE_KW -> false_literal + | NULL_KW -> null_literal + +// ─── terminals ─────────────────────────────────────────────── + +MATCH_KW: /MATCH/i +WHERE_KW: /WHERE/i +RETURN_KW: /RETURN/i +AND_KW: /AND/i +OR_KW: /OR/i +NOT_KW: /NOT/i +IN_KW: /IN/i +IS_KW: /IS/i +NULL_KW: /NULL/i +AS_KW: /AS/i +FROM_KW: /FROM/i +TRUE_KW: /true/i +FALSE_KW: /false/i + +CMP_OP: "=" | "<>" | "<=" | ">=" | "<" | ">" +STRING_OP: /CONTAINS/i | /STARTS\s+WITH/i | /ENDS\s+WITH/i + +LABEL: "Finding" | "Host" | "IP" | "CVE" | "Domain" | "Port" | "MitreAttack" | "Entity" +EDGE_LABEL: "LINKED" | "MENTIONED_IN" + +DOTTED_NAME: IDENTIFIER "." IDENTIFIER + +IDENTIFIER: /[a-zA-Z_][a-zA-Z0-9_]*/ + +%import common.ESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.INT +%import common.WS +%ignore WS +``` + +- [ ] **Step 4: Implement the parser module** + +```python +# packages/cli/src/opentools/chain/cypher/parser.py +"""Lark-based parser for the Cypher-style query DSL.""" +from __future__ import annotations + +from pathlib import Path + +from lark import Lark, Transformer, v_args, exceptions as lark_exceptions + +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + CypherQuery, + EdgePattern, + FromClause, + FunctionCallExpr, + MatchClause, + NodePattern, + PropertyAccessExpr, + ReturnClause, + ReturnItem, + VarLengthSpec, + WhereClause, +) +from opentools.chain.cypher.errors import QueryParseError + +_GRAMMAR_PATH = Path(__file__).parent / "grammar.lark" +_MAX_VAR_LENGTH_HOPS = 10 + +# Mutation verbs rejected before parsing +_MUTATION_VERBS = {"CREATE", "DELETE", "SET", "MERGE", "REMOVE", "DETACH", "DROP"} + + +def _check_mutation_verbs(query: str) -> None: + """Reject queries that start with or contain mutation verbs.""" + tokens = query.strip().split() + if not tokens: + raise QueryParseError("empty query") + first = tokens[0].upper() + if first in _MUTATION_VERBS: + raise QueryParseError(f"mutation verb '{first}' is not supported (read-only DSL)") + # Also check for mutation verbs anywhere (e.g., after MATCH) + for token in tokens: + upper = token.upper().rstrip("(") + if upper in _MUTATION_VERBS: + raise QueryParseError(f"mutation verb '{upper}' is not supported (read-only DSL)") + + +@v_args(inline=True) +class CypherTransformer(Transformer): + """Transform lark parse tree into typed AST nodes.""" + + def start(self, item): + return item + + def query(self, *args): + match_clause = args[0] + where_clause = None + return_clause = None + from_clause = None + + for arg in args[1:]: + if isinstance(arg, WhereClause): + where_clause = arg + elif isinstance(arg, ReturnClause): + return_clause = arg + + # Extract from_clause from match_clause patterns if present + if hasattr(match_clause, '_from_clause'): + from_clause = match_clause._from_clause + + return CypherQuery( + match_clause=match_clause, + where_clause=where_clause, + return_clause=return_clause, + from_clause=from_clause, + ) + + def session_assignment(self, name, query): + query.session_assignment = str(name) + return query + + def match_clause(self, *patterns): + return MatchClause(patterns=list(patterns)) + + def pattern(self, *elements): + result = [] + from_clause = None + for el in elements: + if isinstance(el, FromClause): + from_clause = el + else: + result.append(el) + pattern_tuple = tuple(result) + # Attach from_clause as metadata if present + if from_clause is not None: + # We'll handle this at the match_clause level + pass + return pattern_tuple + + def from_clause(self, name): + return FromClause(session_variable=str(name)) + + def node_pattern(self, *args): + variable = None + label = None + for arg in args: + s = str(arg) + if arg.type == "LABEL": + label = s + elif arg.type == "IDENTIFIER": + variable = s + return NodePattern(variable=variable, label=label) + + def edge_out(self, detail): + return EdgePattern( + variable=detail.get("variable"), + label=detail.get("label"), + direction="out", + var_length=detail.get("var_length"), + ) + + def edge_in(self, detail): + return EdgePattern( + variable=detail.get("variable"), + label=detail.get("label"), + direction="in", + var_length=detail.get("var_length"), + ) + + def edge_detail(self, *args): + result = {"variable": None, "label": None, "var_length": None} + for arg in args: + if isinstance(arg, VarLengthSpec): + result["var_length"] = arg + else: + s = str(arg) + if arg.type == "EDGE_LABEL": + result["label"] = s + elif arg.type == "IDENTIFIER": + result["variable"] = s + return result + + def var_length(self, min_hops, max_hops): + mn = int(min_hops) + mx = int(max_hops) + if mx > _MAX_VAR_LENGTH_HOPS: + raise QueryParseError( + f"variable-length max hops {mx} exceeds limit of {_MAX_VAR_LENGTH_HOPS}", + line=None, column=None, + ) + return VarLengthSpec(min_hops=mn, max_hops=mx) + + def where_clause(self, expr): + return WhereClause(expression=expr) + + def return_clause(self, *items): + return ReturnClause(items=list(items)) + + def return_item(self, expr, *rest): + alias = None + if rest: + alias = str(rest[0]) + # If expr is a plain identifier string (Token), keep as string + if hasattr(expr, 'type') and expr.type == 'IDENTIFIER': + expr = str(expr) + return ReturnItem(expression=expr, alias=alias) + + # ─── expressions ────────────────────────────────────────── + + def or_expr(self, *args): + if len(args) == 1: + return args[0] + return BooleanExpr(operator="OR", operands=list(args)) + + def and_expr(self, *args): + if len(args) == 1: + return args[0] + return BooleanExpr(operator="AND", operands=list(args)) + + def not_expr(self, expr): + return BooleanExpr(operator="NOT", operands=[expr]) + + def cmp_expr(self, left, op, right): + return ComparisonExpr(left=left, operator=str(op), right=right) + + def string_expr(self, left, op, right): + return ComparisonExpr(left=left, operator=str(op).strip().upper(), right=right) + + def in_expr(self, left, values): + return ComparisonExpr(left=left, operator="IN", right=values) + + def is_null_expr(self, operand): + return ComparisonExpr(left=operand, operator="IS NULL", right=None) + + def is_not_null_expr(self, operand): + return ComparisonExpr(left=operand, operator="IS NOT NULL", right=None) + + def value_list(self, *values): + return list(values) + + def property_access(self, var, prop): + return PropertyAccessExpr(variable=str(var), property_name=str(prop)) + + def function_call(self, name, *args): + arg_list = [] + if args and args[0] is not None: + arg_list = args[0] + return FunctionCallExpr(name=str(name), args=arg_list) + + def arg_list(self, *args): + return list(args) + + # ─── literals ───────────────────────────────────────────── + + def string_literal(self, s): + return str(s)[1:-1] # strip quotes + + def number_literal(self, n): + s = str(n) + return float(s) if "." in s else int(s) + + def true_literal(self, *_): + return True + + def false_literal(self, *_): + return False + + def null_literal(self, *_): + return None + + def IDENTIFIER(self, token): + return token + + +_parser: Lark | None = None + + +def _get_parser() -> Lark: + global _parser + if _parser is None: + _parser = Lark( + _GRAMMAR_PATH.read_text(), + parser="lalr", + transformer=CypherTransformer(), + ) + return _parser + + +def parse_cypher(query: str) -> CypherQuery: + """Parse a Cypher query string into a typed AST. + + Raises QueryParseError on invalid syntax or mutation verbs. + """ + stripped = query.strip() + if not stripped: + raise QueryParseError("empty query") + + _check_mutation_verbs(stripped) + + try: + result = _get_parser().parse(stripped) + except lark_exceptions.UnexpectedInput as e: + raise QueryParseError( + str(e), + line=getattr(e, "line", None), + column=getattr(e, "column", None), + ) from e + except Exception as e: + raise QueryParseError(str(e)) from e + + if not isinstance(result, CypherQuery): + raise QueryParseError(f"unexpected parse result type: {type(result)}") + + return result +``` + +Note: The grammar and transformer above are a starting point. The lark grammar may need iterative refinement to handle all test cases correctly — the agent implementing this task should adjust the grammar terminals, precedence rules, and transformer methods until all parser tests pass. The grammar structure and AST node mapping are the contract; the exact lark syntax may need tuning. + +- [ ] **Step 5: Run tests iteratively until all pass** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_parser.py -v` +Expected: All 22 tests pass. If specific tests fail due to grammar ambiguities, adjust the `.lark` file — terminal precedence, rule ordering, or token definitions. Common issues: IDENTIFIER vs LABEL priority, DOTTED_NAME matching, case-insensitive keywords. + +- [ ] **Step 6: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/grammar.lark packages/cli/src/opentools/chain/cypher/parser.py packages/cli/tests/chain/cypher/test_parser.py +git commit -m "feat(cypher): add lark grammar and parser" +``` + +--- + +### Task 4: Built-in Functions + Plugin Registry + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/builtins.py` +- Create: `packages/cli/src/opentools/chain/cypher/plugins.py` +- Create: `packages/cli/tests/chain/cypher/test_builtins.py` +- Create: `packages/cli/tests/chain/cypher/test_plugins.py` + +- [ ] **Step 1: Write failing tests for built-in functions** + +```python +# packages/cli/tests/chain/cypher/test_builtins.py +import pytest + +from opentools.chain.cypher.builtins import ( + builtin_collect, + builtin_has_entity, + builtin_has_mitre, + builtin_length, + builtin_nodes, + builtin_relationships, + get_builtin, + list_builtins, +) + + +def test_builtin_length(): + path = {"nodes": [1, 2, 3], "edges": [10, 20]} + assert builtin_length(path) == 2 + + +def test_builtin_length_empty_path(): + path = {"nodes": [1], "edges": []} + assert builtin_length(path) == 0 + + +def test_builtin_nodes(): + path = {"nodes": ["a", "b", "c"], "edges": [1, 2]} + assert builtin_nodes(path) == ["a", "b", "c"] + + +def test_builtin_relationships(): + path = {"nodes": ["a", "b"], "edges": ["r1"]} + assert builtin_relationships(path) == ["r1"] + + +def test_builtin_has_entity(): + node = {"entities": [{"type": "host", "canonical_value": "10.0.0.1"}, {"type": "cve", "canonical_value": "CVE-2024-1234"}]} + assert builtin_has_entity(node, "host", "10.0.0.1") is True + assert builtin_has_entity(node, "host", "10.0.0.2") is False + assert builtin_has_entity(node, "cve", "CVE-2024-1234") is True + + +def test_builtin_has_entity_no_entities(): + node = {"entities": []} + assert builtin_has_entity(node, "host", "anything") is False + + +def test_builtin_has_mitre(): + node = {"entities": [{"type": "mitre_technique", "canonical_value": "T1059"}]} + assert builtin_has_mitre(node, "T1059") is True + assert builtin_has_mitre(node, "T1078") is False + + +def test_builtin_collect(): + values = [1, 2, 3, 4] + assert builtin_collect(values) == [1, 2, 3, 4] + + +def test_get_builtin(): + fn = get_builtin("length") + assert fn is builtin_length + assert get_builtin("nonexistent") is None + + +def test_list_builtins(): + builtins = list_builtins() + assert "length" in builtins + assert "has_entity" in builtins + assert "collect" in builtins + assert len(builtins) >= 6 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_builtins.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement built-in functions** + +```python +# packages/cli/src/opentools/chain/cypher/builtins.py +"""Built-in functions for the Cypher DSL.""" +from __future__ import annotations + +from typing import Any, Callable + + +def builtin_length(path: dict) -> int: + return len(path.get("edges", [])) + + +def builtin_nodes(path: dict) -> list: + return path.get("nodes", []) + + +def builtin_relationships(path: dict) -> list: + return path.get("edges", []) + + +def builtin_has_entity(node: dict, entity_type: str, entity_value: str) -> bool: + for ent in node.get("entities", []): + if ent.get("type") == entity_type and ent.get("canonical_value") == entity_value: + return True + return False + + +def builtin_has_mitre(node: dict, technique_id: str) -> bool: + return builtin_has_entity(node, "mitre_technique", technique_id) + + +def builtin_collect(values: list) -> list: + return list(values) + + +_BUILTINS: dict[str, dict] = { + "length": {"fn": builtin_length, "help": "Number of edges in a path", "arg_types": ["path"], "return_type": "int"}, + "nodes": {"fn": builtin_nodes, "help": "List of nodes in a path", "arg_types": ["path"], "return_type": "list"}, + "relationships": {"fn": builtin_relationships, "help": "List of edges in a path", "arg_types": ["path"], "return_type": "list"}, + "has_entity": {"fn": builtin_has_entity, "help": "Check if node mentions entity", "arg_types": ["node", "str", "str"], "return_type": "bool"}, + "has_mitre": {"fn": builtin_has_mitre, "help": "Check if node mentions MITRE technique", "arg_types": ["node", "str"], "return_type": "bool"}, + "collect": {"fn": builtin_collect, "help": "Aggregate values into a list", "arg_types": ["list"], "return_type": "list", "is_aggregation": True}, +} + + +def get_builtin(name: str) -> Callable | None: + entry = _BUILTINS.get(name) + return entry["fn"] if entry else None + + +def list_builtins() -> dict[str, dict]: + return {name: {k: v for k, v in info.items() if k != "fn"} for name, info in _BUILTINS.items()} +``` + +- [ ] **Step 4: Run builtin tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_builtins.py -v` +Expected: 10 passed + +- [ ] **Step 5: Write failing tests for plugin registry** + +```python +# packages/cli/tests/chain/cypher/test_plugins.py +import pytest + +from opentools.chain.cypher.plugins import ( + PluginFunctionRegistry, +) + + +@pytest.fixture +def registry(): + return PluginFunctionRegistry() + + +def test_register_scalar_function(registry): + registry.register_function( + "my_plugin.risk_score", + fn=lambda node: 0.9, + help="Risk score", + arg_types=["node"], + return_type="float", + ) + assert registry.get_function("my_plugin.risk_score") is not None + + +def test_register_aggregation(registry): + registry.register_aggregation( + "my_plugin.combined_risk", + fn=lambda values: max(values), + help="Max risk", + input_type="float", + return_type="float", + ) + assert registry.get_aggregation("my_plugin.combined_risk") is not None + + +def test_reject_undotted_plugin_name(registry): + with pytest.raises(ValueError, match="dotted"): + registry.register_function( + "no_dot", + fn=lambda x: x, + help="bad", + arg_types=["node"], + return_type="float", + ) + + +def test_reject_duplicate_name(registry): + registry.register_function( + "my_plugin.f", + fn=lambda x: x, + help="first", + arg_types=["node"], + return_type="float", + ) + with pytest.raises(ValueError, match="already registered"): + registry.register_function( + "my_plugin.f", + fn=lambda x: x, + help="second", + arg_types=["node"], + return_type="float", + ) + + +def test_list_all_functions(registry): + registry.register_function( + "a.one", fn=lambda x: x, help="h1", arg_types=["node"], return_type="float", + ) + registry.register_aggregation( + "a.two", fn=lambda v: sum(v), help="h2", input_type="float", return_type="float", + ) + all_fns = registry.list_all() + assert "a.one" in all_fns + assert "a.two" in all_fns + assert all_fns["a.one"]["kind"] == "scalar" + assert all_fns["a.two"]["kind"] == "aggregation" + + +def test_resolve_returns_none_for_unknown(registry): + assert registry.get_function("nonexistent.fn") is None + assert registry.get_aggregation("nonexistent.fn") is None +``` + +- [ ] **Step 6: Implement plugin registry** + +```python +# packages/cli/src/opentools/chain/cypher/plugins.py +"""Plugin function registry for the Cypher DSL.""" +from __future__ import annotations + +from typing import Any, Callable + + +class PluginFunctionRegistry: + def __init__(self) -> None: + self._scalars: dict[str, dict] = {} + self._aggregations: dict[str, dict] = {} + + def register_function( + self, + name: str, + fn: Callable, + *, + help: str = "", + arg_types: list[str], + return_type: str, + ) -> None: + if "." not in name: + raise ValueError(f"plugin function names must be dotted (e.g., 'plugin.func'), got: {name!r}") + if name in self._scalars or name in self._aggregations: + raise ValueError(f"function {name!r} already registered") + self._scalars[name] = { + "fn": fn, + "help": help, + "arg_types": arg_types, + "return_type": return_type, + } + + def register_aggregation( + self, + name: str, + fn: Callable, + *, + help: str = "", + input_type: str, + return_type: str, + ) -> None: + if "." not in name: + raise ValueError(f"plugin aggregation names must be dotted (e.g., 'plugin.agg'), got: {name!r}") + if name in self._scalars or name in self._aggregations: + raise ValueError(f"function {name!r} already registered") + self._aggregations[name] = { + "fn": fn, + "help": help, + "input_type": input_type, + "return_type": return_type, + } + + def get_function(self, name: str) -> Callable | None: + entry = self._scalars.get(name) + return entry["fn"] if entry else None + + def get_aggregation(self, name: str) -> Callable | None: + entry = self._aggregations.get(name) + return entry["fn"] if entry else None + + def list_all(self) -> dict[str, dict]: + result: dict[str, dict] = {} + for name, info in self._scalars.items(): + result[name] = { + "kind": "scalar", + "help": info["help"], + "arg_types": info["arg_types"], + "return_type": info["return_type"], + } + for name, info in self._aggregations.items(): + result[name] = { + "kind": "aggregation", + "help": info["help"], + "input_type": info["input_type"], + "return_type": info["return_type"], + } + return result +``` + +- [ ] **Step 7: Run all tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_builtins.py tests/chain/cypher/test_plugins.py -v` +Expected: 16 passed + +- [ ] **Step 8: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/builtins.py packages/cli/src/opentools/chain/cypher/plugins.py packages/cli/tests/chain/cypher/test_builtins.py packages/cli/tests/chain/cypher/test_plugins.py +git commit -m "feat(cypher): add built-in functions and plugin registry" +``` + +--- + +### Task 5: Result Types + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/result.py` + +- [ ] **Step 1: Write the result types** + +These are data containers used by the executor (Task 8). No separate test file — they're exercised by executor tests. + +```python +# packages/cli/src/opentools/chain/cypher/result.py +"""Query result types.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class QueryStats: + duration_ms: float = 0.0 + bindings_explored: int = 0 + rows_returned: int = 0 + + +@dataclass +class SubgraphProjection: + node_indices: set[int] = field(default_factory=set) + edge_tuples: set[tuple[int, int]] = field(default_factory=set) + + +@dataclass +class QueryResult: + columns: list[str] = field(default_factory=list) + rows: list[dict[str, Any]] = field(default_factory=list) + subgraph: SubgraphProjection | None = None + stats: QueryStats = field(default_factory=QueryStats) + truncated: bool = False + truncation_reason: str | None = None +``` + +- [ ] **Step 2: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/result.py +git commit -m "feat(cypher): add result types" +``` + +--- + +### Task 6: Virtual Graph Builder + Cache + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/virtual_graph.py` +- Create: `packages/cli/tests/chain/cypher/test_virtual_graph.py` + +- [ ] **Step 1: Write failing tests for VirtualGraphBuilder** + +```python +# packages/cli/tests/chain/cypher/test_virtual_graph.py +"""Tests for the virtual heterogeneous graph builder and cache.""" +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest +import rustworkx as rx + +from opentools.chain.cypher.virtual_graph import ( + EntityNode, + VirtualGraph, + VirtualGraphBuilder, + VirtualGraphCache, +) +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import EdgeData, FindingNode, MasterGraph +from opentools.chain.types import MentionField, RelationStatus + + +def _make_master_graph() -> MasterGraph: + """Create a small master graph with 3 findings and 2 LINKED edges.""" + g = rx.PyDiGraph() + now = datetime.now(timezone.utc) + n0 = g.add_node(FindingNode(finding_id="fnd_1", severity="high", tool="nmap", title="Open SSH", created_at=now)) + n1 = g.add_node(FindingNode(finding_id="fnd_2", severity="critical", tool="nuclei", title="RCE vuln", created_at=now)) + n2 = g.add_node(FindingNode(finding_id="fnd_3", severity="medium", tool="burp", title="XSS", created_at=now)) + + g.add_edge(n0, n1, EdgeData( + relation_id="rel_1", weight=2.0, cost=0.5, status="auto_confirmed", + symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None, + )) + g.add_edge(n1, n2, EdgeData( + relation_id="rel_2", weight=1.5, cost=0.7, status="auto_confirmed", + symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None, + )) + + return MasterGraph( + graph=g, + node_map={"fnd_1": n0, "fnd_2": n1, "fnd_3": n2}, + reverse_map={n0: "fnd_1", n1: "fnd_2", n2: "fnd_3"}, + generation=1, + max_weight=2.0, + ) + + +def _make_entities() -> list[Entity]: + now = datetime.now(timezone.utc) + return [ + Entity(id="ent_host1", type="host", canonical_value="10.0.0.1", first_seen_at=now, last_seen_at=now, mention_count=2), + Entity(id="ent_cve1", type="cve", canonical_value="CVE-2024-1234", first_seen_at=now, last_seen_at=now, mention_count=1), + ] + + +def _make_mentions() -> list[EntityMention]: + now = datetime.now(timezone.utc) + return [ + EntityMention(id="m1", entity_id="ent_host1", finding_id="fnd_1", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m2", entity_id="ent_host1", finding_id="fnd_2", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m3", entity_id="ent_cve1", finding_id="fnd_2", field=MentionField.TITLE, raw_value="CVE-2024-1234", extractor="security_regex", confidence=0.95, created_at=now), + ] + + +@pytest.mark.asyncio +async def test_build_virtual_graph_node_counts(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + # 3 findings + 2 entities = 5 nodes + assert vg.graph.num_nodes() == 5 + assert len(vg.finding_map) == 3 + assert len(vg.entity_map) == 2 + + +@pytest.mark.asyncio +async def test_build_virtual_graph_edge_counts(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + # 2 LINKED edges + 3 MENTIONED_IN edges = 5 total + assert vg.graph.num_edges() == 5 + + +@pytest.mark.asyncio +async def test_build_virtual_graph_node_labels(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + finding_labels = [vg.node_labels[idx] for idx in vg.finding_map.values()] + assert all(l == "Finding" for l in finding_labels) + + host_idx = vg.entity_map["ent_host1"] + assert vg.node_labels[host_idx] == "Host" + + cve_idx = vg.entity_map["ent_cve1"] + assert vg.node_labels[cve_idx] == "CVE" + + +@pytest.mark.asyncio +async def test_mentioned_in_direction(): + """MENTIONED_IN edges go Entity → Finding.""" + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + host_idx = vg.entity_map["ent_host1"] + successors = list(vg.graph.successor_indices(host_idx)) + # Host entity should have successors (findings it's mentioned in) + assert len(successors) == 2 + successor_ids = {vg.reverse_map[s] for s in successors} + assert successor_ids == {"fnd_1", "fnd_2"} + + +@pytest.mark.asyncio +async def test_linked_edges_preserved(): + """LINKED edges between findings are preserved from the master graph.""" + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + fnd1_idx = vg.finding_map["fnd_1"] + fnd2_idx = vg.finding_map["fnd_2"] + edge_data = vg.graph.get_edge_data(fnd1_idx, fnd2_idx) + assert edge_data is not None + + +@pytest.mark.asyncio +async def test_entity_node_properties(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + + host_idx = vg.entity_map["ent_host1"] + node_data = vg.graph.get_node_data(host_idx) + assert isinstance(node_data, EntityNode) + assert node_data.entity_id == "ent_host1" + assert node_data.canonical_value == "10.0.0.1" + assert node_data.entity_type == "host" + + +@pytest.mark.asyncio +async def test_virtual_graph_cache_reuse(): + """Same cache key returns the same VirtualGraph instance.""" + master = _make_master_graph() + entities = _make_entities() + mentions = _make_mentions() + + store = AsyncMock() + store.current_linker_generation = AsyncMock(return_value=1) + store.list_entities = AsyncMock(return_value=entities) + # Need a method to fetch all mentions for a scope + store.fetch_all_mentions_in_scope = AsyncMock(return_value=mentions) + + graph_cache = AsyncMock() + graph_cache.get_master_graph = AsyncMock(return_value=master) + + cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=4) + vg1 = await cache.get(user_id=None, include_candidates=False, engagement_ids=None) + vg2 = await cache.get(user_id=None, include_candidates=False, engagement_ids=None) + assert vg1 is vg2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_virtual_graph.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement VirtualGraphBuilder and VirtualGraphCache** + +```python +# packages/cli/src/opentools/chain/cypher/virtual_graph.py +"""Virtual heterogeneous graph: findings + entities as first-class nodes.""" +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import TYPE_CHECKING +from uuid import UUID + +import rustworkx as rx + +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import EdgeData, FindingNode, MasterGraph + +if TYPE_CHECKING: + from opentools.chain.query.graph_cache import GraphCache + from opentools.chain.store_protocol import ChainStoreProtocol + +# Entity type → node label mapping +_ENTITY_TYPE_TO_LABEL: dict[str, str] = { + "host": "Host", + "ip": "IP", + "cve": "CVE", + "domain": "Domain", + "port": "Port", + "mitre_technique": "MitreAttack", +} + + +@dataclass +class EntityNode: + entity_id: str + entity_type: str + canonical_value: str + mention_count: int + + +@dataclass +class MentionedInEdge: + mention_id: str + field: str + confidence: float + extractor: str + + +@dataclass +class VirtualGraph: + graph: rx.PyDiGraph + finding_map: dict[str, int] # finding_id → node index + entity_map: dict[str, int] # entity_id → node index + reverse_map: dict[int, str] # node index → id + node_labels: dict[int, str] # node index → label + generation: int + + +class VirtualGraphBuilder: + """Build a VirtualGraph from a MasterGraph + entity/mention data.""" + + def build( + self, + master: MasterGraph, + entities: list[Entity], + mentions: list[EntityMention], + ) -> VirtualGraph: + graph = rx.PyDiGraph() + finding_map: dict[str, int] = {} + entity_map: dict[str, int] = {} + reverse_map: dict[int, str] = {} + node_labels: dict[int, str] = {} + + # Copy finding nodes from master graph + for finding_id, master_idx in master.node_map.items(): + node_data = master.graph.get_node_data(master_idx) + idx = graph.add_node(node_data) + finding_map[finding_id] = idx + reverse_map[idx] = finding_id + node_labels[idx] = "Finding" + + # Copy LINKED edges from master graph + for edge_idx in master.graph.edge_indices(): + src, tgt = master.graph.get_edge_endpoints_by_index(edge_idx) + edge_data = master.graph.get_edge_data_by_index(edge_idx) + src_id = master.reverse_map.get(src) + tgt_id = master.reverse_map.get(tgt) + if src_id and tgt_id and src_id in finding_map and tgt_id in finding_map: + graph.add_edge(finding_map[src_id], finding_map[tgt_id], edge_data) + + # Add entity nodes + for entity in entities: + label = _ENTITY_TYPE_TO_LABEL.get(entity.type, "Entity") + en = EntityNode( + entity_id=entity.id, + entity_type=entity.type, + canonical_value=entity.canonical_value, + mention_count=entity.mention_count, + ) + idx = graph.add_node(en) + entity_map[entity.id] = idx + reverse_map[idx] = entity.id + node_labels[idx] = label + + # Add MENTIONED_IN edges: Entity → Finding + for mention in mentions: + entity_idx = entity_map.get(mention.entity_id) + finding_idx = finding_map.get(mention.finding_id) + if entity_idx is not None and finding_idx is not None: + edge = MentionedInEdge( + mention_id=mention.id, + field=mention.field.value if hasattr(mention.field, "value") else str(mention.field), + confidence=mention.confidence, + extractor=mention.extractor, + ) + graph.add_edge(entity_idx, finding_idx, edge) + + return VirtualGraph( + graph=graph, + finding_map=finding_map, + entity_map=entity_map, + reverse_map=reverse_map, + node_labels=node_labels, + generation=master.generation, + ) + + +class VirtualGraphCache: + """Async LRU cache of virtual graphs.""" + + def __init__( + self, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + maxsize: int = 4, + ) -> None: + self.store = store + self.graph_cache = graph_cache + self.maxsize = maxsize + self._cache: dict[tuple, VirtualGraph] = {} + self._access_order: list[tuple] = [] + self._build_locks: dict[tuple, asyncio.Lock] = {} + self._builder = VirtualGraphBuilder() + + async def get( + self, + *, + user_id: UUID | None, + include_candidates: bool = False, + engagement_ids: frozenset[str] | None = None, + ) -> VirtualGraph: + generation = await self.store.current_linker_generation(user_id=user_id) + key = ( + str(user_id) if user_id else None, + generation, + include_candidates, + engagement_ids, + ) + + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + + lock = self._build_locks.setdefault(key, asyncio.Lock()) + async with lock: + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + + master = await self.graph_cache.get_master_graph( + user_id=user_id, + include_candidates=include_candidates, + ) + entities = await self.store.list_entities( + user_id=user_id, limit=100_000, + ) + mentions = await self.store.fetch_all_mentions_in_scope( + user_id=user_id, + ) + + vg = self._builder.build(master, entities, mentions) + self._cache[key] = vg + self._access_order.append(key) + + while len(self._access_order) > self.maxsize: + oldest = self._access_order.pop(0) + self._cache.pop(oldest, None) + self._build_locks.pop(oldest, None) + + return vg + + def invalidate(self, *, user_id: UUID | None) -> None: + user_key = str(user_id) if user_id else None + to_remove = [k for k in self._access_order if k[0] == user_key] + for k in to_remove: + self._access_order.remove(k) + self._cache.pop(k, None) + self._build_locks.pop(k, None) + + def clear(self) -> None: + self._cache.clear() + self._access_order.clear() + self._build_locks.clear() +``` + +Note: The `VirtualGraphCache.get()` method calls `store.fetch_all_mentions_in_scope()` — this is a new protocol method that needs to be added to `ChainStoreProtocol` and both backends. The implementing agent should add this method (returns all `EntityMention` rows for the user scope) as part of this task if it doesn't already exist. Check `store_protocol.py` and the existing mention-related methods (`mentions_for_finding`, `add_mentions_bulk`) for the pattern. + +- [ ] **Step 4: Run tests and iterate until passing** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_virtual_graph.py -v` +Expected: 7 passed. If `get_edge_endpoints_by_index` or `get_edge_data_by_index` are not available in the installed rustworkx version, use the edge iteration API instead (`edge_list()` returns `(src, tgt)` tuples, `edges()` returns payloads). + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/virtual_graph.py packages/cli/tests/chain/cypher/test_virtual_graph.py +git commit -m "feat(cypher): add virtual heterogeneous graph builder and cache" +``` + +--- + +### Task 7: Planner + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/planner.py` +- Create: `packages/cli/tests/chain/cypher/test_planner.py` + +- [ ] **Step 1: Write failing tests for the planner** + +```python +# packages/cli/tests/chain/cypher/test_planner.py +import pytest + +from opentools.chain.cypher.ast_nodes import ( + ComparisonExpr, + CypherQuery, + EdgePattern, + MatchClause, + NodePattern, + PropertyAccessExpr, + ReturnClause, + ReturnItem, + VarLengthSpec, + WhereClause, +) +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.planner import plan_query, PlanStep, QueryPlan + + +def _simple_query() -> CypherQuery: + """MATCH (a:Finding) RETURN a""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + (NodePattern(variable="a", label="Finding"),), + ]), + where_clause=None, + return_clause=ReturnClause(items=[ReturnItem(expression="a", alias=None)]), + ) + + +def _two_node_query() -> CypherQuery: + """MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + ( + NodePattern(variable="a", label="Finding"), + EdgePattern(variable="r", label="LINKED", direction="out", var_length=None), + NodePattern(variable="b", label="Finding"), + ), + ]), + where_clause=None, + return_clause=ReturnClause(items=[ + ReturnItem(expression="a", alias=None), + ReturnItem(expression="b", alias=None), + ]), + ) + + +def _filtered_query() -> CypherQuery: + """MATCH (a:Finding) WHERE a.severity = "critical" RETURN a""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + (NodePattern(variable="a", label="Finding"),), + ]), + where_clause=WhereClause(expression=ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", + right="critical", + )), + return_clause=ReturnClause(items=[ReturnItem(expression="a", alias=None)]), + ) + + +def _var_length_query() -> CypherQuery: + """MATCH (a:Finding)-[r:LINKED*1..5]->(b:Finding) RETURN a, b""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + ( + NodePattern(variable="a", label="Finding"), + EdgePattern(variable="r", label="LINKED", direction="out", var_length=VarLengthSpec(min_hops=1, max_hops=5)), + NodePattern(variable="b", label="Finding"), + ), + ]), + where_clause=None, + return_clause=ReturnClause(items=[ + ReturnItem(expression="a", alias=None), + ReturnItem(expression="b", alias=None), + ]), + ) + + +def test_plan_simple_scan(): + plan = plan_query(_simple_query(), QueryLimits()) + assert len(plan.steps) == 1 + assert plan.steps[0].kind == "scan" + assert plan.steps[0].label == "Finding" + assert plan.steps[0].target_var == "a" + + +def test_plan_two_node_has_scan_then_expand(): + plan = plan_query(_two_node_query(), QueryLimits()) + assert plan.steps[0].kind == "scan" + assert plan.steps[0].target_var == "a" + assert plan.steps[1].kind == "expand" + assert plan.steps[1].target_var == "r" + assert plan.steps[1].label == "LINKED" + # The third step binds b via the expand target + # (or b is bound implicitly by the expand — depends on implementation) + + +def test_plan_predicate_pushdown(): + plan = plan_query(_filtered_query(), QueryLimits()) + # The WHERE predicate on 'a' should be pushed down to the scan step for 'a' + scan_step = plan.steps[0] + assert scan_step.kind == "scan" + assert scan_step.target_var == "a" + assert len(scan_step.predicates) == 1 + assert isinstance(scan_step.predicates[0], ComparisonExpr) + + +def test_plan_var_length_expand(): + plan = plan_query(_var_length_query(), QueryLimits()) + var_length_steps = [s for s in plan.steps if s.kind == "var_length_expand"] + assert len(var_length_steps) == 1 + vl = var_length_steps[0] + assert vl.min_hops == 1 + assert vl.max_hops == 5 + assert vl.label == "LINKED" + + +def test_plan_preserves_limits(): + limits = QueryLimits(timeout_seconds=60.0, max_rows=500) + plan = plan_query(_simple_query(), limits) + assert plan.limits.timeout_seconds == 60.0 + assert plan.limits.max_rows == 500 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_planner.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement planner** + +```python +# packages/cli/src/opentools/chain/cypher/planner.py +"""Query planner: AST → QueryPlan with predicate pushdown.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + CypherQuery, + EdgePattern, + FunctionCallExpr, + NodePattern, + PropertyAccessExpr, + ReturnClause, +) +from opentools.chain.cypher.limits import QueryLimits + + +@dataclass +class PlanStep: + kind: Literal["scan", "expand", "filter", "var_length_expand"] + target_var: str + label: str | None = None + direction: Literal["out", "in", "both"] | None = None + min_hops: int | None = None + max_hops: int | None = None + predicates: list[Any] = field(default_factory=list) + from_session: str | None = None # for session result scans + + +@dataclass +class ReturnSpec: + items: list[Any] # ReturnItem instances from the AST + + +@dataclass +class QueryPlan: + steps: list[PlanStep] + return_spec: ReturnSpec + limits: QueryLimits + + +def _extract_variables(expr: Any) -> set[str]: + """Extract variable names referenced in an expression.""" + if isinstance(expr, PropertyAccessExpr): + return {expr.variable} + if isinstance(expr, ComparisonExpr): + return _extract_variables(expr.left) | ( + _extract_variables(expr.right) if not isinstance(expr.right, (str, int, float, bool, type(None), list)) else set() + ) + if isinstance(expr, BooleanExpr): + result: set[str] = set() + for op in expr.operands: + result |= _extract_variables(op) + return result + if isinstance(expr, FunctionCallExpr): + result = set() + for arg in expr.args: + if isinstance(arg, str): + result.add(arg) + else: + result |= _extract_variables(arg) + return result + if isinstance(expr, str): + return {expr} + return set() + + +def _flatten_and(expr: Any) -> list[Any]: + """Flatten AND expressions into a list of conjuncts.""" + if isinstance(expr, BooleanExpr) and expr.operator == "AND": + result = [] + for op in expr.operands: + result.extend(_flatten_and(op)) + return result + return [expr] + + +def plan_query(query: CypherQuery, limits: QueryLimits) -> QueryPlan: + """Convert a parsed CypherQuery AST into a QueryPlan.""" + steps: list[PlanStep] = [] + + # Collect all WHERE predicates as a flat list of conjuncts + pending_predicates: list[Any] = [] + if query.where_clause is not None: + pending_predicates = _flatten_and(query.where_clause.expression) + + # Track which variables are bound so far + bound_vars: set[str] = set() + + # Process each pattern in the MATCH clause + for pattern_tuple in query.match_clause.patterns: + elements = list(pattern_tuple) + + for i, element in enumerate(elements): + if isinstance(element, NodePattern): + var = element.variable + if var and var not in bound_vars: + # Check if this is a FROM-clause scan + from_session = None + if query.from_clause and i == 0: + from_session = query.from_clause.session_variable + + step = PlanStep( + kind="scan", + target_var=var, + label=element.label, + from_session=from_session, + ) + bound_vars.add(var) + + # Push down predicates whose variables are now all bound + remaining = [] + for pred in pending_predicates: + pred_vars = _extract_variables(pred) + if pred_vars <= bound_vars: + step.predicates.append(pred) + else: + remaining.append(pred) + pending_predicates = remaining + + steps.append(step) + + elif isinstance(element, EdgePattern): + var = element.variable + if element.var_length is not None: + step = PlanStep( + kind="var_length_expand", + target_var=var or f"_anon_edge_{i}", + label=element.label, + direction=element.direction, + min_hops=element.var_length.min_hops, + max_hops=element.var_length.max_hops, + ) + else: + step = PlanStep( + kind="expand", + target_var=var or f"_anon_edge_{i}", + label=element.label, + direction=element.direction, + ) + if var: + bound_vars.add(var) + + # Push down predicates + remaining = [] + for pred in pending_predicates: + pred_vars = _extract_variables(pred) + if pred_vars <= bound_vars: + step.predicates.append(pred) + else: + remaining.append(pred) + pending_predicates = remaining + + steps.append(step) + + # Any remaining predicates become a final filter step + if pending_predicates: + steps.append(PlanStep( + kind="filter", + target_var="_post_filter", + predicates=pending_predicates, + )) + + return QueryPlan( + steps=steps, + return_spec=ReturnSpec(items=query.return_clause.items), + limits=limits, + ) +``` + +- [ ] **Step 4: Run tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_planner.py -v` +Expected: 5 passed + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/planner.py packages/cli/tests/chain/cypher/test_planner.py +git commit -m "feat(cypher): add query planner with predicate pushdown" +``` + +--- + +### Task 8: Executor + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/executor.py` +- Create: `packages/cli/tests/chain/cypher/test_executor.py` + +- [ ] **Step 1: Write failing tests for the executor** + +```python +# packages/cli/tests/chain/cypher/test_executor.py +"""End-to-end executor tests against small fixture virtual graphs.""" +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +import rustworkx as rx + +from opentools.chain.cypher.executor import CypherExecutor +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.parser import parse_cypher +from opentools.chain.cypher.planner import plan_query +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import EntityNode, MentionedInEdge, VirtualGraph +from opentools.chain.query.graph_cache import EdgeData, FindingNode + + +def _build_test_vg() -> VirtualGraph: + """3 findings, 1 host entity, 2 LINKED edges, 2 MENTIONED_IN edges.""" + g = rx.PyDiGraph() + now = datetime.now(timezone.utc) + + # Findings + n0 = g.add_node(FindingNode(finding_id="fnd_1", severity="high", tool="nmap", title="Open SSH", created_at=now)) + n1 = g.add_node(FindingNode(finding_id="fnd_2", severity="critical", tool="nuclei", title="RCE vuln", created_at=now)) + n2 = g.add_node(FindingNode(finding_id="fnd_3", severity="medium", tool="burp", title="XSS", created_at=now)) + + # Entity + n3 = g.add_node(EntityNode(entity_id="ent_host1", entity_type="host", canonical_value="10.0.0.1", mention_count=2)) + + # LINKED: fnd_1 -> fnd_2, fnd_2 -> fnd_3 + g.add_edge(n0, n1, EdgeData( + relation_id="rel_1", weight=2.0, cost=0.5, status="auto_confirmed", + symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None, + )) + g.add_edge(n1, n2, EdgeData( + relation_id="rel_2", weight=1.5, cost=0.7, status="auto_confirmed", + symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None, + )) + + # MENTIONED_IN: host -> fnd_1, host -> fnd_2 + g.add_edge(n3, n0, MentionedInEdge(mention_id="m1", field="description", confidence=1.0, extractor="ioc_finder")) + g.add_edge(n3, n1, MentionedInEdge(mention_id="m2", field="description", confidence=1.0, extractor="ioc_finder")) + + return VirtualGraph( + graph=g, + finding_map={"fnd_1": n0, "fnd_2": n1, "fnd_3": n2}, + entity_map={"ent_host1": n3}, + reverse_map={n0: "fnd_1", n1: "fnd_2", n2: "fnd_3", n3: "ent_host1"}, + node_labels={n0: "Finding", n1: "Finding", n2: "Finding", n3: "Host"}, + generation=1, + ) + + +def _execute(query_str: str, vg: VirtualGraph | None = None, limits: QueryLimits | None = None) -> QueryResult: + """Parse, plan, and execute a query synchronously for tests.""" + import asyncio + if vg is None: + vg = _build_test_vg() + if limits is None: + limits = QueryLimits() + ast = parse_cypher(query_str) + plan = plan_query(ast, limits) + executor = CypherExecutor( + virtual_graph=vg, + plan=plan, + session=QuerySession(), + plugin_registry=PluginFunctionRegistry(), + limits=limits, + ) + return asyncio.get_event_loop().run_until_complete(executor.execute()) + + +@pytest.mark.asyncio +async def test_scan_all_findings(): + result = _execute("MATCH (a:Finding) RETURN a") + assert len(result.rows) == 3 + assert "a" in result.columns + + +@pytest.mark.asyncio +async def test_scan_entity_label(): + result = _execute("MATCH (h:Host) RETURN h") + assert len(result.rows) == 1 + + +@pytest.mark.asyncio +async def test_expand_linked(): + result = _execute("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert len(result.rows) == 2 # fnd_1->fnd_2, fnd_2->fnd_3 + + +@pytest.mark.asyncio +async def test_expand_mentioned_in(): + result = _execute("MATCH (h:Host)-[r:MENTIONED_IN]->(f:Finding) RETURN h, f") + assert len(result.rows) == 2 # host->fnd_1, host->fnd_2 + + +@pytest.mark.asyncio +async def test_where_filter(): + result = _execute('MATCH (a:Finding) WHERE a.severity = "critical" RETURN a') + assert len(result.rows) == 1 + assert result.rows[0]["a"]["severity"] == "critical" + + +@pytest.mark.asyncio +async def test_where_numeric_comparison(): + result = _execute("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.weight > 1.8 RETURN a, b") + assert len(result.rows) == 1 # only rel_1 has weight=2.0 + + +@pytest.mark.asyncio +async def test_return_property(): + result = _execute("MATCH (a:Finding) RETURN a.title, a.severity") + assert len(result.rows) == 3 + assert "a.title" in result.columns or "title" in str(result.columns) + + +@pytest.mark.asyncio +async def test_subgraph_projection(): + result = _execute("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert result.subgraph is not None + assert len(result.subgraph.node_indices) >= 2 + + +@pytest.mark.asyncio +async def test_resource_limit_max_rows(): + result = _execute("MATCH (a:Finding) RETURN a", limits=QueryLimits(max_rows=1)) + assert len(result.rows) == 1 + assert result.truncated is True + + +@pytest.mark.asyncio +async def test_resource_limit_timeout(): + """A near-zero timeout should abort quickly.""" + from opentools.chain.cypher.errors import QueryResourceError + # This test verifies the timeout mechanism exists — with a tiny graph + # it may not actually trigger, so we use a very small timeout + result = _execute("MATCH (a:Finding) RETURN a", limits=QueryLimits(timeout_seconds=0.0001)) + # Either it completes (graph is tiny) or raises QueryResourceError + # Both are acceptable — the important thing is the mechanism exists + assert isinstance(result, QueryResult) + + +@pytest.mark.asyncio +async def test_empty_result(): + result = _execute('MATCH (a:Finding) WHERE a.severity = "nonexistent" RETURN a') + assert len(result.rows) == 0 + assert result.truncated is False +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_executor.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement the executor** + +```python +# packages/cli/src/opentools/chain/cypher/executor.py +"""CypherExecutor: walks a QueryPlan against a VirtualGraph.""" +from __future__ import annotations + +import time +from typing import Any + +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + FunctionCallExpr, + PropertyAccessExpr, + ReturnItem, +) +from opentools.chain.cypher.builtins import get_builtin +from opentools.chain.cypher.errors import QueryResourceError +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.planner import PlanStep, QueryPlan +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult, QueryStats, SubgraphProjection +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import EntityNode, MentionedInEdge, VirtualGraph +from opentools.chain.query.graph_cache import EdgeData, FindingNode + +# Binding = dict mapping variable names to node/edge indices or data +Binding = dict[str, Any] + + +class CypherExecutor: + def __init__( + self, + *, + virtual_graph: VirtualGraph, + plan: QueryPlan, + session: QuerySession, + plugin_registry: PluginFunctionRegistry, + limits: QueryLimits, + ) -> None: + self.vg = virtual_graph + self.plan = plan + self.session = session + self.plugins = plugin_registry + self.limits = limits + + async def execute(self) -> QueryResult: + start = time.monotonic() + bindings: list[Binding] = [{}] # start with one empty binding + explored = 0 + + for step in self.plan.steps: + # Timeout check + elapsed = time.monotonic() - start + if elapsed > self.limits.timeout_seconds: + raise QueryResourceError( + f"query timeout after {elapsed:.1f}s", + limit_name="timeout_seconds", + limit_value=self.limits.timeout_seconds, + ) + + if step.kind == "scan": + bindings = self._step_scan(step, bindings) + elif step.kind == "expand": + bindings = self._step_expand(step, bindings) + elif step.kind == "var_length_expand": + bindings = self._step_var_length_expand(step, bindings) + elif step.kind == "filter": + bindings = self._step_filter(step, bindings) + + explored += len(bindings) + + # Intermediate binding cap + if len(bindings) > self.limits.intermediate_binding_cap: + raise QueryResourceError( + f"intermediate binding cap exceeded: {len(bindings)} > {self.limits.intermediate_binding_cap}", + limit_name="intermediate_binding_cap", + limit_value=self.limits.intermediate_binding_cap, + ) + + # Project RETURN + columns, rows = self._project_return(bindings) + + # Truncate + truncated = False + truncation_reason = None + if len(rows) > self.limits.max_rows: + rows = rows[:self.limits.max_rows] + truncated = True + truncation_reason = f"max_rows ({self.limits.max_rows})" + + # Build subgraph projection + subgraph = self._build_subgraph(bindings) + + elapsed_ms = (time.monotonic() - start) * 1000 + return QueryResult( + columns=columns, + rows=rows, + subgraph=subgraph, + stats=QueryStats(duration_ms=elapsed_ms, bindings_explored=explored, rows_returned=len(rows)), + truncated=truncated, + truncation_reason=truncation_reason, + ) + + # ─── step implementations ──────────────────────────────────── + + def _step_scan(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Scan all nodes matching the label, create bindings.""" + new_bindings: list[Binding] = [] + + # If scanning from session, use stored result set + if step.from_session: + stored = self.session.get(step.from_session) + if stored is not None: + for row in stored.rows: + if step.target_var in row: + binding = {step.target_var: row[step.target_var]} + new_bindings.append(binding) + return new_bindings + + for idx in self.vg.graph.node_indices(): + label = self.vg.node_labels.get(idx) + if step.label and label != step.label: + continue + + node_data = self.vg.graph.get_node_data(idx) + node_dict = self._node_to_dict(node_data, idx) + + for b in bindings: + new_b = {**b, step.target_var: node_dict} + new_b[f"_idx_{step.target_var}"] = idx # internal index tracking + if self._check_predicates(step.predicates, new_b): + new_bindings.append(new_b) + + return new_bindings + + def _step_expand(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Expand from bound nodes along edges of the specified type.""" + new_bindings: list[Binding] = [] + + for b in bindings: + # Find the last bound node index + last_node_var = self._last_node_var(b) + if last_node_var is None: + continue + src_idx = b.get(f"_idx_{last_node_var}") + if src_idx is None: + continue + + # Get edges based on direction + if step.direction == "out": + neighbors = self._outgoing_edges(src_idx, step.label) + elif step.direction == "in": + neighbors = self._incoming_edges(src_idx, step.label) + else: + neighbors = self._outgoing_edges(src_idx, step.label) + self._incoming_edges(src_idx, step.label) + + for tgt_idx, edge_data in neighbors: + tgt_node = self.vg.graph.get_node_data(tgt_idx) + tgt_dict = self._node_to_dict(tgt_node, tgt_idx) + edge_dict = self._edge_to_dict(edge_data) + + # Find the next node variable from the plan + next_node_var = self._next_node_var_after_edge(step.target_var) + + new_b = {**b} + new_b[step.target_var] = edge_dict + new_b[f"_idx_{step.target_var}"] = (src_idx, tgt_idx) + if next_node_var: + new_b[next_node_var] = tgt_dict + new_b[f"_idx_{next_node_var}"] = tgt_idx + + # Check label on target node if the plan specifies one + tgt_label = self.vg.node_labels.get(tgt_idx) + next_step_label = self._get_next_node_label(step.target_var) + if next_step_label and tgt_label != next_step_label: + continue + + if self._check_predicates(step.predicates, new_b): + new_bindings.append(new_b) + + return new_bindings + + def _step_var_length_expand(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Bounded DFS for variable-length path patterns.""" + new_bindings: list[Binding] = [] + + for b in bindings: + last_node_var = self._last_node_var(b) + if last_node_var is None: + continue + start_idx = b.get(f"_idx_{last_node_var}") + if start_idx is None: + continue + + # DFS with depth bounds + paths = self._bounded_dfs( + start_idx, + label=step.label, + direction=step.direction or "out", + min_depth=step.min_hops or 1, + max_depth=step.max_hops or 10, + ) + + next_node_var = self._next_node_var_after_edge(step.target_var) + + for path_nodes, path_edges in paths: + if not path_nodes: + continue + end_idx = path_nodes[-1] + end_node = self.vg.graph.get_node_data(end_idx) + end_dict = self._node_to_dict(end_node, end_idx) + + # Check target node label + end_label = self.vg.node_labels.get(end_idx) + next_step_label = self._get_next_node_label(step.target_var) + if next_step_label and end_label != next_step_label: + continue + + path_dict = { + "nodes": [self._node_to_dict(self.vg.graph.get_node_data(n), n) for n in path_nodes], + "edges": [self._edge_to_dict(e) for e in path_edges], + } + + new_b = {**b} + new_b[step.target_var] = path_dict + if next_node_var: + new_b[next_node_var] = end_dict + new_b[f"_idx_{next_node_var}"] = end_idx + + if self._check_predicates(step.predicates, new_b): + new_bindings.append(new_b) + + return new_bindings + + def _step_filter(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + return [b for b in bindings if self._check_predicates(step.predicates, b)] + + # ─── helpers ────────────────────────────────────────────────── + + def _outgoing_edges(self, src_idx: int, label: str | None) -> list[tuple[int, Any]]: + result = [] + for tgt_idx in self.vg.graph.successor_indices(src_idx): + edge_data = self.vg.graph.get_edge_data(src_idx, tgt_idx) + if label and not self._edge_matches_label(edge_data, label): + continue + result.append((tgt_idx, edge_data)) + return result + + def _incoming_edges(self, src_idx: int, label: str | None) -> list[tuple[int, Any]]: + result = [] + for pred_idx in self.vg.graph.predecessor_indices(src_idx): + edge_data = self.vg.graph.get_edge_data(pred_idx, src_idx) + if label and not self._edge_matches_label(edge_data, label): + continue + result.append((pred_idx, edge_data)) + return result + + def _edge_matches_label(self, edge_data: Any, label: str) -> bool: + if label == "LINKED" and isinstance(edge_data, EdgeData): + return True + if label == "MENTIONED_IN" and isinstance(edge_data, MentionedInEdge): + return True + return False + + def _bounded_dfs( + self, + start: int, + *, + label: str | None, + direction: str, + min_depth: int, + max_depth: int, + ) -> list[tuple[list[int], list[Any]]]: + """Return all paths of length [min_depth, max_depth] from start.""" + results: list[tuple[list[int], list[Any]]] = [] + # Stack: (current_node, path_nodes, path_edges, visited) + stack: list[tuple[int, list[int], list[Any], set[int]]] = [ + (start, [start], [], {start}) + ] + + while stack: + current, path_nodes, path_edges, visited = stack.pop() + depth = len(path_edges) + + if depth >= min_depth: + # Record this path (end node, not start) + results.append((list(path_nodes), list(path_edges))) + + if depth >= max_depth: + continue + + if direction == "out": + neighbors = self._outgoing_edges(current, label) + elif direction == "in": + neighbors = self._incoming_edges(current, label) + else: + neighbors = self._outgoing_edges(current, label) + self._incoming_edges(current, label) + + for next_idx, edge_data in neighbors: + if next_idx not in visited: + stack.append(( + next_idx, + path_nodes + [next_idx], + path_edges + [edge_data], + visited | {next_idx}, + )) + + return results + + def _node_to_dict(self, node_data: Any, idx: int) -> dict: + if isinstance(node_data, FindingNode): + return { + "id": node_data.finding_id, + "label": "Finding", + "severity": node_data.severity, + "tool": node_data.tool, + "title": node_data.title, + "created_at": str(node_data.created_at) if node_data.created_at else None, + "_idx": idx, + } + if isinstance(node_data, EntityNode): + return { + "id": node_data.entity_id, + "label": self.vg.node_labels.get(idx, "Entity"), + "canonical_value": node_data.canonical_value, + "entity_type": node_data.entity_type, + "mention_count": node_data.mention_count, + "_idx": idx, + } + return {"_idx": idx} + + def _edge_to_dict(self, edge_data: Any) -> dict: + if isinstance(edge_data, EdgeData): + return { + "label": "LINKED", + "relation_id": edge_data.relation_id, + "weight": edge_data.weight, + "status": edge_data.status, + "reasons": [r.rule for r in edge_data.reasons] if edge_data.reasons else [], + "llm_rationale": edge_data.llm_rationale, + "llm_relation_type": edge_data.llm_relation_type, + } + if isinstance(edge_data, MentionedInEdge): + return { + "label": "MENTIONED_IN", + "mention_id": edge_data.mention_id, + "field": edge_data.field, + "confidence": edge_data.confidence, + "extractor": edge_data.extractor, + } + return {} + + def _check_predicates(self, predicates: list, binding: Binding) -> bool: + for pred in predicates: + if not self._eval_predicate(pred, binding): + return False + return True + + def _eval_predicate(self, pred: Any, binding: Binding) -> bool: + if isinstance(pred, ComparisonExpr): + left_val = self._eval_expr(pred.left, binding) + if pred.operator == "IS NULL": + return left_val is None + if pred.operator == "IS NOT NULL": + return left_val is not None + right_val = self._eval_expr(pred.right, binding) + return self._compare(left_val, pred.operator, right_val) + if isinstance(pred, BooleanExpr): + if pred.operator == "AND": + return all(self._eval_predicate(op, binding) for op in pred.operands) + if pred.operator == "OR": + return any(self._eval_predicate(op, binding) for op in pred.operands) + if pred.operator == "NOT": + return not self._eval_predicate(pred.operands[0], binding) + if isinstance(pred, FunctionCallExpr): + return bool(self._eval_function(pred, binding)) + return True + + def _eval_expr(self, expr: Any, binding: Binding) -> Any: + if isinstance(expr, PropertyAccessExpr): + node_or_edge = binding.get(expr.variable) + if isinstance(node_or_edge, dict): + return node_or_edge.get(expr.property_name) + return None + if isinstance(expr, FunctionCallExpr): + return self._eval_function(expr, binding) + if isinstance(expr, str) and expr in binding: + return binding[expr] + return expr # literal value + + def _eval_function(self, func: FunctionCallExpr, binding: Binding) -> Any: + # Check built-ins first + builtin_fn = get_builtin(func.name) + if builtin_fn is not None: + args = [self._eval_expr(a, binding) for a in func.args] + return builtin_fn(*args) + # Check plugin registry + plugin_fn = self.plugins.get_function(func.name) + if plugin_fn is not None: + args = [self._eval_expr(a, binding) for a in func.args] + return plugin_fn(*args) + plugin_agg = self.plugins.get_aggregation(func.name) + if plugin_agg is not None: + args = [self._eval_expr(a, binding) for a in func.args] + return plugin_agg(*args) + return None + + def _compare(self, left: Any, op: str, right: Any) -> bool: + try: + if op == "=": + return left == right + if op == "<>": + return left != right + if op == "<": + return left < right + if op == ">": + return left > right + if op == "<=": + return left <= right + if op == ">=": + return left >= right + if op == "CONTAINS": + return isinstance(left, str) and isinstance(right, str) and right in left + if op in ("STARTS WITH", "STARTS_WITH"): + return isinstance(left, str) and isinstance(right, str) and left.startswith(right) + if op in ("ENDS WITH", "ENDS_WITH"): + return isinstance(left, str) and isinstance(right, str) and left.endswith(right) + if op == "IN": + return left in right if isinstance(right, list) else False + except TypeError: + return False + return False + + def _last_node_var(self, binding: Binding) -> str | None: + """Find the last node variable in the binding (has _idx_ prefix that maps to an int).""" + last = None + for key in binding: + if key.startswith("_idx_") and isinstance(binding[key], int): + last = key[5:] + return last + + def _next_node_var_after_edge(self, edge_var: str) -> str | None: + """Find the node variable that follows this edge variable in the plan steps.""" + found_edge = False + for step in self.plan.steps: + if step.target_var == edge_var: + found_edge = True + continue + if found_edge and step.kind == "scan": + return step.target_var + # If no explicit scan step for the next node, check the pattern + # The next node's variable needs to come from the AST patterns + return None + + def _get_next_node_label(self, edge_var: str) -> str | None: + """Get the label of the node that follows this edge variable in the plan.""" + found_edge = False + for step in self.plan.steps: + if step.target_var == edge_var: + found_edge = True + continue + if found_edge and step.kind == "scan": + return step.label + return None + + def _project_return(self, bindings: list[Binding]) -> tuple[list[str], list[dict]]: + """Project bindings through RETURN clause.""" + columns: list[str] = [] + for item in self.plan.return_spec.items: + if isinstance(item, ReturnItem): + if item.alias: + columns.append(item.alias) + elif isinstance(item.expression, PropertyAccessExpr): + columns.append(f"{item.expression.variable}.{item.expression.property_name}") + elif isinstance(item.expression, str): + columns.append(item.expression) + elif isinstance(item.expression, FunctionCallExpr): + columns.append(item.expression.name) + else: + columns.append(str(item.expression)) + else: + columns.append(str(item)) + + rows: list[dict] = [] + for b in bindings: + row: dict = {} + for i, item in enumerate(self.plan.return_spec.items): + col = columns[i] + if isinstance(item, ReturnItem): + val = self._eval_expr(item.expression, b) + else: + val = self._eval_expr(item, b) + # Strip internal keys from node dicts + if isinstance(val, dict): + val = {k: v for k, v in val.items() if not k.startswith("_")} + row[col] = val + rows.append(row) + + return columns, rows + + def _build_subgraph(self, bindings: list[Binding]) -> SubgraphProjection: + """Collect all node/edge indices touched by bindings.""" + nodes: set[int] = set() + edges: set[tuple[int, int]] = set() + + for b in bindings: + for key, val in b.items(): + if key.startswith("_idx_") and isinstance(val, int): + nodes.add(val) + elif key.startswith("_idx_") and isinstance(val, tuple): + edges.add(val) + nodes.add(val[0]) + nodes.add(val[1]) + + return SubgraphProjection(node_indices=nodes, edge_tuples=edges) +``` + +Note: This is a substantial implementation. The key design decisions are: (1) bindings track `_idx_` prefixed keys for internal graph index bookkeeping, (2) node/edge data is serialized to dicts early for property access in WHERE clauses, (3) DFS for variable-length paths uses a visited set to prevent cycles. The implementing agent should run the tests iteratively and fix any issues — the executor's expand/scan interaction with the planner's step ordering is the most likely area for edge cases. + +- [ ] **Step 4: Run tests iteratively until all pass** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_executor.py -v` +Expected: 11 passed. The `_next_node_var_after_edge` and `_get_next_node_label` methods need to correctly trace the relationship between edge variables and their target node variables in the plan. If tests fail on expand steps, the agent should inspect how the planner generates steps and adjust the executor's step-chaining logic. + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/executor.py packages/cli/tests/chain/cypher/test_executor.py +git commit -m "feat(cypher): add query executor with binding table and resource limits" +``` + +--- + +### Task 9: Query Session + +**Files:** +- Create: `packages/cli/src/opentools/chain/cypher/session.py` +- Create: `packages/cli/tests/chain/cypher/test_session.py` + +- [ ] **Step 1: Write failing tests** + +```python +# packages/cli/tests/chain/cypher/test_session.py +from opentools.chain.cypher.result import QueryResult, QueryStats +from opentools.chain.cypher.session import QuerySession + + +def test_session_store_and_get(): + session = QuerySession() + result = QueryResult(columns=["a"], rows=[{"a": 1}, {"a": 2}], stats=QueryStats()) + session.store("my_results", result) + retrieved = session.get("my_results") + assert retrieved is result + + +def test_session_get_unknown(): + session = QuerySession() + assert session.get("nonexistent") is None + + +def test_session_list_variables(): + session = QuerySession() + r1 = QueryResult(columns=["a"], rows=[], stats=QueryStats()) + r2 = QueryResult(columns=["b"], rows=[], stats=QueryStats()) + session.store("first", r1) + session.store("second", r2) + assert set(session.list_variables()) == {"first", "second"} + + +def test_session_clear(): + session = QuerySession() + session.store("x", QueryResult(columns=[], rows=[], stats=QueryStats())) + session.clear() + assert session.get("x") is None + assert session.list_variables() == [] + + +def test_session_overwrite(): + session = QuerySession() + r1 = QueryResult(columns=["a"], rows=[{"a": 1}], stats=QueryStats()) + r2 = QueryResult(columns=["a"], rows=[{"a": 2}], stats=QueryStats()) + session.store("x", r1) + session.store("x", r2) + assert session.get("x") is r2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_session.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement QuerySession** + +```python +# packages/cli/src/opentools/chain/cypher/session.py +"""Query session: named result sets for the REPL.""" +from __future__ import annotations + +from opentools.chain.cypher.result import QueryResult + + +class QuerySession: + def __init__(self) -> None: + self._variables: dict[str, QueryResult] = {} + + def store(self, name: str, result: QueryResult) -> None: + self._variables[name] = result + + def get(self, name: str) -> QueryResult | None: + return self._variables.get(name) + + def list_variables(self) -> list[str]: + return list(self._variables.keys()) + + def clear(self) -> None: + self._variables.clear() +``` + +- [ ] **Step 4: Run tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_session.py -v` +Expected: 5 passed + +- [ ] **Step 5: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/session.py packages/cli/tests/chain/cypher/test_session.py +git commit -m "feat(cypher): add query session for named result sets" +``` + +--- + +### Task 10: Public API + Config Integration + +**Files:** +- Modify: `packages/cli/src/opentools/chain/cypher/__init__.py` +- Modify: `packages/cli/src/opentools/chain/config.py` + +- [ ] **Step 1: Add CypherConfig to ChainConfig** + +Add to `packages/cli/src/opentools/chain/config.py`, before the `ChainConfig` class: + +```python +class CypherConfig(BaseModel): + model_config = ConfigDict(frozen=True) + + timeout_seconds: float = 30.0 + max_rows: int = 1000 + intermediate_binding_cap: int = 10_000 + max_var_length_hops: int = 10 + virtual_graph_cache_size: int = 4 +``` + +Add `cypher: CypherConfig = CypherConfig()` to the `ChainConfig` class fields (after the `query` field). + +- [ ] **Step 2: Write the public API in `__init__.py`** + +```python +# packages/cli/src/opentools/chain/cypher/__init__.py +"""Cypher-style query DSL for the attack chain knowledge graph.""" +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import UUID + +from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError +from opentools.chain.cypher.executor import CypherExecutor +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.parser import parse_cypher +from opentools.chain.cypher.planner import plan_query +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import VirtualGraphCache + +if TYPE_CHECKING: + from opentools.chain.config import ChainConfig + from opentools.chain.query.graph_cache import GraphCache + from opentools.chain.store_protocol import ChainStoreProtocol + + +async def parse_and_execute( + query: str, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + vg_cache: VirtualGraphCache, + session: QuerySession | None = None, + plugin_registry: PluginFunctionRegistry | None = None, + user_id: UUID | None = None, + include_candidates: bool = False, + engagement_ids: frozenset[str] | None = None, + limits: QueryLimits | None = None, +) -> QueryResult: + """Parse, plan, and execute a Cypher query — main entry point.""" + if session is None: + session = QuerySession() + if plugin_registry is None: + plugin_registry = PluginFunctionRegistry() + if limits is None: + limits = QueryLimits() + + ast = parse_cypher(query) + plan = plan_query(ast, limits) + + vg = await vg_cache.get( + user_id=user_id, + include_candidates=include_candidates, + engagement_ids=engagement_ids, + ) + + executor = CypherExecutor( + virtual_graph=vg, + plan=plan, + session=session, + plugin_registry=plugin_registry, + limits=limits, + ) + result = await executor.execute() + + # Store in session if this was a session assignment + if ast.session_assignment: + session.store(ast.session_assignment, result) + + return result + + +class CypherSession: + """High-level session object for CLI REPL and web editor.""" + + def __init__( + self, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + config: "ChainConfig", + user_id: UUID | None = None, + ) -> None: + from opentools.chain.cypher.virtual_graph import VirtualGraphCache + self.store = store + self.graph_cache = graph_cache + self.user_id = user_id + self.session = QuerySession() + self.plugin_registry = PluginFunctionRegistry() + self.limits = QueryLimits( + timeout_seconds=config.cypher.timeout_seconds, + max_rows=config.cypher.max_rows, + intermediate_binding_cap=config.cypher.intermediate_binding_cap, + max_var_length_hops=config.cypher.max_var_length_hops, + ) + self.vg_cache = VirtualGraphCache( + store=store, + graph_cache=graph_cache, + maxsize=config.cypher.virtual_graph_cache_size, + ) + self._engagement_ids: frozenset[str] | None = None + self._include_candidates = False + + def set_engagement_scope(self, engagement_ids: frozenset[str] | None) -> None: + self._engagement_ids = engagement_ids + + def set_include_candidates(self, include: bool) -> None: + self._include_candidates = include + + async def execute(self, query: str) -> QueryResult: + return await parse_and_execute( + query, + store=self.store, + graph_cache=self.graph_cache, + vg_cache=self.vg_cache, + session=self.session, + plugin_registry=self.plugin_registry, + user_id=self.user_id, + include_candidates=self._include_candidates, + engagement_ids=self._engagement_ids, + limits=self.limits, + ) +``` + +- [ ] **Step 3: Run full test suite** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/ -v` +Expected: All previous tests still pass. No new tests needed — the public API is exercised by the CLI tests in Task 11. + +- [ ] **Step 4: Commit** + +```bash +git add packages/cli/src/opentools/chain/cypher/__init__.py packages/cli/src/opentools/chain/config.py +git commit -m "feat(cypher): add public API and CypherConfig" +``` + +--- + +### Task 11: CLI Commands (run, repl, explain) + +**Files:** +- Modify: `packages/cli/src/opentools/chain/cli.py` +- Create: `packages/cli/tests/chain/cypher/test_cli_query.py` + +- [ ] **Step 1: Write failing tests for CLI query commands** + +```python +# packages/cli/tests/chain/cypher/test_cli_query.py +"""Tests for the CLI query subcommands.""" +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from typer.testing import CliRunner + +from opentools.chain.cli import app + + +runner = CliRunner() + + +@pytest.fixture(autouse=True) +def mock_stores(): + """Mock out the store/cache infrastructure so CLI commands can run.""" + mock_chain_store = AsyncMock() + mock_chain_store.initialize = AsyncMock() + mock_chain_store.close = AsyncMock() + mock_chain_store.current_linker_generation = AsyncMock(return_value=1) + + with patch("opentools.chain.cli._get_stores", new_callable=AsyncMock) as mock_get: + mock_get.return_value = (AsyncMock(), mock_chain_store) + yield mock_chain_store + + +def test_query_run_help(): + result = runner.invoke(app, ["query", "run", "--help"]) + assert result.exit_code == 0 + assert "Execute a Cypher query" in result.output or "cypher" in result.output.lower() + + +def test_query_explain_help(): + result = runner.invoke(app, ["query", "explain", "--help"]) + assert result.exit_code == 0 + + +def test_query_repl_help(): + result = runner.invoke(app, ["query", "repl", "--help"]) + assert result.exit_code == 0 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_cli_query.py -v` +Expected: FAIL — the existing `query` command in `cli.py` is a single command, not a subgroup. + +- [ ] **Step 3: Replace the existing `query` command with a subgroup** + +In `packages/cli/src/opentools/chain/cli.py`, replace the existing `query` command (the preset runner around line 274-328) with a new `query` subgroup. The existing preset functionality moves to `query preset` (or stays as the `preset` command at the top level — agent should choose based on what breaks fewer existing tests). + +Add a new Typer sub-app: + +```python +query_app = typer.Typer(help="Cypher query DSL commands") +app.add_typer(query_app, name="query") + + +@query_app.command("run") +@_async_command +async def query_run( + cypher: str = typer.Argument(..., help="Cypher query string"), + timeout: float = typer.Option(30.0, "--timeout", help="Query timeout in seconds"), + max_rows: int = typer.Option(1000, "--max-rows", help="Maximum result rows"), + engagement: str | None = typer.Option(None, "--engagement", help="Scope to engagement"), + include_candidates: bool = typer.Option(False, "--include-candidates", help="Include candidate edges"), + format: str = typer.Option("table", "--format", help="Output format: table, json, csv"), + no_subgraph: bool = typer.Option(False, "--no-subgraph", help="Skip subgraph projection"), +) -> None: + """Execute a Cypher query.""" + from opentools.chain.cypher import CypherSession + from opentools.chain.cypher.limits import QueryLimits + from opentools.chain.query.graph_cache import GraphCache + + _engagement_store, chain_store = await _get_stores() + try: + cfg = get_chain_config() + cache = GraphCache(store=chain_store, maxsize=cfg.query.graph_cache_size) + session = CypherSession(store=chain_store, graph_cache=cache, config=cfg) + + if engagement: + session.set_engagement_scope(frozenset([engagement])) + session.set_include_candidates(include_candidates) + session.limits = QueryLimits(timeout_seconds=timeout, max_rows=max_rows) + + result = await session.execute(cypher) + + if format == "json": + import json + rprint(json.dumps({"columns": result.columns, "rows": result.rows, "stats": {"duration_ms": result.stats.duration_ms, "rows_returned": result.stats.rows_returned}, "truncated": result.truncated}, indent=2, default=str)) + elif format == "csv": + if result.columns: + rprint(",".join(result.columns)) + for row in result.rows: + rprint(",".join(str(row.get(c, "")) for c in result.columns)) + else: + # Table format + if not result.rows: + rprint("[yellow]no results[/yellow]") + return + table = Table() + for col in result.columns: + table.add_column(col) + for row in result.rows: + table.add_row(*[str(row.get(c, "")) for c in result.columns]) + Console().print(table) + rprint(f"[dim]{result.stats.rows_returned} rows, {result.stats.duration_ms:.1f}ms[/dim]") + if result.truncated: + rprint(f"[yellow]truncated: {result.truncation_reason}[/yellow]") + finally: + await chain_store.close() + + +@query_app.command("explain") +@_async_command +async def query_explain( + cypher: str = typer.Argument(..., help="Cypher query string"), +) -> None: + """Show the query plan without executing.""" + from opentools.chain.cypher.limits import QueryLimits + from opentools.chain.cypher.parser import parse_cypher + from opentools.chain.cypher.planner import plan_query + + limits = QueryLimits() + ast = parse_cypher(cypher) + plan = plan_query(ast, limits) + + rprint("[bold]Query Plan[/bold]") + for i, step in enumerate(plan.steps, 1): + rprint(f" {i}. {step.kind}: {step.target_var} (label={step.label}, direction={step.direction})") + if step.predicates: + rprint(f" predicates: {len(step.predicates)} pushed down") + if step.min_hops is not None: + rprint(f" hops: {step.min_hops}..{step.max_hops}") + + +@query_app.command("repl") +@_async_command +async def query_repl( + engagement: str | None = typer.Option(None, "--engagement", help="Scope to engagement"), + include_candidates: bool = typer.Option(False, "--include-candidates"), +) -> None: + """Start an interactive Cypher query REPL.""" + from prompt_toolkit import PromptSession + from prompt_toolkit.history import InMemoryHistory + + from opentools.chain.cypher import CypherSession + from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError + from opentools.chain.query.graph_cache import GraphCache + + _engagement_store, chain_store = await _get_stores() + try: + cfg = get_chain_config() + cache = GraphCache(store=chain_store, maxsize=cfg.query.graph_cache_size) + cypher_session = CypherSession(store=chain_store, graph_cache=cache, config=cfg) + + if engagement: + cypher_session.set_engagement_scope(frozenset([engagement])) + cypher_session.set_include_candidates(include_candidates) + + prompt_session = PromptSession(history=InMemoryHistory()) + rprint("[bold]OpenTools Cypher REPL[/bold] (type :help for help, :quit to exit)") + + while True: + try: + text = prompt_session.prompt("cypher> ") + except (EOFError, KeyboardInterrupt): + break + + text = text.strip() + if not text: + continue + + # Multi-line continuation + while text.endswith("-") or text.endswith("|") or text.count("(") > text.count(")"): + try: + continuation = prompt_session.prompt(" ...> ") + text += " " + continuation.strip() + except (EOFError, KeyboardInterrupt): + break + + # Special commands + if text.startswith(":"): + cmd = text[1:].strip().lower() + if cmd in ("quit", "exit"): + break + elif cmd == "help": + rprint("Cypher query DSL. MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE ... RETURN ...") + rprint("Labels: Finding, Host, IP, CVE, Domain, Port, MitreAttack, Entity") + rprint("Edges: LINKED, MENTIONED_IN") + elif cmd == "functions": + from opentools.chain.cypher.builtins import list_builtins + for name, info in list_builtins().items(): + rprint(f" {name}: {info.get('help', '')}") + for name, info in cypher_session.plugin_registry.list_all().items(): + rprint(f" {name}: {info.get('help', '')} [{info.get('kind', '')}]") + elif cmd == "clear": + cypher_session.session.clear() + rprint("[dim]session cleared[/dim]") + elif cmd.startswith("limits"): + rprint(f"timeout: {cypher_session.limits.timeout_seconds}s") + rprint(f"max_rows: {cypher_session.limits.max_rows}") + rprint(f"intermediate_cap: {cypher_session.limits.intermediate_binding_cap}") + else: + rprint(f"[red]unknown command: {text}[/red]") + continue + + # Check if it's just a variable name (display stored result) + if text in cypher_session.session.list_variables(): + stored = cypher_session.session.get(text) + if stored: + for row in stored.rows[:20]: + rprint(row) + if len(stored.rows) > 20: + rprint(f"[dim]... {len(stored.rows) - 20} more rows[/dim]") + continue + + # Execute query + try: + result = await cypher_session.execute(text) + if not result.rows: + rprint("[yellow]no results[/yellow]") + else: + table = Table() + for col in result.columns: + table.add_column(col) + for row in result.rows: + table.add_row(*[str(row.get(c, "")) for c in result.columns]) + Console().print(table) + rprint(f"[dim]{result.stats.rows_returned} rows, {result.stats.duration_ms:.1f}ms[/dim]") + except (QueryParseError, QueryValidationError) as e: + rprint(f"[red]Parse error: {e}[/red]") + except QueryResourceError as e: + rprint(f"[red]Resource limit: {e}[/red]") + except Exception as e: + rprint(f"[red]Error: {e}[/red]") + + rprint("[dim]bye[/dim]") + finally: + await chain_store.close() +``` + +The agent should also rename the existing `query` function (the preset runner) to `preset` and add it to `query_app` or keep it as a top-level command. Check existing tests in `test_cli_commands.py` to see if anything references `query` by name and update accordingly. + +- [ ] **Step 4: Run tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/test_cli_query.py -v` +Expected: 3 passed + +- [ ] **Step 5: Run existing CLI tests to check for regressions** + +Run: `cd packages/cli && python -m pytest tests/chain/test_cli_commands.py -v` +Expected: If the old `query` command was renamed to `preset`, update any tests that invoke it. All tests should pass. + +- [ ] **Step 6: Commit** + +```bash +git add packages/cli/src/opentools/chain/cli.py packages/cli/tests/chain/cypher/test_cli_query.py +git commit -m "feat(cypher): add CLI query commands (run, explain, repl)" +``` + +--- + +### Task 12: Web Backend — Query Endpoint + +**Files:** +- Create: `packages/web/backend/app/routes/chain_query.py` +- Modify: `packages/web/backend/app/main.py` +- Create: `packages/web/backend/tests/chain/test_query_routes.py` + +- [ ] **Step 1: Write failing tests for the web endpoint** + +```python +# packages/web/backend/tests/chain/test_query_routes.py +"""Tests for the Cypher query web API endpoints.""" +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.main import app + + +@pytest.fixture +async def client(): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + + +@pytest.mark.asyncio +async def test_query_endpoint_requires_auth(client): + response = await client.post("/api/chain/query", json={"query": "MATCH (a:Finding) RETURN a"}) + assert response.status_code in (401, 403) + + +@pytest.mark.asyncio +async def test_functions_endpoint_requires_auth(client): + response = await client.get("/api/chain/query/functions") + assert response.status_code in (401, 403) +``` + +Note: The implementing agent should adapt these tests to match the project's existing auth test patterns (check `packages/web/backend/tests/` for how authenticated requests are mocked — there's likely a fixture that provides an auth token or mock user). + +- [ ] **Step 2: Implement the query routes** + +```python +# packages/web/backend/app/routes/chain_query.py +"""Cypher query DSL web API endpoints.""" +from __future__ import annotations + +from typing import Any, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from app.database import async_session_factory +from app.dependencies import get_current_user, get_db +from app.models import User + +router = APIRouter(prefix="/api/chain/query", tags=["chain-query"]) + + +class QueryRequest(BaseModel): + query: str + engagement_id: Optional[str] = None + include_candidates: bool = False + timeout: float = 30.0 + max_rows: int = 1000 + + +class QueryResponse(BaseModel): + columns: list[str] + rows: list[dict[str, Any]] + subgraph: Optional[dict] = None + stats: dict + truncated: bool + + +class FunctionInfo(BaseModel): + name: str + kind: str + help: str + + +@router.post("", response_model=QueryResponse) +async def execute_query( + request: QueryRequest, + current_user: User = Depends(get_current_user), + db=Depends(get_db), +): + """Execute a Cypher query against the attack chain knowledge graph.""" + from opentools.chain.config import get_chain_config + from opentools.chain.cypher import parse_and_execute + from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError + from opentools.chain.cypher.limits import QueryLimits + from opentools.chain.cypher.plugins import PluginFunctionRegistry + from opentools.chain.cypher.virtual_graph import VirtualGraphCache + from opentools.chain.query.graph_cache import GraphCache + + from app.services.chain_service import ChainService + + try: + cfg = get_chain_config() + chain_service = ChainService(session_factory=async_session_factory) + store = chain_service.get_store(user_id=current_user.id) + graph_cache = GraphCache(store=store, maxsize=cfg.query.graph_cache_size) + vg_cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=cfg.cypher.virtual_graph_cache_size) + + engagement_ids = frozenset([request.engagement_id]) if request.engagement_id else None + limits = QueryLimits(timeout_seconds=request.timeout, max_rows=request.max_rows) + + result = await parse_and_execute( + request.query, + store=store, + graph_cache=graph_cache, + vg_cache=vg_cache, + user_id=current_user.id, + include_candidates=request.include_candidates, + engagement_ids=engagement_ids, + limits=limits, + ) + + subgraph_data = None + if result.subgraph: + subgraph_data = { + "nodes": [{"index": idx} for idx in result.subgraph.node_indices], + "edges": [{"source": s, "target": t} for s, t in result.subgraph.edge_tuples], + } + + return QueryResponse( + columns=result.columns, + rows=result.rows, + subgraph=subgraph_data, + stats={ + "duration_ms": result.stats.duration_ms, + "bindings_explored": result.stats.bindings_explored, + "rows_returned": result.stats.rows_returned, + }, + truncated=result.truncated, + ) + + except QueryParseError as e: + raise HTTPException(status_code=400, detail=f"Parse error: {e}") + except QueryValidationError as e: + raise HTTPException(status_code=400, detail=f"Validation error: {e}") + except QueryResourceError as e: + raise HTTPException(status_code=400, detail=f"Resource limit: {e}") + + +@router.get("/functions") +async def list_functions( + current_user: User = Depends(get_current_user), +): + """List all available query functions (built-in and plugin).""" + from opentools.chain.cypher.builtins import list_builtins + + result = [] + for name, info in list_builtins().items(): + result.append({"name": name, "kind": "builtin", "help": info.get("help", "")}) + return result +``` + +Note: The `ChainService.get_store()` call above is a placeholder — the implementing agent should check the actual `chain_service.py` API for how to get a `ChainStoreProtocol` instance for the web backend. It likely involves creating a `PostgresChainStore` with the user's session and user_id. Follow the patterns in `packages/web/backend/app/routes/chain.py`. + +- [ ] **Step 3: Register the router in main.py** + +In `packages/web/backend/app/main.py`, add: + +```python +from app.routes import chain_query +app.include_router(chain_query.router) +``` + +alongside the existing `app.include_router(chain.router)`. + +- [ ] **Step 4: Run tests** + +Run: `cd packages/web/backend && python -m pytest tests/chain/test_query_routes.py -v` +Expected: 2 passed + +- [ ] **Step 5: Commit** + +```bash +git add packages/web/backend/app/routes/chain_query.py packages/web/backend/app/main.py packages/web/backend/tests/chain/test_query_routes.py +git commit -m "feat(cypher): add web query API endpoints" +``` + +--- + +### Task 13: Web Frontend — Standalone Query Page + +**Files:** +- Create: `packages/web/frontend/src/views/ChainQueryView.vue` +- Create: `packages/web/frontend/src/components/CypherEditor.vue` +- Create: `packages/web/frontend/src/components/QueryResultsPane.vue` +- Modify: `packages/web/frontend/src/router/index.ts` + +- [ ] **Step 1: Add the route** + +In `packages/web/frontend/src/router/index.ts`, add after the engagement-chain route: + +```typescript +{ path: '/chain/query', name: 'chain-query', component: () => import('@/views/ChainQueryView.vue') }, +``` + +- [ ] **Step 2: Create CypherEditor.vue** + +```vue + + + + + + + Run (Ctrl+Enter) + + + + + + + + +``` + +- [ ] **Step 3: Create QueryResultsPane.vue** + +```vue + + + + Running query... + {{ error }} + + + {{ result.stats.rows_returned }} rows, {{ result.stats.duration_ms.toFixed(1) }}ms + (truncated) + + + + + {{ col }} + + + + + {{ formatCell(row[col]) }} + + + + No results + + + + + + + +``` + +- [ ] **Step 4: Create ChainQueryView.vue** + +```vue + + + + Chain Query + + + All engagements + + + + + + + + + + + + + + +``` + +Note: The auth token inclusion in the `fetch` call depends on the project's existing auth pattern (cookie-based, bearer token in headers, etc.). The implementing agent should check how other views like `ChainGraphView.vue` make API calls and follow that pattern. + +- [ ] **Step 5: Verify the page loads in the browser** + +Start the dev server and navigate to `/chain/query`. Verify: +- CodeMirror editor renders +- Run button visible +- The page doesn't crash + +- [ ] **Step 6: Commit** + +```bash +git add packages/web/frontend/src/views/ChainQueryView.vue packages/web/frontend/src/components/CypherEditor.vue packages/web/frontend/src/components/QueryResultsPane.vue packages/web/frontend/src/router/index.ts +git commit -m "feat(cypher): add standalone web query page" +``` + +--- + +### Task 14: Inline Query Panel (Final 3C.4 Task) + +**Files:** +- Create: `packages/web/frontend/src/components/InlineQueryPanel.vue` +- Modify: `packages/web/frontend/src/views/ChainGraphView.vue` +- Modify: `packages/web/frontend/src/views/GlobalChainView.vue` (if exists from 3C.3) + +- [ ] **Step 1: Create InlineQueryPanel.vue** + +```vue + + + + + {{ expanded ? 'Hide Query' : 'Query' }} + + + + + + + + + + + +``` + +- [ ] **Step 2: Add InlineQueryPanel to ChainGraphView.vue** + +The implementing agent should read `ChainGraphView.vue`, find the template section, and add the inline panel. Add it inside the main container, after the `ForceGraphCanvas` component: + +```vue + +``` + +Import the component and add the highlight handler: + +```typescript +import InlineQueryPanel from '@/components/InlineQueryPanel.vue' + +function onQueryHighlight(nodeIds: string[]) { + // Apply glow effect to matched nodes in the force graph + // Implementation depends on ForceGraphCanvas API — the agent should + // check what highlighting mechanism the canvas supports +} +``` + +- [ ] **Step 3: Add InlineQueryPanel to GlobalChainView.vue** + +If `GlobalChainView.vue` exists (from 3C.3), add the same `InlineQueryPanel` component. The panel should be identical but without a fixed `engagement-id` prop (since the global view is cross-engagement). Follow the same pattern as Step 2. + +If `GlobalChainView.vue` does not yet exist (3C.3 not yet merged), skip this step — it will be added when the global view is implemented. + +- [ ] **Step 4: Verify in browser** + +Start the dev server, navigate to an engagement chain page. Verify: +- "Query" button visible at bottom +- Clicking it expands the inline panel +- CodeMirror editor works +- Running a query shows results +- Collapsing works + +- [ ] **Step 5: Commit** + +```bash +git add packages/web/frontend/src/components/InlineQueryPanel.vue packages/web/frontend/src/views/ChainGraphView.vue +git commit -m "feat(cypher): add inline query panel to chain graph views" +``` + +--- + +### Task 15: Protocol Addition + Full Integration Test + +**Files:** +- Modify: `packages/cli/src/opentools/chain/store_protocol.py` +- Modify both store backends to implement `fetch_all_mentions_in_scope` +- Create: `packages/cli/tests/chain/cypher/test_integration.py` + +- [ ] **Step 1: Add `fetch_all_mentions_in_scope` to ChainStoreProtocol** + +Add to `packages/cli/src/opentools/chain/store_protocol.py`: + +```python +async def fetch_all_mentions_in_scope( + self, *, user_id: UUID | None +) -> list[EntityMention]: + """Return all entity mentions for the user scope. + + Used by VirtualGraphBuilder to populate MENTIONED_IN edges. + """ + ... +``` + +- [ ] **Step 2: Implement in AsyncChainStore** + +The implementing agent should find `AsyncChainStore` (the aiosqlite backend) and add the implementation. Pattern: `SELECT * FROM entity_mentions` (or the equivalent table name), scoped by user_id if non-None, returning `EntityMention` domain objects. + +- [ ] **Step 3: Implement in PostgresChainStore** + +Same query via SQLAlchemy async. Follow the existing pattern of other `fetch_*` methods in `PostgresChainStore`. + +- [ ] **Step 4: Write integration test** + +```python +# packages/cli/tests/chain/cypher/test_integration.py +"""End-to-end integration test: parse → plan → build virtual graph → execute.""" +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone + +import pytest + +from opentools.chain.cypher import parse_and_execute +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import VirtualGraphCache +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import GraphCache +from tests.chain.cypher.test_virtual_graph import _make_entities, _make_master_graph, _make_mentions + + +@pytest.fixture +def mock_store(): + from unittest.mock import AsyncMock + store = AsyncMock() + store.current_linker_generation = AsyncMock(return_value=1) + store.stream_relations_in_scope = AsyncMock(return_value=iter([])) + store.fetch_all_finding_ids = AsyncMock(return_value=["fnd_1", "fnd_2", "fnd_3"]) + store.fetch_findings_by_ids = AsyncMock(return_value=[]) + store.list_entities = AsyncMock(return_value=_make_entities()) + store.fetch_all_mentions_in_scope = AsyncMock(return_value=_make_mentions()) + return store + + +@pytest.mark.asyncio +async def test_full_pipeline(mock_store): + """Parse a query, build virtual graph, execute, get results.""" + graph_cache = GraphCache(store=mock_store, maxsize=4) + + # We need to mock get_master_graph since the store is fully mocked + master = _make_master_graph() + from unittest.mock import AsyncMock as AM + graph_cache.get_master_graph = AM(return_value=master) + + vg_cache = VirtualGraphCache(store=mock_store, graph_cache=graph_cache, maxsize=4) + + result = await parse_and_execute( + "MATCH (a:Finding) RETURN a", + store=mock_store, + graph_cache=graph_cache, + vg_cache=vg_cache, + limits=QueryLimits(), + ) + + assert len(result.rows) == 3 + assert "a" in result.columns +``` + +- [ ] **Step 5: Run all cypher tests** + +Run: `cd packages/cli && python -m pytest tests/chain/cypher/ -v` +Expected: All tests pass (target: ~80+ tests at this point). + +- [ ] **Step 6: Run full test suite for regressions** + +Run: `cd packages/cli && python -m pytest tests/ -x --timeout=120` +Expected: No regressions. + +- [ ] **Step 7: Commit** + +```bash +git add packages/cli/src/opentools/chain/store_protocol.py packages/cli/tests/chain/cypher/test_integration.py +git commit -m "feat(cypher): add fetch_all_mentions_in_scope + integration test" +``` + +--- + +## Task Summary + +| Task | Description | Est. Tests | +|---|---|---| +| 1 | Error types + QueryLimits | 6 | +| 2 | AST node definitions | 14 | +| 3 | Lark grammar + parser | 22 | +| 4 | Built-in functions + plugin registry | 16 | +| 5 | Result types | 0 (exercised by T8) | +| 6 | Virtual graph builder + cache | 7 | +| 7 | Planner | 5 | +| 8 | Executor | 11 | +| 9 | Query session | 5 | +| 10 | Public API + config | 0 (exercised by T11, T15) | +| 11 | CLI commands | 3+ | +| 12 | Web backend endpoints | 2+ | +| 13 | Web frontend query page | manual | +| 14 | Inline query panel (final task) | manual | +| 15 | Protocol addition + integration test | 1+ | +| **Total** | | **~92+ automated** | diff --git a/docs/superpowers/specs/2026-04-13-phase3c4-cypher-dsl-design.md b/docs/superpowers/specs/2026-04-13-phase3c4-cypher-dsl-design.md new file mode 100644 index 0000000..e0c443e --- /dev/null +++ b/docs/superpowers/specs/2026-04-13-phase3c4-cypher-dsl-design.md @@ -0,0 +1,503 @@ +# Phase 3C.4: Cypher-Style Query DSL — Design Specification + +**Date:** 2026-04-13 +**Status:** Draft +**Author:** slabl + Claude +**Depends on:** Phase 3C.1 (data layer), 3C.2 (per-engagement viz), 3C.3 (global view + Bayesian calibration) + +## 1. Overview + +Phase 3C.4 adds a Cypher-style query DSL for custom graph queries over the attack chain knowledge graph. Users can write pattern-matching queries to explore findings, entities, and their relationships — from the CLI, an interactive REPL, or a web query editor. + +The DSL is read-only (no mutations), operates on a virtual heterogeneous graph (findings + entities as first-class nodes), and supports plugin-extensible functions. + +## 2. Decisions + +| Decision | Choice | Rationale | +|---|---|---| +| Parser library | `lark` (LALR mode) | Pure Python, EBNF, fast LALR parsing, actively maintained, ~500KB | +| Graph model | Virtual heterogeneous graph | Entities promoted to first-class nodes with MENTIONED_IN edges; enables pattern matching through entities | +| Variable-length paths | Edge-type-filtered traversal | Explicit relationship type labels control what edge types are followed; standard Cypher semantics | +| Architecture | Layered pipeline (Parser → Planner → VirtualGraphBuilder → Executor) | Each layer testable independently; virtual graph cached for REPL reuse; planner extensible for future optimization | +| Web query editor | Standalone page first, inline overlay last | Standalone page is independent and testable; inline overlay on graph pages is the final 3C.4 task | +| Plugin functions | Scalar + aggregation functions | `collect()` pulled into v1 grammar to support plugin aggregations | +| Result format | Dual: table + subgraph projection | Power users want both raw data and visual; subgraph projection is cheap | +| REPL interaction | Stateful sessions with named result sets | Documented as OpenTools extension; enables iterative exploration | +| Resource limits | Configurable with defaults (30s timeout, 1000 rows, 10,000 intermediate bindings) | Intermediate binding cap important because heterogeneous graph increases branching factor | +| Engagement scoping | Context-dependent | Pre-scoped from engagement pages, cross-engagement from standalone/global; engagement filter always available | +| Variable-length max hops | Hard cap of 10, enforced at parse time | Prevents combinatorial explosion even if other limits are raised | +| Plugin sandboxing | No sandbox in v1 | Plugins receive read-only property dicts; timeout kills hung plugins; true sandboxing deferred | + +## 3. Grammar & Parser + +**Grammar file:** `packages/cli/src/opentools/chain/cypher/grammar.lark` + +**Parser:** `lark` with LALR mode. Parse output is a typed AST using Python dataclasses. + +### 3.1 Supported Grammar (v1) + +``` +MATCH [, ]* +[WHERE ] +RETURN [, ]* +``` + +**Patterns:** + +``` +(var:Label) — node pattern +-[var:Label]-> — directed edge (outgoing) +<-[var:Label]- — directed edge (incoming) +-[var:Label*min..max]-> — variable-length path (outgoing) +<-[var:Label*min..max]- — variable-length path (incoming) +``` + +**Node labels:** `Finding`, `Host`, `IP`, `CVE`, `Domain`, `Port`, `MitreAttack`, `Entity` (wildcard for any entity type). + +**Edge labels:** `LINKED`, `MENTIONED_IN`. + +**WHERE expressions:** + +| Category | Syntax | +|---|---| +| Property access | `a.severity`, `r.weight`, `a.title` | +| Comparisons | `=`, `<>`, `<`, `>`, `<=`, `>=` | +| Boolean | `AND`, `OR`, `NOT` | +| String | `CONTAINS`, `STARTS WITH`, `ENDS WITH` | +| Membership | `IN [list]` | +| Null check | `IS NULL`, `IS NOT NULL` | +| Built-in functions | `length(path)`, `nodes(path)`, `relationships(path)`, `has_entity(node, type, value)`, `has_mitre(node, technique_id)` | +| Plugin functions | `plugin_name.function_name(args...)` | + +**RETURN items:** + +| Category | Syntax | +|---|---| +| Variables | `a`, `r`, `path` | +| Property access | `a.title`, `r.weight` | +| Aggregation | `collect(a)` | +| Plugin aggregations | `plugin_name.agg_function(collect(a))` | + +### 3.2 Session Extension (OpenTools-specific) + +``` +result_name = MATCH ... RETURN ... +``` + +Stores the result set in a session variable, referenceable in later queries within the same REPL session. This is an OpenTools extension, not standard Cypher. + +### 3.3 Read-Only Enforcement + +The grammar does not define tokens for `CREATE`, `DELETE`, `SET`, `MERGE`, `REMOVE`, `DETACH`, `DROP`. These are caught as parse errors — mutation verbs never produce a valid AST. + +### 3.4 AST Dataclasses + +Defined in `ast_nodes.py`: `MatchClause`, `NodePattern`, `EdgePattern`, `VarLengthSpec`, `WhereExpr`, `ComparisonExpr`, `BooleanExpr`, `FunctionCallExpr`, `PropertyAccessExpr`, `ReturnClause`, `ReturnItem`, `SessionAssignment`. + +## 4. Virtual Heterogeneous Graph + +The virtual graph augments the existing `MasterGraph` (finding-only) with entity nodes and `MENTIONED_IN` edges. + +### 4.1 Node Types + +| Label | Source | Properties | +|---|---|---| +| `Finding` | `FindingNode` from `MasterGraph` | `id`, `severity`, `tool`, `title`, `created_at`, `engagement_id` | +| `Host` | `Entity` where `type="host"` | `id`, `canonical_value`, `mention_count` | +| `IP` | `Entity` where `type="ip"` | `id`, `canonical_value`, `mention_count` | +| `CVE` | `Entity` where `type="cve"` | `id`, `canonical_value`, `mention_count` | +| `Domain` | `Entity` where `type="domain"` | `id`, `canonical_value`, `mention_count` | +| `Port` | `Entity` where `type="port"` | `id`, `canonical_value`, `mention_count` | +| `MitreAttack` | `Entity` where `type="mitre_technique"` | `id`, `canonical_value`, `mention_count` | + +Any entity type not in this list is accessible via the generic `Entity` label. + +### 4.2 Edge Types + +| Label | Direction | Meaning | Properties | +|---|---|---|---| +| `LINKED` | Finding → Finding | Existing `FindingRelation` edges | `weight`, `status`, `reasons`, `llm_rationale`, `llm_relation_type` | +| `MENTIONED_IN` | Entity → Finding | Derived from `EntityMention` rows | `field`, `confidence`, `extractor` | + +### 4.3 VirtualGraphBuilder + +**Location:** `packages/cli/src/opentools/chain/cypher/virtual_graph.py` + +Takes a `MasterGraph` + entity/mention data from `ChainStoreProtocol` and produces a `VirtualGraph`: + +```python +@dataclass +class VirtualGraph: + graph: rx.PyDiGraph + finding_map: dict[str, int] # finding_id → node index + entity_map: dict[str, int] # entity_id → node index + reverse_map: dict[int, str] # node index → id (finding or entity) + node_labels: dict[int, str] # node index → label ("Finding", "Host", etc.) + generation: int +``` + +### 4.4 Caching + +`VirtualGraphCache` wraps `GraphCache`. Keyed by `(user_id, generation, include_candidates, engagement_ids_frozenset)`. Same async LRU pattern as `GraphCache` with per-key build lock. `maxsize=4`. + +REPL sessions reuse the cache across queries — the virtual graph is only rebuilt when the linker generation advances or engagement scope changes. + +### 4.5 Build Cost + +For an engagement with 500 findings and 2,000 entities, the virtual graph adds ~2,000 nodes and ~5,000 MENTIONED_IN edges on top of the existing master graph. Build time dominated by entity/mention DB queries, not graph construction. Expected <500ms for typical engagement sizes. + +## 5. Planner + +The planner translates the AST into an ordered sequence of execution steps. In v1 it follows query order (no cost-based optimization), but the layer exists for future cardinality estimation and reordering. + +### 5.1 Data Structures + +```python +@dataclass +class QueryPlan: + steps: list[PlanStep] + return_spec: ReturnSpec + limits: QueryLimits + +@dataclass +class PlanStep: + kind: Literal["scan", "expand", "filter", "var_length_expand"] + target_var: str # which query variable this step binds + label: str | None # node/edge label constraint + direction: Literal["out", "in", "both"] | None + min_hops: int | None # for var_length_expand + max_hops: int | None + predicates: list[WhereExpr] # pushed-down WHERE clauses for this step +``` + +**Step kinds:** + +- **scan** — find all nodes matching a label, create one binding per match +- **expand** — follow edges from bound nodes to next pattern element +- **filter** — apply WHERE predicates to current bindings +- **var_length_expand** — bounded DFS for variable-length paths + +### 5.2 Predicate Pushdown + +The planner analyzes WHERE clauses and attaches each predicate to the earliest step whose bound variables satisfy it. `WHERE a.severity = "critical"` gets pushed down to the scan step that binds `a`, not deferred to a post-match filter pass. This is the one optimization v1 performs. + +### 5.3 Variable-Length Path Planning + +`(a:Finding)-[r:LINKED*1..5]->(b:Finding)` becomes three steps: + +1. Scan for `a` (with any pushed-down predicates on `a`) +2. `var_length_expand` following LINKED edges 1-5 hops, binding `r` as a path variable +3. Bind `b` as the terminal node + +The expand uses bounded DFS with the intermediate binding cap (default 10,000) as the kill switch. + +### 5.4 Session Result References + +Session variables can be used as the source in a MATCH pattern via `FROM` syntax: + +``` +critical = MATCH (a:Finding) WHERE a.severity = "critical" RETURN a +MATCH (a) FROM critical -[r:LINKED]->(b:Finding) RETURN a, b +``` + +The `FROM ` clause tells the planner to insert a `scan` step that reads bindings from the session store instead of scanning the graph. The stored result set's column `a` seeds the binding table, and execution continues from there with the remaining pattern. + +Only RETURN-ed variables from the stored result are available — internal bindings that were not returned are discarded. + +## 6. Executor + +The executor walks the `QueryPlan` against the `VirtualGraph`, managing bindings and enforcing resource limits. + +### 6.1 Core Class + +```python +class CypherExecutor: + def __init__( + self, + *, + virtual_graph: VirtualGraph, + plan: QueryPlan, + session: QuerySession, + plugin_registry: PluginFunctionRegistry, + limits: QueryLimits, + ) -> None: ... + + async def execute(self) -> QueryResult: ... +``` + +### 6.2 Binding Table + +The executor maintains a list of `Binding` dicts — each dict maps query variable names to graph node/edge indices. Every plan step transforms the binding table: + +- **scan** — iterates all nodes with matching label, creates one binding per match +- **expand** — for each existing binding, follows edges of the specified type/direction, extends the binding with the new variable +- **filter** — evaluates predicates against each binding, drops non-matching rows +- **var_length_expand** — bounded DFS from each binding's current position, produces one binding per discovered path (the path variable binds to a `PathBinding` containing the full node/edge sequence) + +### 6.3 Resource Enforcement + +Checked at every step boundary: + +- `len(bindings) <= intermediate_binding_cap` (default 10,000). Exceeding aborts with `QueryResourceError`. +- Monotonic timer checks against timeout (default 30s). +- Final result rows capped at `max_rows` (default 1,000) — applied after RETURN projection. + +### 6.4 RETURN Projection + +After all match steps complete, the executor projects the binding table: + +- Variable references → serialize the bound node/edge data +- Property access (`a.severity`) → extract from node/edge payload +- `collect(a)` → group and aggregate +- Plugin functions → invoke registered callables with bound values + +### 6.5 Output + +```python +@dataclass +class QueryResult: + columns: list[str] # RETURN column names + rows: list[dict[str, Any]] # tabular data + subgraph: SubgraphProjection | None # union of all matched nodes/edges + stats: QueryStats # timing, bindings explored, rows returned + truncated: bool + truncation_reason: str | None + +@dataclass +class SubgraphProjection: + node_indices: set[int] + edge_indices: set[tuple[int, int]] +``` + +### 6.6 Plugin Function Invocation + +Plugin scalar functions receive property values and return scalars. Plugin aggregation functions receive `list[Any]` (collected values) and return scalars. Both called synchronously — async plugin functions not supported in v1. The query timeout kills hung plugins. + +## 7. Plugin Function Registry + +### 7.1 Registration API + +```python +# packages/cli/src/opentools/chain/cypher/plugins.py + +def register_query_function( + name: str, # "my_plugin.risk_score" + fn: Callable, # (value: Any) -> scalar + *, + help: str = "", + arg_types: list[str], # ["node"], ["node", "str"], etc. + return_type: str, # "float", "bool", "str" +) -> None: ... + +def register_query_aggregation( + name: str, # "my_plugin.combined_risk" + fn: Callable, # (values: list[Any]) -> scalar + *, + help: str = "", + input_type: str, + return_type: str, +) -> None: ... +``` + +### 7.2 Namespacing + +Plugin functions must use dotted names (`plugin_name.function_name`). Built-in functions (`length`, `nodes`, `relationships`, `has_entity`, `has_mitre`, `collect`) are un-namespaced. Prevents collisions and clarifies built-in vs. plugin in queries. + +### 7.3 Validation + +- Registration time: name collision check +- Plan time: all function references resolve to registered functions, argument counts match +- Unresolved functions produce `QueryValidationError` before execution + +### 7.4 Discovery + +`list_query_functions()` returns all registered functions with help text and signatures — used by REPL tab completion and web editor autocomplete. + +## 8. CLI Surface + +### 8.1 `opentools chain query run ''` + +Single-shot query execution. + +**Flags:** + +| Flag | Default | Description | +|---|---|---| +| `--timeout` | 30 | Query timeout in seconds | +| `--max-rows` | 1000 | Maximum result rows | +| `--engagement` | None | Scope to engagement (omit for cross-engagement) | +| `--include-candidates` | false | Include candidate-status edges | +| `--format` | table | Output format: `table`, `json`, `csv` | +| `--no-subgraph` | false | Skip subgraph projection | + +### 8.2 `opentools chain query repl` + +Interactive REPL session. + +**Multi-line detection:** Open parens, trailing `-`, or trailing `|` prompt continuation lines. Prompt changes from `cypher>` to ` ...>`. + +**Session variables:** `results = MATCH ... RETURN ...` stores the result set. Typing `results` alone re-displays it. Tab completion shows available session variables. + +**Special commands (`:` prefix):** + +| Command | Description | +|---|---| +| `:help` | Grammar reference and available functions | +| `:functions` | List all built-in and plugin functions | +| `:presets` | List available presets (for reference) | +| `:limits` | Show/set timeout, max-rows, intermediate cap | +| `:clear` | Clear session variables | +| `:quit` / `:exit` | Exit REPL | + +Accepts `--engagement`, `--include-candidates` flags at launch, also settable via `:limits`. + +Uses `prompt_toolkit` for line editing, history, and tab completion. + +### 8.3 `opentools chain query explain ''` + +Dry-run that shows the query plan without executing. Outputs plan steps, pushed-down predicates, and estimated scan sizes. + +## 9. Web Surface + +### 9.1 Phase 1: Standalone Query Page + +**Route:** `/chain/query` + +**Backend endpoint:** `POST /api/chain/query` + +Request: +```json +{ + "query": "MATCH (a:Finding)-[:LINKED]->(b:Finding) WHERE a.severity = 'critical' RETURN a, b", + "engagement_id": "eng-123", + "include_candidates": false, + "timeout": 30, + "max_rows": 1000 +} +``` + +Response: +```json +{ + "columns": ["a", "b"], + "rows": [{"a": {...}, "b": {...}}], + "subgraph": { + "nodes": [{"id": "...", "label": "Finding", "properties": {...}}], + "edges": [{"source": "...", "target": "...", "label": "LINKED", "properties": {...}}] + }, + "stats": {"duration_ms": 45, "bindings_explored": 312, "rows_returned": 8}, + "truncated": false +} +``` + +**Security:** Requires authentication. `user_id` from JWT propagated to executor. Queries cannot cross user boundaries. + +**Metadata endpoint:** `GET /api/chain/query/functions` — returns all available functions with names, help text, arg types. Powers editor autocomplete. + +**Frontend component:** `ChainQueryPage.vue` + +Layout — split pane: +- **Top:** CodeMirror 6 editor with Cypher syntax highlighting. Autocomplete for labels, property names, functions (fetched from metadata endpoint). Run via button or Ctrl+Enter. +- **Bottom left:** Sortable data grid showing result rows. Columns from RETURN clause. +- **Bottom right:** Mini force-graph preview rendering the subgraph projection. Uses `ForceGraphCanvas` from 3C.2. Clicking a node in the table highlights it in the graph and vice versa. +- **Engagement filter:** Dropdown at top. Pre-populated from context if navigated from an engagement page. + +### 9.2 Phase 2: Inline Overlay (Final 3C.4 Task) + +Collapsible query panel added to `ChainGraphView.vue` and `GlobalChainView.vue`. Query results highlight matching nodes/edges in the main graph (yellow glow). Reuses CodeMirror editor and tabular results from standalone page, embedded as overlay. Engagement scope auto-set from current page context. + +## 10. Safety & Security + +### 10.1 Read-Only Enforcement + +Two layers: +1. **Lexer-level:** Grammar does not define mutation verb tokens. Parse errors before AST. +2. **Executor-level:** Only rustworkx read methods called. No `ChainStoreProtocol` write methods invoked during execution. + +### 10.2 User Scoping + +`user_id` set once by `VirtualGraphBuilder` and propagated to every `ChainStoreProtocol` call. Web: `PostgresChainStore` enforces `@require_user_scope`. CLI: `user_id=None` (single-user). + +### 10.3 Input Sanitization + +Query string parsed by lark — rejects anything not matching grammar. No string interpolation, no SQL generation, no eval. Entity lookups go through `normalize()` + `entity_id_for()` content-addressing — no injection surface. + +### 10.4 Resource Limits + +| Limit | Default | Configurable via | +|---|---|---| +| Query timeout | 30s | `ChainConfig`, CLI `--timeout`, web request body | +| Max result rows | 1,000 | `ChainConfig`, CLI `--max-rows`, web request body | +| Intermediate binding cap | 10,000 | `ChainConfig`, CLI `:limits` in REPL | +| Variable-length max hops | 10 (hard cap) | Grammar-enforced, `*1..N` where N <= 10 | + +### 10.5 Plugin Function Sandboxing + +No sandbox in v1. Plugins receive read-only property dicts, not graph references. Timeout kills hung plugins. Documented as known limitation; true sandboxing deferred. + +## 11. Testing Strategy + +### 11.1 Unit Tests + +| Layer | Focus | Est. Cases | +|---|---|---| +| Parser | Grammar edge cases, valid/invalid queries, mutation rejection, var-length bounds | ~40-50 | +| Planner | Predicate pushdown, session references, var-length step generation | ~20 | +| VirtualGraphBuilder | Node/edge counts, labels, property access, MENTIONED_IN direction, cache LRU | ~15 | +| Executor | End-to-end per step kind, resource limit enforcement, collect() aggregation | ~30 | +| Plugin registry | Registration, collision rejection, resolution, invocation | ~10 | + +### 11.2 Integration Tests + +| Area | Focus | Est. Cases | +|---|---|---| +| CLI `query run` | Typer test runner, output formats, engagement scoping | ~10 | +| CLI REPL | Session variables, multi-line, special commands, prompt_toolkit mocking | ~10 | +| Web endpoint | POST /api/chain/query, auth, user scoping, response shape, 403 on unauthorized | ~10 | + +### 11.3 Conformance Tests + +Same query test suite runs against both `AsyncChainStore` (aiosqlite) and `PostgresChainStore` (SQLAlchemy async), following the backend parameterization pattern from 3C.1.5. Verifies virtual graph builds identically from both backends. + +**Total:** ~145-175 tests. All async, following existing `pytest-asyncio` patterns. + +## 12. File Layout + +### 12.1 Core Module + +``` +packages/cli/src/opentools/chain/cypher/ +├── __init__.py # public API: parse_and_execute(), CypherSession +├── grammar.lark # lark EBNF grammar +├── parser.py # lark parser → typed AST +├── ast_nodes.py # AST dataclass definitions +├── planner.py # AST → QueryPlan with predicate pushdown +├── virtual_graph.py # VirtualGraphBuilder + VirtualGraphCache +├── executor.py # CypherExecutor +├── plugins.py # PluginFunctionRegistry + registration API +├── session.py # QuerySession — named result sets, REPL state +├── result.py # QueryResult, SubgraphProjection, QueryStats +├── limits.py # QueryLimits + QueryResourceError +├── builtins.py # length, nodes, relationships, has_entity, has_mitre, collect +└── errors.py # QueryParseError, QueryValidationError, QueryResourceError +``` + +### 12.2 CLI Additions + +- `packages/cli/src/opentools/chain/cli.py` — new `query` command group (`run`, `repl`, `explain`) + +### 12.3 Web Additions + +- `packages/web/backend/app/routes/chain_query.py` — `POST /api/chain/query`, `GET /api/chain/query/functions` +- `packages/web/frontend/src/pages/ChainQueryPage.vue` — standalone query page +- `packages/web/frontend/src/components/chain/CypherEditor.vue` — CodeMirror wrapper with Cypher mode +- `packages/web/frontend/src/components/chain/QueryResultsPane.vue` — tabular results + mini graph + +### 12.4 Inline Overlay (Final Task) + +- `packages/web/frontend/src/components/chain/InlineQueryPanel.vue` — collapsible overlay for `ChainGraphView.vue` and `GlobalChainView.vue` + +### 12.5 Tests + +- `packages/cli/tests/chain/cypher/` — `test_parser.py`, `test_planner.py`, `test_virtual_graph.py`, `test_executor.py`, `test_plugins.py`, `test_session.py`, `test_cli_query.py` +- `packages/web/backend/tests/chain/test_query_routes.py` diff --git a/packages/cli/src/opentools/chain/cli.py b/packages/cli/src/opentools/chain/cli.py index c5c643f..6c4129a 100644 --- a/packages/cli/src/opentools/chain/cli.py +++ b/packages/cli/src/opentools/chain/cli.py @@ -287,9 +287,15 @@ async def export( await chain_store.close() -@app.command() +# ─── query sub-app ────────────────────────────────────────────────── + +query_app = typer.Typer(help="Cypher query DSL and preset commands") +app.add_typer(query_app, name="query") + + +@query_app.command("preset") @_async_command -async def query( +async def query_preset( preset: str = typer.Argument(..., help="Preset name (lateral-movement, priv-esc-chains, external-to-internal, crown-jewel, mitre-coverage)"), engagement: str = typer.Option(..., "--engagement", help="Engagement id"), entity_ref: str | None = typer.Option(None, "--entity", help="Required for crown-jewel preset"), @@ -344,6 +350,187 @@ async def query( await chain_store.close() +@query_app.command("run") +@_async_command +async def query_run( + cypher: str = typer.Argument(..., help="Cypher query string"), + timeout: float = typer.Option(30.0, "--timeout", help="Query timeout in seconds"), + max_rows: int = typer.Option(1000, "--max-rows", help="Maximum result rows"), + engagement: str | None = typer.Option(None, "--engagement", help="Scope to engagement"), + include_candidates: bool = typer.Option(False, "--include-candidates", help="Include candidate edges"), + format_: str = typer.Option("table", "--format", help="Output format: table, json, csv"), + no_subgraph: bool = typer.Option(False, "--no-subgraph", help="Skip subgraph projection"), +) -> None: + """Execute a Cypher query.""" + import json + from opentools.chain.cypher import CypherSession + from opentools.chain.cypher.limits import QueryLimits + + _engagement_store, chain_store = await _get_stores() + try: + cfg = get_chain_config() + cache = GraphCache(store=chain_store, maxsize=cfg.query.graph_cache_size) + cypher_session = CypherSession(store=chain_store, graph_cache=cache, config=cfg) + + if engagement: + cypher_session.set_engagement_scope(frozenset([engagement])) + cypher_session.set_include_candidates(include_candidates) + cypher_session.limits = QueryLimits(timeout_seconds=timeout, max_rows=max_rows) + + result = await cypher_session.execute(cypher) + + if format_ == "json": + rprint(json.dumps( + {"columns": result.columns, "rows": result.rows, + "stats": {"duration_ms": result.stats.duration_ms, "rows_returned": result.stats.rows_returned}, + "truncated": result.truncated}, + indent=2, default=str, + )) + elif format_ == "csv": + if result.columns: + rprint(",".join(result.columns)) + for row in result.rows: + rprint(",".join(str(row.get(c, "")) for c in result.columns)) + else: + if not result.rows: + rprint("[yellow]no results[/yellow]") + return + table = Table() + for col in result.columns: + table.add_column(col) + for row in result.rows: + table.add_row(*[str(row.get(c, "")) for c in result.columns]) + console.print(table) + rprint(f"[dim]{result.stats.rows_returned} rows, {result.stats.duration_ms:.1f}ms[/dim]") + if result.truncated: + rprint(f"[yellow]truncated: {result.truncation_reason}[/yellow]") + finally: + await chain_store.close() + + +@query_app.command("explain") +@_async_command +async def query_explain( + cypher: str = typer.Argument(..., help="Cypher query string"), +) -> None: + """Show the query plan without executing.""" + from opentools.chain.cypher.limits import QueryLimits + from opentools.chain.cypher.parser import parse_cypher + from opentools.chain.cypher.planner import plan_query + + limits = QueryLimits() + ast = parse_cypher(cypher) + plan = plan_query(ast, limits) + + rprint("[bold]Query Plan[/bold]") + for i, step in enumerate(plan.steps, 1): + rprint(f" {i}. {step.kind}: {step.target_var} (label={step.label}, direction={step.direction})") + if step.predicates: + rprint(f" predicates: {len(step.predicates)} pushed down") + if step.min_hops is not None: + rprint(f" hops: {step.min_hops}..{step.max_hops}") + + +@query_app.command("repl") +@_async_command +async def query_repl( + engagement: str | None = typer.Option(None, "--engagement", help="Scope to engagement"), + include_candidates: bool = typer.Option(False, "--include-candidates"), +) -> None: + """Start an interactive Cypher query REPL.""" + from prompt_toolkit import PromptSession + from prompt_toolkit.history import InMemoryHistory + from opentools.chain.cypher import CypherSession + from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError + + _engagement_store, chain_store = await _get_stores() + try: + cfg = get_chain_config() + cache = GraphCache(store=chain_store, maxsize=cfg.query.graph_cache_size) + cypher_session = CypherSession(store=chain_store, graph_cache=cache, config=cfg) + + if engagement: + cypher_session.set_engagement_scope(frozenset([engagement])) + cypher_session.set_include_candidates(include_candidates) + + prompt_session = PromptSession(history=InMemoryHistory()) + rprint("[bold]OpenTools Cypher REPL[/bold] (type :help for help, :quit to exit)") + + while True: + try: + text = prompt_session.prompt("cypher> ") + except (EOFError, KeyboardInterrupt): + break + + text = text.strip() + if not text: + continue + + while text.endswith("-") or text.endswith("|") or text.count("(") > text.count(")"): + try: + continuation = prompt_session.prompt(" ...> ") + text += " " + continuation.strip() + except (EOFError, KeyboardInterrupt): + break + + if text.startswith(":"): + cmd = text[1:].strip().lower() + if cmd in ("quit", "exit"): + break + elif cmd == "help": + rprint("Cypher query DSL. MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE ... RETURN ...") + rprint("Labels: Finding, Host, IP, CVE, Domain, Port, MitreAttack, Entity") + rprint("Edges: LINKED, MENTIONED_IN") + elif cmd == "functions": + from opentools.chain.cypher.builtins import list_builtins + for name, info in list_builtins().items(): + rprint(f" {name}: {info.get('help', '')}") + for name, info in cypher_session.plugin_registry.list_all().items(): + rprint(f" {name}: {info.get('help', '')} [{info.get('kind', '')}]") + elif cmd == "clear": + cypher_session.session.clear() + rprint("[dim]session cleared[/dim]") + elif cmd.startswith("limits"): + rprint(f"timeout: {cypher_session.limits.timeout_seconds}s") + rprint(f"max_rows: {cypher_session.limits.max_rows}") + rprint(f"intermediate_cap: {cypher_session.limits.intermediate_binding_cap}") + else: + rprint(f"[red]unknown command: {text}[/red]") + continue + + if text in cypher_session.session.list_variables(): + stored = cypher_session.session.get(text) + if stored: + for row in stored.rows[:20]: + rprint(row) + if len(stored.rows) > 20: + rprint(f"[dim]... {len(stored.rows) - 20} more rows[/dim]") + continue + + try: + result = await cypher_session.execute(text) + if not result.rows: + rprint("[yellow]no results[/yellow]") + else: + table = Table() + for col in result.columns: + table.add_column(col) + for row in result.rows: + table.add_row(*[str(row.get(c, "")) for c in result.columns]) + console.print(table) + rprint(f"[dim]{result.stats.rows_returned} rows, {result.stats.duration_ms:.1f}ms[/dim]") + except (QueryParseError, QueryValidationError) as e: + rprint(f"[red]Parse error: {e}[/red]") + except QueryResourceError as e: + rprint(f"[red]Resource limit: {e}[/red]") + except Exception as e: + rprint(f"[red]Error: {e}[/red]") + + rprint("[dim]bye[/dim]") + finally: + await chain_store.close() + + @app.command() @_async_command async def calibrate( diff --git a/packages/cli/src/opentools/chain/config.py b/packages/cli/src/opentools/chain/config.py index 9899e92..d582709 100644 --- a/packages/cli/src/opentools/chain/config.py +++ b/packages/cli/src/opentools/chain/config.py @@ -137,6 +137,16 @@ class QueryConfig(BaseModel): graph_cache_size: int = 8 +class CypherConfig(BaseModel): + model_config = ConfigDict(frozen=True) + + timeout_seconds: float = 30.0 + max_rows: int = 1000 + intermediate_binding_cap: int = 10_000 + max_var_length_hops: int = 10 + virtual_graph_cache_size: int = 4 + + class ChainConfig(BaseModel): model_config = ConfigDict(frozen=True) @@ -146,6 +156,7 @@ class ChainConfig(BaseModel): linker: LinkerConfig = LinkerConfig() llm: LLMConfig = LLMConfig() query: QueryConfig = QueryConfig() + cypher: CypherConfig = CypherConfig() _config_singleton: ChainConfig | None = None diff --git a/packages/cli/src/opentools/chain/cypher/__init__.py b/packages/cli/src/opentools/chain/cypher/__init__.py new file mode 100644 index 0000000..c02c82f --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/__init__.py @@ -0,0 +1,117 @@ +"""Cypher-style query DSL for the attack chain knowledge graph.""" +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import UUID + +from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError +from opentools.chain.cypher.executor import CypherExecutor +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.parser import parse_cypher +from opentools.chain.cypher.planner import plan_query +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import VirtualGraphCache + +if TYPE_CHECKING: + from opentools.chain.config import ChainConfig + from opentools.chain.query.graph_cache import GraphCache + from opentools.chain.store_protocol import ChainStoreProtocol + + +async def parse_and_execute( + query: str, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + vg_cache: VirtualGraphCache, + session: QuerySession | None = None, + plugin_registry: PluginFunctionRegistry | None = None, + user_id: UUID | None = None, + include_candidates: bool = False, + engagement_ids: frozenset[str] | None = None, + limits: QueryLimits | None = None, +) -> QueryResult: + """Parse, plan, and execute a Cypher query — main entry point.""" + if session is None: + session = QuerySession() + if plugin_registry is None: + plugin_registry = PluginFunctionRegistry() + if limits is None: + limits = QueryLimits() + + ast = parse_cypher(query) + plan = plan_query(ast, limits) + + vg = await vg_cache.get( + user_id=user_id, + include_candidates=include_candidates, + engagement_ids=engagement_ids, + ) + + executor = CypherExecutor( + virtual_graph=vg, + plan=plan, + session=session, + plugin_registry=plugin_registry, + limits=limits, + ) + result = await executor.execute() + + # Store in session if this was a session assignment + if ast.session_assignment: + session.store(ast.session_assignment, result) + + return result + + +class CypherSession: + """High-level session object for CLI REPL and web editor.""" + + def __init__( + self, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + config: "ChainConfig", + user_id: UUID | None = None, + ) -> None: + self.store = store + self.graph_cache = graph_cache + self.user_id = user_id + self.session = QuerySession() + self.plugin_registry = PluginFunctionRegistry() + self.limits = QueryLimits( + timeout_seconds=config.cypher.timeout_seconds, + max_rows=config.cypher.max_rows, + intermediate_binding_cap=config.cypher.intermediate_binding_cap, + max_var_length_hops=config.cypher.max_var_length_hops, + ) + self.vg_cache = VirtualGraphCache( + store=store, + graph_cache=graph_cache, + maxsize=config.cypher.virtual_graph_cache_size, + ) + self._engagement_ids: frozenset[str] | None = None + self._include_candidates = False + + def set_engagement_scope(self, engagement_ids: frozenset[str] | None) -> None: + self._engagement_ids = engagement_ids + + def set_include_candidates(self, include: bool) -> None: + self._include_candidates = include + + async def execute(self, query: str) -> QueryResult: + return await parse_and_execute( + query, + store=self.store, + graph_cache=self.graph_cache, + vg_cache=self.vg_cache, + session=self.session, + plugin_registry=self.plugin_registry, + user_id=self.user_id, + include_candidates=self._include_candidates, + engagement_ids=self._engagement_ids, + limits=self.limits, + ) diff --git a/packages/cli/src/opentools/chain/cypher/ast_nodes.py b/packages/cli/src/opentools/chain/cypher/ast_nodes.py new file mode 100644 index 0000000..09338aa --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/ast_nodes.py @@ -0,0 +1,93 @@ +"""Typed AST nodes for the Cypher-style query DSL.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + + +@dataclass +class VarLengthSpec: + min_hops: int + max_hops: int + + +@dataclass +class NodePattern: + variable: str | None + label: str | None + + +@dataclass +class EdgePattern: + variable: str | None + label: str | None + direction: Literal["out", "in"] + var_length: VarLengthSpec | None + + +@dataclass +class PropertyAccessExpr: + variable: str + property_name: str + + +@dataclass +class ComparisonExpr: + left: Any + operator: str + right: Any + + +@dataclass +class BooleanExpr: + operator: Literal["AND", "OR", "NOT"] + operands: list[Any] + + +@dataclass +class FunctionCallExpr: + name: str + args: list[Any] = field(default_factory=list) + + +@dataclass +class ReturnItem: + expression: Any + alias: str | None + + +@dataclass +class MatchClause: + patterns: list[tuple] + + +@dataclass +class WhereClause: + expression: Any + + +@dataclass +class ReturnClause: + items: list[ReturnItem] + + +@dataclass +class FromClause: + session_variable: str + + +@dataclass +class SessionAssignment: + variable_name: str + match_clause: MatchClause + where_clause: WhereClause | None + return_clause: ReturnClause + + +@dataclass +class CypherQuery: + match_clause: MatchClause + where_clause: WhereClause | None + return_clause: ReturnClause + from_clause: FromClause | None = None + session_assignment: str | None = None diff --git a/packages/cli/src/opentools/chain/cypher/builtins.py b/packages/cli/src/opentools/chain/cypher/builtins.py new file mode 100644 index 0000000..2e3fa7c --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/builtins.py @@ -0,0 +1,40 @@ +"""Built-in functions for the Cypher DSL.""" +from __future__ import annotations +from typing import Any, Callable + +def builtin_length(path: dict) -> int: + return len(path.get("edges", [])) + +def builtin_nodes(path: dict) -> list: + return path.get("nodes", []) + +def builtin_relationships(path: dict) -> list: + return path.get("edges", []) + +def builtin_has_entity(node: dict, entity_type: str, entity_value: str) -> bool: + for ent in node.get("entities", []): + if ent.get("type") == entity_type and ent.get("canonical_value") == entity_value: + return True + return False + +def builtin_has_mitre(node: dict, technique_id: str) -> bool: + return builtin_has_entity(node, "mitre_technique", technique_id) + +def builtin_collect(values: list) -> list: + return list(values) + +_BUILTINS: dict[str, dict] = { + "length": {"fn": builtin_length, "help": "Number of edges in a path", "arg_types": ["path"], "return_type": "int"}, + "nodes": {"fn": builtin_nodes, "help": "List of nodes in a path", "arg_types": ["path"], "return_type": "list"}, + "relationships": {"fn": builtin_relationships, "help": "List of edges in a path", "arg_types": ["path"], "return_type": "list"}, + "has_entity": {"fn": builtin_has_entity, "help": "Check if node mentions entity", "arg_types": ["node", "str", "str"], "return_type": "bool"}, + "has_mitre": {"fn": builtin_has_mitre, "help": "Check if node mentions MITRE technique", "arg_types": ["node", "str"], "return_type": "bool"}, + "collect": {"fn": builtin_collect, "help": "Aggregate values into a list", "arg_types": ["list"], "return_type": "list", "is_aggregation": True}, +} + +def get_builtin(name: str) -> Callable | None: + entry = _BUILTINS.get(name) + return entry["fn"] if entry else None + +def list_builtins() -> dict[str, dict]: + return {name: {k: v for k, v in info.items() if k != "fn"} for name, info in _BUILTINS.items()} diff --git a/packages/cli/src/opentools/chain/cypher/errors.py b/packages/cli/src/opentools/chain/cypher/errors.py new file mode 100644 index 0000000..2d00c95 --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/errors.py @@ -0,0 +1,26 @@ +"""Query DSL error hierarchy.""" +from __future__ import annotations + + +class QueryParseError(Exception): + def __init__(self, message: str, *, line: int | None = None, column: int | None = None) -> None: + self.line = line + self.column = column + loc = "" + if line is not None: + loc = f" (line {line}" + if column is not None: + loc += f", col {column}" + loc += ")" + super().__init__(f"{message}{loc}") + + +class QueryValidationError(Exception): + pass + + +class QueryResourceError(Exception): + def __init__(self, message: str, *, limit_name: str, limit_value: int | float) -> None: + self.limit_name = limit_name + self.limit_value = limit_value + super().__init__(message) diff --git a/packages/cli/src/opentools/chain/cypher/executor.py b/packages/cli/src/opentools/chain/cypher/executor.py new file mode 100644 index 0000000..66b3d41 --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/executor.py @@ -0,0 +1,616 @@ +"""Cypher query executor: walks a QueryPlan against a VirtualGraph. + +Uses a binding-table approach where each step produces/filters a list of +bindings (dict[str, Any]). Resource limits (timeout, binding cap, row cap) +are checked at step boundaries. +""" +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + FunctionCallExpr, + PropertyAccessExpr, + ReturnItem, +) +from opentools.chain.cypher.builtins import get_builtin +from opentools.chain.cypher.errors import QueryResourceError, QueryValidationError +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.planner import PlanStep, QueryPlan +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult, QueryStats, SubgraphProjection +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import EntityNode, MentionedInEdge, VirtualGraph +from opentools.chain.query.graph_cache import EdgeData, FindingNode + + +# ─── serialization helpers ─────────────────────────────────────────────────── + + +def _serialize_finding(node: FindingNode, idx: int) -> dict[str, Any]: + created = node.created_at + if isinstance(created, datetime): + created = created.isoformat() + return { + "id": node.finding_id, + "label": "Finding", + "severity": node.severity, + "tool": node.tool, + "title": node.title, + "created_at": created, + "_idx": idx, + } + + +def _serialize_entity(node: EntityNode, idx: int) -> dict[str, Any]: + return { + "id": node.entity_id, + "label": node.entity_type.capitalize() if node.entity_type else "Entity", + "canonical_value": node.canonical_value, + "entity_type": node.entity_type, + "mention_count": node.mention_count, + "_idx": idx, + } + + +def _serialize_node(node_data: Any, idx: int) -> dict[str, Any]: + if isinstance(node_data, FindingNode): + return _serialize_finding(node_data, idx) + if isinstance(node_data, EntityNode): + return _serialize_entity(node_data, idx) + return {"_idx": idx, "label": "Unknown"} + + +def _serialize_edge_data(edge: EdgeData) -> dict[str, Any]: + return { + "label": "LINKED", + "relation_id": edge.relation_id, + "weight": edge.weight, + "cost": edge.cost, + "status": edge.status, + "reasons": edge.reasons, + "llm_rationale": edge.llm_rationale, + "llm_relation_type": edge.llm_relation_type, + } + + +def _serialize_mentioned_in(edge: MentionedInEdge) -> dict[str, Any]: + return { + "label": "MENTIONED_IN", + "mention_id": edge.mention_id, + "field": edge.field, + "confidence": edge.confidence, + "extractor": edge.extractor, + } + + +def _serialize_edge(edge_data: Any) -> dict[str, Any]: + if isinstance(edge_data, EdgeData): + return _serialize_edge_data(edge_data) + if isinstance(edge_data, MentionedInEdge): + return _serialize_mentioned_in(edge_data) + return {"label": "UNKNOWN"} + + +def _edge_matches_label(edge_data: Any, label: str | None) -> bool: + """Check whether an edge payload matches the requested type label.""" + if label is None: + return True + if label == "LINKED": + return isinstance(edge_data, EdgeData) + if label == "MENTIONED_IN": + return isinstance(edge_data, MentionedInEdge) + return False + + +def _strip_internal_keys(d: dict[str, Any]) -> dict[str, Any]: + """Remove internal bookkeeping keys (e.g. _idx) from an output dict.""" + return {k: v for k, v in d.items() if not k.startswith("_idx")} + + +# ─── predicate evaluation ─────────────────────────────────────────────────── + + +def _resolve_expr(expr: Any, binding: dict[str, Any], plugin_registry: PluginFunctionRegistry) -> Any: + """Resolve an expression against a binding row.""" + if isinstance(expr, PropertyAccessExpr): + node_or_edge = binding.get(expr.variable) + if node_or_edge is None: + return None + if isinstance(node_or_edge, dict): + return node_or_edge.get(expr.property_name) + return getattr(node_or_edge, expr.property_name, None) + + if isinstance(expr, FunctionCallExpr): + resolved_args = [_resolve_expr(a, binding, plugin_registry) for a in expr.args] + fn = get_builtin(expr.name) + if fn is None: + fn = plugin_registry.get_function(expr.name) + if fn is None: + raise QueryValidationError(f"Unknown function: {expr.name}") + return fn(*resolved_args) + + if isinstance(expr, ComparisonExpr): + return _eval_comparison(expr, binding, plugin_registry) + + if isinstance(expr, BooleanExpr): + return _eval_boolean(expr, binding, plugin_registry) + + # Bare string: could be a variable reference or a literal. + # The parser emits var_ref as a plain str, and string_val as a plain str. + # We disambiguate by checking the binding table first. + if isinstance(expr, str): + if expr in binding: + return binding[expr] + return expr + + # Other literal values (int, float, bool, None, list) + if isinstance(expr, (int, float, bool, type(None), list)): + return expr + + return expr + + +def _eval_comparison(expr: ComparisonExpr, binding: dict[str, Any], plugin_registry: PluginFunctionRegistry) -> bool: + left = _resolve_expr(expr.left, binding, plugin_registry) + right = _resolve_expr(expr.right, binding, plugin_registry) + op = expr.operator + + if op == "=": + return left == right + if op == "!=": + return left != right + if op == "<>": + return left != right + if op == "<": + return left is not None and right is not None and left < right + if op == ">": + return left is not None and right is not None and left > right + if op == "<=": + return left is not None and right is not None and left <= right + if op == ">=": + return left is not None and right is not None and left >= right + if op == "CONTAINS": + return left is not None and right is not None and right in left + if op == "STARTS WITH": + return left is not None and right is not None and str(left).startswith(str(right)) + if op == "ENDS WITH": + return left is not None and right is not None and str(left).endswith(str(right)) + if op == "IN": + return left is not None and right is not None and left in right + if op == "IS NULL": + return left is None + if op == "IS NOT NULL": + return left is not None + + raise QueryValidationError(f"Unknown operator: {op}") + + +def _eval_boolean(expr: BooleanExpr, binding: dict[str, Any], plugin_registry: PluginFunctionRegistry) -> bool: + if expr.operator == "AND": + return all(_resolve_expr(op, binding, plugin_registry) for op in expr.operands) + if expr.operator == "OR": + return any(_resolve_expr(op, binding, plugin_registry) for op in expr.operands) + if expr.operator == "NOT": + return not _resolve_expr(expr.operands[0], binding, plugin_registry) + raise QueryValidationError(f"Unknown boolean operator: {expr.operator}") + + +def _eval_predicates(predicates: list[Any], binding: dict[str, Any], plugin_registry: PluginFunctionRegistry) -> bool: + """Evaluate a list of predicate expressions (conjuncts) against a binding.""" + for pred in predicates: + result = _resolve_expr(pred, binding, plugin_registry) + if not result: + return False + return True + + +# ─── executor ──────────────────────────────────────────────────────────────── + +Binding = dict[str, Any] + + +class CypherExecutor: + """Execute a QueryPlan against a VirtualGraph, producing a QueryResult. + + Args: + virtual_graph: The heterogeneous graph to query. + plan: The query plan from the planner. + session: Query session for named result sets. + plugin_registry: Registry for plugin scalar/aggregation functions. + limits: Resource limits for the query. + """ + + def __init__( + self, + *, + virtual_graph: VirtualGraph, + plan: QueryPlan, + session: QuerySession, + plugin_registry: PluginFunctionRegistry, + limits: QueryLimits, + ) -> None: + self._vg = virtual_graph + self._plan = plan + self._session = session + self._plugins = plugin_registry + self._limits = limits + + async def execute(self) -> QueryResult: + """Execute the plan and return a QueryResult.""" + start_time = time.monotonic() + bindings: list[Binding] = [{}] # start with one empty binding + + for step in self._plan.steps: + # Timeout check + elapsed = time.monotonic() - start_time + if elapsed > self._limits.timeout_seconds: + raise QueryResourceError( + f"Query timed out after {elapsed:.1f}s (limit: {self._limits.timeout_seconds}s)", + limit_name="timeout_seconds", + limit_value=self._limits.timeout_seconds, + ) + + if step.kind == "scan": + bindings = self._exec_scan(step, bindings) + elif step.kind == "expand": + bindings = self._exec_expand(step, bindings) + elif step.kind == "var_length_expand": + bindings = self._exec_var_length_expand(step, bindings) + elif step.kind == "filter": + bindings = self._exec_filter(step, bindings) + + # Intermediate binding cap + if len(bindings) > self._limits.intermediate_binding_cap: + raise QueryResourceError( + f"Intermediate bindings ({len(bindings)}) exceed cap ({self._limits.intermediate_binding_cap})", + limit_name="intermediate_binding_cap", + limit_value=self._limits.intermediate_binding_cap, + ) + + # Project RETURN + elapsed_ms = (time.monotonic() - start_time) * 1000 + return self._project_return(bindings, elapsed_ms) + + # ── step executors ──────────────────────────────────────────────────── + + def _exec_scan(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Scan: iterate all nodes matching a label, create or validate bindings.""" + var = step.target_var + + # If the variable is already bound in all existing bindings, this is a + # no-op (the expand step already bound it). + if bindings and all(var in b for b in bindings): + # Still apply predicates if any + if step.predicates: + return [b for b in bindings if _eval_predicates(step.predicates, b, self._plugins)] + return bindings + + new_bindings: list[Binding] = [] + g = self._vg.graph + node_indices = g.node_indices() + + for b in bindings: + for idx in node_indices: + # Label filter + if step.label is not None: + node_label = self._vg.node_labels.get(idx) + if node_label != step.label: + continue + + node_data = g.get_node_data(idx) + serialized = _serialize_node(node_data, idx) + new_b = {**b, var: serialized, f"_idx_{var}": idx} + + # Apply pushed-down predicates + if step.predicates and not _eval_predicates(step.predicates, new_b, self._plugins): + continue + + new_bindings.append(new_b) + + return new_bindings + + def _exec_expand(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Expand: follow edges from the last bound node. + + The expand step binds BOTH the edge variable AND the target node variable. + The planner produces: scan(a) -> expand(r) -> scan(b). + We look ahead to find which node variable should be bound by this expand. + """ + edge_var = step.target_var + g = self._vg.graph + + # Determine the source node variable (most recently bound node var) + # and the target node variable (next scan step's target_var). + source_var = self._find_source_var(bindings) + target_var = self._find_expand_target_var(step) + + new_bindings: list[Binding] = [] + + for b in bindings: + src_idx = b.get(f"_idx_{source_var}") + if src_idx is None: + continue + + edges = self._get_directed_edges(src_idx, step.direction) + + for (edge_src, edge_tgt, edge_data) in edges: + if not _edge_matches_label(edge_data, step.label): + continue + + # Determine the "other" node index (the target of the traversal) + other_idx = edge_tgt if edge_src == src_idx else edge_src + + # If target_var is already bound, check it matches + if target_var and target_var in b: + existing_idx = b.get(f"_idx_{target_var}") + if existing_idx != other_idx: + continue + + # Check label of target node if the next scan step has a label + target_label = self._get_next_scan_label(target_var) + if target_label is not None: + actual_label = self._vg.node_labels.get(other_idx) + if actual_label != target_label: + continue + + serialized_edge = _serialize_edge(edge_data) + other_node_data = g.get_node_data(other_idx) + serialized_target = _serialize_node(other_node_data, other_idx) + + new_b = {**b, edge_var: serialized_edge} + if target_var: + new_b[target_var] = serialized_target + new_b[f"_idx_{target_var}"] = other_idx + + # Apply pushed-down predicates + if step.predicates and not _eval_predicates(step.predicates, new_b, self._plugins): + continue + + new_bindings.append(new_b) + + return new_bindings + + def _exec_var_length_expand(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Variable-length expand: bounded DFS producing path bindings.""" + edge_var = step.target_var + g = self._vg.graph + source_var = self._find_source_var(bindings) + target_var = self._find_expand_target_var(step) + + min_hops = step.min_hops or 1 + max_hops = min(step.max_hops or self._limits.max_var_length_hops, self._limits.max_var_length_hops) + + new_bindings: list[Binding] = [] + + for b in bindings: + src_idx = b.get(f"_idx_{source_var}") + if src_idx is None: + continue + + # BFS/DFS to find all paths of length [min_hops, max_hops] + paths = self._bounded_dfs(src_idx, step.label, step.direction, min_hops, max_hops) + + for path_nodes, path_edges in paths: + end_idx = path_nodes[-1] + + # Check target label + target_label = self._get_next_scan_label(target_var) + if target_label is not None: + actual_label = self._vg.node_labels.get(end_idx) + if actual_label != target_label: + continue + + serialized_path = { + "nodes": [_serialize_node(g.get_node_data(ni), ni) for ni in path_nodes], + "edges": [_serialize_edge(e) for e in path_edges], + } + + new_b = {**b, edge_var: serialized_path} + if target_var: + end_node_data = g.get_node_data(end_idx) + new_b[target_var] = _serialize_node(end_node_data, end_idx) + new_b[f"_idx_{target_var}"] = end_idx + + if step.predicates and not _eval_predicates(step.predicates, new_b, self._plugins): + continue + + new_bindings.append(new_b) + + return new_bindings + + def _exec_filter(self, step: PlanStep, bindings: list[Binding]) -> list[Binding]: + """Filter: apply remaining predicates.""" + return [b for b in bindings if _eval_predicates(step.predicates, b, self._plugins)] + + # ── helpers ─────────────────────────────────────────────────────────── + + def _find_source_var(self, bindings: list[Binding]) -> str: + """Find the most recently bound node variable (has _idx_ prefix).""" + if not bindings or not bindings[0]: + return "" + # Get the last node variable that was bound (has _idx_ key) + b = bindings[0] + node_vars = [k[5:] for k in b if k.startswith("_idx_")] + return node_vars[-1] if node_vars else "" + + def _find_expand_target_var(self, step: PlanStep) -> str | None: + """Find the target node variable for an expand step. + + Look at the plan steps: the step after this expand should be a scan + for the target node. Return that scan's target_var. + """ + steps = self._plan.steps + step_idx = None + for i, s in enumerate(steps): + if s is step: + step_idx = i + break + if step_idx is None: + return None + + # Look for the next scan step after this expand + for i in range(step_idx + 1, len(steps)): + if steps[i].kind == "scan": + return steps[i].target_var + if steps[i].kind in ("expand", "var_length_expand"): + # Another expand before a scan — the intermediate node + break + return None + + def _get_next_scan_label(self, target_var: str | None) -> str | None: + """Get the label from the next scan step for a given variable.""" + if target_var is None: + return None + for s in self._plan.steps: + if s.kind == "scan" and s.target_var == target_var: + return s.label + return None + + def _get_directed_edges(self, src_idx: int, direction: str | None) -> list[tuple[int, int, Any]]: + """Get edges from/to a node based on direction.""" + g = self._vg.graph + results: list[tuple[int, int, Any]] = [] + + if direction in ("out", None, "both"): + # Outgoing edges + try: + out_edges = g.out_edges(src_idx) + for src, tgt, data in out_edges: + results.append((src, tgt, data)) + except Exception: + pass + + if direction in ("in", "both"): + # Incoming edges + try: + in_edges = g.in_edges(src_idx) + for src, tgt, data in in_edges: + results.append((src, tgt, data)) + except Exception: + pass + + return results + + def _bounded_dfs( + self, + start_idx: int, + label: str | None, + direction: str | None, + min_hops: int, + max_hops: int, + ) -> list[tuple[list[int], list[Any]]]: + """Bounded DFS returning all paths of length [min_hops, max_hops]. + + Returns list of (node_indices, edge_payloads) tuples. + """ + results: list[tuple[list[int], list[Any]]] = [] + stack: list[tuple[list[int], list[Any]]] = [([start_idx], [])] + + while stack: + path_nodes, path_edges = stack.pop() + current = path_nodes[-1] + depth = len(path_edges) + + if depth >= min_hops: + results.append((path_nodes, path_edges)) + + if depth >= max_hops: + continue + + edges = self._get_directed_edges(current, direction) + for (edge_src, edge_tgt, edge_data) in edges: + if not _edge_matches_label(edge_data, label): + continue + next_idx = edge_tgt if edge_src == current else edge_src + if next_idx in path_nodes: + continue # avoid cycles + stack.append((path_nodes + [next_idx], path_edges + [edge_data])) + + return results + + # ── RETURN projection ───────────────────────────────────────────────── + + def _project_return(self, bindings: list[Binding], elapsed_ms: float) -> QueryResult: + """Project bindings through the RETURN clause to produce final output.""" + return_items = self._plan.return_spec.items + columns = self._compute_columns(return_items) + subgraph = SubgraphProjection() + + rows: list[dict[str, Any]] = [] + truncated = False + truncation_reason: str | None = None + + for b in bindings: + if len(rows) >= self._limits.max_rows: + truncated = True + truncation_reason = f"Row limit ({self._limits.max_rows}) reached" + break + + row: dict[str, Any] = {} + for item, col_name in zip(return_items, columns): + value = self._resolve_return_item(item, b) + # Strip internal keys from dict values + if isinstance(value, dict): + # Track subgraph nodes/edges + idx = value.get("_idx") + if idx is not None: + subgraph.node_indices.add(idx) + value = _strip_internal_keys(value) + row[col_name] = value + + # Track subgraph edges from bindings + self._track_subgraph_edges(b, subgraph) + rows.append(row) + + # Only return subgraph if there are edges tracked + result_subgraph = subgraph if subgraph.node_indices else None + + return QueryResult( + columns=columns, + rows=rows, + subgraph=result_subgraph, + stats=QueryStats( + duration_ms=elapsed_ms, + bindings_explored=len(bindings), + rows_returned=len(rows), + ), + truncated=truncated, + truncation_reason=truncation_reason, + ) + + def _compute_columns(self, return_items: list[ReturnItem]) -> list[str]: + """Compute column names from ReturnItems.""" + columns: list[str] = [] + for item in return_items: + if item.alias: + columns.append(item.alias) + elif isinstance(item.expression, PropertyAccessExpr): + columns.append(f"{item.expression.variable}.{item.expression.property_name}") + elif isinstance(item.expression, FunctionCallExpr): + columns.append(item.expression.name) + elif isinstance(item.expression, str): + columns.append(item.expression) + else: + columns.append(str(item.expression)) + return columns + + def _resolve_return_item(self, item: ReturnItem, binding: Binding) -> Any: + """Resolve a single ReturnItem against a binding.""" + return _resolve_expr(item.expression, binding, self._plugins) + + def _track_subgraph_edges(self, binding: Binding, subgraph: SubgraphProjection) -> None: + """Track edge tuples in the subgraph projection from binding _idx_ keys.""" + idx_keys = sorted([k for k in binding if k.startswith("_idx_")]) + node_indices = [binding[k] for k in idx_keys] + for idx in node_indices: + subgraph.node_indices.add(idx) + # If there are at least 2 node indices, track the edge between consecutive pairs + if len(node_indices) >= 2: + for i in range(len(node_indices) - 1): + subgraph.edge_tuples.add((node_indices[i], node_indices[i + 1])) diff --git a/packages/cli/src/opentools/chain/cypher/grammar.lark b/packages/cli/src/opentools/chain/cypher/grammar.lark new file mode 100644 index 0000000..bcf8900 --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/grammar.lark @@ -0,0 +1,119 @@ +// Cypher-style query DSL grammar (LALR mode) + +start: session_assignment | query + +session_assignment: IDENT "=" query + +query: match_clause where_clause? return_clause + +match_clause: KW_MATCH pattern_list + +pattern_list: pattern ("," pattern)* + +pattern: node_pattern (edge_pattern node_pattern)* + +// Node patterns +node_pattern: "(" variable_def? label_def? ")" from_clause? + +from_clause: KW_FROM IDENT + +variable_def: IDENT + +label_def: ":" NODE_LABEL + +// Edge patterns +edge_pattern: outgoing_edge | incoming_edge + +outgoing_edge: "-" edge_detail "->" +incoming_edge: "<-" edge_detail "-" + +edge_detail: "[" edge_var? edge_label? var_length? "]" + +edge_var: IDENT + +edge_label: ":" EDGE_LABEL + +var_length: "*" INT ".." INT + +// WHERE clause +where_clause: KW_WHERE bool_expr + +?bool_expr: or_expr + +?or_expr: and_expr (KW_AND and_expr)* + +?and_expr: not_expr (KW_AND not_expr)* -> and_expr + +?not_expr: KW_NOT not_expr -> not_expr + | comparison + +?comparison: prop_access KW_CONTAINS expr -> contains_expr + | prop_access KW_STARTS_WITH expr -> starts_with_expr + | prop_access KW_ENDS_WITH expr -> ends_with_expr + | prop_access KW_IN expr -> in_expr + | prop_access KW_IS_NOT_NULL -> is_not_null_expr + | prop_access KW_IS_NULL -> is_null_expr + | expr comp_op expr -> comparison_expr + | function_call -> where_func + | "(" bool_expr ")" + +!comp_op: "=" | "<>" | "<=" | ">=" | "<" | ">" + +// RETURN clause +return_clause: KW_RETURN return_list + +return_list: return_item ("," return_item)* + +return_item: expr alias? + +alias: KW_AS IDENT + +// Expressions +?expr: function_call + | prop_access + | atom + +prop_access: IDENT "." IDENT + +function_call: IDENT "(" arg_list? ")" + +arg_list: expr ("," expr)* + +?atom: STRING -> string_val + | FLOAT -> float_val + | INT -> int_val + | IDENT -> var_ref + +// -- Terminals -- + +// Case-insensitive keywords (high priority to beat IDENT) +KW_MATCH.10: /MATCH/i +KW_WHERE.10: /WHERE/i +KW_RETURN.10: /RETURN/i +KW_AND.10: /AND/i +KW_OR.10: /OR/i +KW_NOT.10: /NOT/i +KW_AS.10: /AS/i +KW_FROM.10: /FROM/i +KW_CONTAINS.10: /CONTAINS/i +KW_STARTS_WITH.10: /STARTS\s+WITH/i +KW_ENDS_WITH.10: /ENDS\s+WITH/i +KW_IS_NOT_NULL.10: /IS\s+NOT\s+NULL/i +KW_IS_NULL.10: /IS\s+NULL/i +KW_IN.10: /IN/i + +// Node and edge labels (high priority) +NODE_LABEL.5: "Finding" | "Host" | "IP" | "CVE" | "Domain" | "Port" | "MitreAttack" | "Entity" +EDGE_LABEL.5: "LINKED" | "MENTIONED_IN" + +// Literals +STRING: /\"[^\"]*\"/ +FLOAT.2: /\d+\.\d+/ +INT: /\d+/ + +// Identifiers (lowest priority) +IDENT: /[a-zA-Z_][a-zA-Z0-9_]*/ + +// Whitespace +%import common.WS +%ignore WS diff --git a/packages/cli/src/opentools/chain/cypher/limits.py b/packages/cli/src/opentools/chain/cypher/limits.py new file mode 100644 index 0000000..e9d4a9d --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/limits.py @@ -0,0 +1,13 @@ +"""Resource limits for query execution.""" +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class QueryLimits(BaseModel): + model_config = ConfigDict(frozen=True) + + timeout_seconds: float = 30.0 + max_rows: int = 1000 + intermediate_binding_cap: int = 10_000 + max_var_length_hops: int = 10 diff --git a/packages/cli/src/opentools/chain/cypher/parser.py b/packages/cli/src/opentools/chain/cypher/parser.py new file mode 100644 index 0000000..7e00e37 --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/parser.py @@ -0,0 +1,353 @@ +"""Lark-based parser and transformer for the Cypher-style query DSL.""" +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +from lark import Lark, Token, Transformer, Tree +from lark.exceptions import LarkError + +from .ast_nodes import ( + BooleanExpr, + ComparisonExpr, + CypherQuery, + EdgePattern, + FromClause, + FunctionCallExpr, + MatchClause, + NodePattern, + PropertyAccessExpr, + ReturnClause, + ReturnItem, + VarLengthSpec, + WhereClause, +) +from .errors import QueryParseError + +_MUTATION_VERBS = re.compile( + r"\b(CREATE|DELETE|SET|MERGE|REMOVE|DETACH|DROP)\b", re.IGNORECASE +) + +_MAX_VAR_LENGTH_HOPS = 10 + +_GRAMMAR_PATH = Path(__file__).parent / "grammar.lark" + +_parser: Lark | None = None + + +def _get_parser() -> Lark: + """Lazily load and cache the Lark parser.""" + global _parser + if _parser is None: + _parser = Lark( + _GRAMMAR_PATH.read_text(encoding="utf-8"), + parser="lalr", + start="start", + ) + return _parser + + +# Helper sets for disambiguation +_NODE_LABELS = frozenset({"Finding", "Host", "IP", "CVE", "Domain", "Port", "MitreAttack", "Entity"}) +_EDGE_LABELS = frozenset({"LINKED", "MENTIONED_IN"}) + +# Terminal types that are keyword tokens (should be filtered) +_KW_TYPES = frozenset({ + "KW_MATCH", "KW_WHERE", "KW_RETURN", "KW_AND", "KW_OR", "KW_NOT", + "KW_AS", "KW_FROM", "KW_CONTAINS", "KW_STARTS_WITH", "KW_ENDS_WITH", + "KW_IS_NOT_NULL", "KW_IS_NULL", "KW_IN", +}) + + +def _filter_kw(items: list) -> list: + """Filter out keyword tokens from transformer items.""" + return [i for i in items if not (isinstance(i, Token) and i.type in _KW_TYPES)] + + +class CypherTransformer(Transformer): + """Transform Lark parse tree into typed AST nodes.""" + + # -- Literals -- + + def string_val(self, items: list) -> str: + return str(items[0])[1:-1] + + def float_val(self, items: list) -> float: + return float(items[0]) + + def int_val(self, items: list) -> int: + return int(items[0]) + + def var_ref(self, items: list) -> str: + return str(items[0]) + + # -- Property access -- + + def prop_access(self, items: list) -> PropertyAccessExpr: + idents = [i for i in items if isinstance(i, Token) and i.type == "IDENT"] + return PropertyAccessExpr(variable=str(idents[0]), property_name=str(idents[1])) + + # -- Function calls -- + + def function_call(self, items: list) -> FunctionCallExpr: + name = str(items[0]) + args_list: list = [] + for item in items[1:]: + if isinstance(item, list): + args_list = item + return FunctionCallExpr(name=name, args=args_list) + + def arg_list(self, items: list) -> list: + return list(items) + + # -- Comparison / boolean expressions -- + + def comp_op(self, items: list) -> str: + return str(items[0]) + + def comparison_expr(self, items: list) -> ComparisonExpr: + # items: [left_expr, comp_op_str, right_expr] + left, op, right = items + return ComparisonExpr(left=left, operator=op, right=right) + + def contains_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="CONTAINS", right=filtered[1]) + + def starts_with_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="STARTS WITH", right=filtered[1]) + + def ends_with_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="ENDS WITH", right=filtered[1]) + + def in_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="IN", right=filtered[1]) + + def is_null_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="IS NULL", right=None) + + def is_not_null_expr(self, items: list) -> ComparisonExpr: + filtered = _filter_kw(items) + return ComparisonExpr(left=filtered[0], operator="IS NOT NULL", right=None) + + def where_func(self, items: list) -> Any: + return items[0] + + def not_expr(self, items: list) -> BooleanExpr: + filtered = _filter_kw(items) + return BooleanExpr(operator="NOT", operands=[filtered[0]]) + + def or_expr(self, items: list) -> BooleanExpr | Any: + filtered = _filter_kw(items) + if len(filtered) == 1: + return filtered[0] + return BooleanExpr(operator="OR", operands=filtered) + + def and_expr(self, items: list) -> BooleanExpr | Any: + filtered = _filter_kw(items) + if len(filtered) == 1: + return filtered[0] + return BooleanExpr(operator="AND", operands=filtered) + + # -- Pattern elements -- + + def variable_def(self, items: list) -> str: + return str(items[0]) + + def label_def(self, items: list) -> str: + return str(items[0]) + + def edge_var(self, items: list) -> str: + return str(items[0]) + + def edge_label(self, items: list) -> str: + return str(items[0]) + + def var_length(self, items: list) -> VarLengthSpec: + ints = [i for i in items if isinstance(i, Token) and i.type == "INT"] + min_hops = int(ints[0]) + max_hops = int(ints[1]) + if max_hops > _MAX_VAR_LENGTH_HOPS: + raise QueryParseError( + f"Variable-length max hops {max_hops} exceeds max of {_MAX_VAR_LENGTH_HOPS}" + ) + return VarLengthSpec(min_hops=min_hops, max_hops=max_hops) + + def edge_detail(self, items: list) -> dict: + variable = None + label = None + vl = None + for item in items: + if isinstance(item, VarLengthSpec): + vl = item + elif isinstance(item, str): + if item in _EDGE_LABELS: + label = item + else: + variable = item + return {"variable": variable, "label": label, "var_length": vl} + + def outgoing_edge(self, items: list) -> EdgePattern: + detail = items[0] + return EdgePattern( + variable=detail["variable"], + label=detail["label"], + direction="out", + var_length=detail["var_length"], + ) + + def incoming_edge(self, items: list) -> EdgePattern: + detail = items[0] + return EdgePattern( + variable=detail["variable"], + label=detail["label"], + direction="in", + var_length=detail["var_length"], + ) + + def edge_pattern(self, items: list) -> EdgePattern: + return items[0] + + def from_clause(self, items: list) -> FromClause: + filtered = _filter_kw(items) + return FromClause(session_variable=str(filtered[0])) + + def node_pattern(self, items: list) -> tuple: + """Return (NodePattern, optional FromClause).""" + variable = None + label = None + fc = None + for item in items: + if isinstance(item, FromClause): + fc = item + elif isinstance(item, str): + if item in _NODE_LABELS: + label = item + else: + variable = item + return (NodePattern(variable=variable, label=label), fc) + + def pattern(self, items: list) -> tuple: + """Build a pattern tuple from alternating nodes and edges. + + Returns (pattern_elements_tuple, optional_from_clause). + """ + elements = [] + from_clause = None + for item in items: + if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], NodePattern): + node, fc = item + elements.append(node) + if fc is not None: + from_clause = fc + elif isinstance(item, EdgePattern): + elements.append(item) + return (tuple(elements), from_clause) + + def pattern_list(self, items: list) -> list: + return list(items) + + def match_clause(self, items: list) -> tuple: + """Return (MatchClause, optional FromClause).""" + filtered = _filter_kw(items) + pattern_items = filtered[0] # list of (pattern_tuple, from_clause) tuples + patterns = [] + from_clause = None + for pattern_tuple, fc in pattern_items: + patterns.append(pattern_tuple) + if fc is not None: + from_clause = fc + return (MatchClause(patterns=patterns), from_clause) + + def where_clause(self, items: list) -> WhereClause: + filtered = _filter_kw(items) + return WhereClause(expression=filtered[0]) + + # -- Return clause -- + + def return_item(self, items: list) -> ReturnItem: + expression = items[0] + alias = items[1] if len(items) > 1 else None + return ReturnItem(expression=expression, alias=alias) + + def alias(self, items: list) -> str: + filtered = _filter_kw(items) + return str(filtered[0]) + + def return_list(self, items: list) -> list: + return list(items) + + def return_clause(self, items: list) -> ReturnClause: + filtered = _filter_kw(items) + return ReturnClause(items=filtered[0]) + + # -- Top-level -- + + def query(self, items: list) -> CypherQuery: + match_result = items[0] + match_clause, from_clause = match_result + + where_clause = None + return_clause = None + for item in items[1:]: + if isinstance(item, WhereClause): + where_clause = item + elif isinstance(item, ReturnClause): + return_clause = item + + return CypherQuery( + match_clause=match_clause, + where_clause=where_clause, + return_clause=return_clause, + from_clause=from_clause, + ) + + def session_assignment(self, items: list) -> CypherQuery: + var_name = str(items[0]) + query = items[1] + query.session_assignment = var_name + return query + + def start(self, items: list) -> CypherQuery: + return items[0] + + +def parse_cypher(query: str) -> CypherQuery: + """Parse a Cypher-style query string into a CypherQuery AST. + + Args: + query: The query string to parse. + + Returns: + A CypherQuery AST node. + + Raises: + QueryParseError: If the query is empty, contains mutation verbs, + or has syntax errors. + """ + if not query or not query.strip(): + raise QueryParseError("Empty query") + + # Reject mutation verbs before parsing + if _MUTATION_VERBS.search(query): + raise QueryParseError( + f"Mutation operations are not allowed: {_MUTATION_VERBS.search(query).group(1)}" + ) + + try: + parser = _get_parser() + tree = parser.parse(query) + transformer = CypherTransformer() + result = transformer.transform(tree) + return result + except QueryParseError: + raise + except LarkError as exc: + raise QueryParseError(f"Syntax error: {exc}") from exc + except Exception as exc: + raise QueryParseError(f"Parse error: {exc}") from exc diff --git a/packages/cli/src/opentools/chain/cypher/planner.py b/packages/cli/src/opentools/chain/cypher/planner.py new file mode 100644 index 0000000..f1adb8a --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/planner.py @@ -0,0 +1,121 @@ +"""Query planner: AST → QueryPlan with predicate pushdown.""" +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Literal +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, ComparisonExpr, CypherQuery, EdgePattern, + FunctionCallExpr, NodePattern, PropertyAccessExpr, +) +from opentools.chain.cypher.limits import QueryLimits + +@dataclass +class PlanStep: + kind: Literal["scan", "expand", "filter", "var_length_expand"] + target_var: str + label: str | None = None + direction: Literal["out", "in", "both"] | None = None + min_hops: int | None = None + max_hops: int | None = None + predicates: list[Any] = field(default_factory=list) + from_session: str | None = None + +@dataclass +class ReturnSpec: + items: list[Any] + +@dataclass +class QueryPlan: + steps: list[PlanStep] + return_spec: ReturnSpec + limits: QueryLimits + +def _extract_variables(expr: Any) -> set[str]: + """Extract variable names referenced in an expression.""" + if isinstance(expr, PropertyAccessExpr): + return {expr.variable} + if isinstance(expr, ComparisonExpr): + return _extract_variables(expr.left) | ( + _extract_variables(expr.right) if not isinstance(expr.right, (str, int, float, bool, type(None), list)) else set() + ) + if isinstance(expr, BooleanExpr): + result: set[str] = set() + for op in expr.operands: + result |= _extract_variables(op) + return result + if isinstance(expr, FunctionCallExpr): + result = set() + for arg in expr.args: + if isinstance(arg, str): + result.add(arg) + else: + result |= _extract_variables(arg) + return result + if isinstance(expr, str): + return {expr} + return set() + +def _flatten_and(expr: Any) -> list[Any]: + """Flatten AND expressions into a list of conjuncts.""" + if isinstance(expr, BooleanExpr) and expr.operator == "AND": + result = [] + for op in expr.operands: + result.extend(_flatten_and(op)) + return result + return [expr] + +def plan_query(query: CypherQuery, limits: QueryLimits) -> QueryPlan: + steps: list[PlanStep] = [] + pending_predicates: list[Any] = [] + if query.where_clause is not None: + pending_predicates = _flatten_and(query.where_clause.expression) + bound_vars: set[str] = set() + + for pattern_tuple in query.match_clause.patterns: + elements = list(pattern_tuple) + for i, element in enumerate(elements): + if isinstance(element, NodePattern): + var = element.variable + if var and var not in bound_vars: + from_session = None + if query.from_clause and i == 0: + from_session = query.from_clause.session_variable + step = PlanStep(kind="scan", target_var=var, label=element.label, from_session=from_session) + bound_vars.add(var) + remaining = [] + for pred in pending_predicates: + pred_vars = _extract_variables(pred) + if pred_vars <= bound_vars: + step.predicates.append(pred) + else: + remaining.append(pred) + pending_predicates = remaining + steps.append(step) + elif isinstance(element, EdgePattern): + var = element.variable + if element.var_length is not None: + step = PlanStep( + kind="var_length_expand", target_var=var or f"_anon_edge_{i}", + label=element.label, direction=element.direction, + min_hops=element.var_length.min_hops, max_hops=element.var_length.max_hops, + ) + else: + step = PlanStep( + kind="expand", target_var=var or f"_anon_edge_{i}", + label=element.label, direction=element.direction, + ) + if var: + bound_vars.add(var) + remaining = [] + for pred in pending_predicates: + pred_vars = _extract_variables(pred) + if pred_vars <= bound_vars: + step.predicates.append(pred) + else: + remaining.append(pred) + pending_predicates = remaining + steps.append(step) + + if pending_predicates: + steps.append(PlanStep(kind="filter", target_var="_post_filter", predicates=pending_predicates)) + + return QueryPlan(steps=steps, return_spec=ReturnSpec(items=query.return_clause.items), limits=limits) diff --git a/packages/cli/src/opentools/chain/cypher/plugins.py b/packages/cli/src/opentools/chain/cypher/plugins.py new file mode 100644 index 0000000..9c599d1 --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/plugins.py @@ -0,0 +1,38 @@ +"""Plugin function registry for the Cypher DSL.""" +from __future__ import annotations +from typing import Any, Callable + +class PluginFunctionRegistry: + def __init__(self) -> None: + self._scalars: dict[str, dict] = {} + self._aggregations: dict[str, dict] = {} + + def register_function(self, name: str, fn: Callable, *, help: str = "", arg_types: list[str], return_type: str) -> None: + if "." not in name: + raise ValueError(f"plugin function names must be dotted (e.g., 'plugin.func'), got: {name!r}") + if name in self._scalars or name in self._aggregations: + raise ValueError(f"function {name!r} already registered") + self._scalars[name] = {"fn": fn, "help": help, "arg_types": arg_types, "return_type": return_type} + + def register_aggregation(self, name: str, fn: Callable, *, help: str = "", input_type: str, return_type: str) -> None: + if "." not in name: + raise ValueError(f"plugin aggregation names must be dotted (e.g., 'plugin.agg'), got: {name!r}") + if name in self._scalars or name in self._aggregations: + raise ValueError(f"function {name!r} already registered") + self._aggregations[name] = {"fn": fn, "help": help, "input_type": input_type, "return_type": return_type} + + def get_function(self, name: str) -> Callable | None: + entry = self._scalars.get(name) + return entry["fn"] if entry else None + + def get_aggregation(self, name: str) -> Callable | None: + entry = self._aggregations.get(name) + return entry["fn"] if entry else None + + def list_all(self) -> dict[str, dict]: + result: dict[str, dict] = {} + for name, info in self._scalars.items(): + result[name] = {"kind": "scalar", "help": info["help"], "arg_types": info["arg_types"], "return_type": info["return_type"]} + for name, info in self._aggregations.items(): + result[name] = {"kind": "aggregation", "help": info["help"], "input_type": info["input_type"], "return_type": info["return_type"]} + return result diff --git a/packages/cli/src/opentools/chain/cypher/result.py b/packages/cli/src/opentools/chain/cypher/result.py new file mode 100644 index 0000000..d2149af --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/result.py @@ -0,0 +1,28 @@ +"""Query result types.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class QueryStats: + duration_ms: float = 0.0 + bindings_explored: int = 0 + rows_returned: int = 0 + + +@dataclass +class SubgraphProjection: + node_indices: set[int] = field(default_factory=set) + edge_tuples: set[tuple[int, int]] = field(default_factory=set) + + +@dataclass +class QueryResult: + columns: list[str] = field(default_factory=list) + rows: list[dict[str, Any]] = field(default_factory=list) + subgraph: SubgraphProjection | None = None + stats: QueryStats = field(default_factory=QueryStats) + truncated: bool = False + truncation_reason: str | None = None diff --git a/packages/cli/src/opentools/chain/cypher/session.py b/packages/cli/src/opentools/chain/cypher/session.py new file mode 100644 index 0000000..6c25e3b --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/session.py @@ -0,0 +1,21 @@ +"""Query session: named result sets for the REPL.""" +from __future__ import annotations + +from opentools.chain.cypher.result import QueryResult + + +class QuerySession: + def __init__(self) -> None: + self._variables: dict[str, QueryResult] = {} + + def store(self, name: str, result: QueryResult) -> None: + self._variables[name] = result + + def get(self, name: str) -> QueryResult | None: + return self._variables.get(name) + + def list_variables(self) -> list[str]: + return list(self._variables.keys()) + + def clear(self) -> None: + self._variables.clear() diff --git a/packages/cli/src/opentools/chain/cypher/virtual_graph.py b/packages/cli/src/opentools/chain/cypher/virtual_graph.py new file mode 100644 index 0000000..ce574cb --- /dev/null +++ b/packages/cli/src/opentools/chain/cypher/virtual_graph.py @@ -0,0 +1,289 @@ +"""Virtual heterogeneous graph builder and cache for the Cypher DSL executor. + +The VirtualGraph overlays entity nodes on top of the MasterGraph (finding +nodes + LINKED edges) so that the executor can traverse across both node +types in a single graph walk. + +Node labels: + Finding — every FindingNode from the MasterGraph + Host / IP / CVE / Domain / Port / MitreAttack / Entity + — EntityNode instances, label derived from entity.type + +Edge types: + LINKED — copied from MasterGraph (EdgeData payload) + MENTIONED_IN — Entity → Finding (MentionedInEdge payload) + +The VirtualGraphCache is an async LRU with per-key build lock, mirroring +the design of GraphCache in opentools.chain.query.graph_cache. +""" +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING +from uuid import UUID + +import rustworkx as rx + +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import MasterGraph + +if TYPE_CHECKING: + from opentools.chain.query.graph_cache import GraphCache + from opentools.chain.store_protocol import ChainStoreProtocol + + +# ─── node / edge payload types ──────────────────────────────────────────────── + + +@dataclass +class EntityNode: + """Payload attached to entity nodes in the VirtualGraph.""" + + entity_id: str + entity_type: str + canonical_value: str + mention_count: int + + +@dataclass +class MentionedInEdge: + """Payload for Entity → Finding edges.""" + + mention_id: str + field: str # MentionField value + confidence: float + extractor: str + + +# ─── label mapping ──────────────────────────────────────────────────────────── + +_TYPE_TO_LABEL: dict[str, str] = { + "host": "Host", + "ip": "IP", + "cve": "CVE", + "domain": "Domain", + "port": "Port", + "mitre_technique": "MitreAttack", +} + +_FINDING_LABEL = "Finding" + + +def _entity_label(entity_type: str) -> str: + return _TYPE_TO_LABEL.get(entity_type.lower(), "Entity") + + +# ─── VirtualGraph ───────────────────────────────────────────────────────────── + + +@dataclass +class VirtualGraph: + """Heterogeneous graph combining findings and entities. + + Attributes: + graph — rustworkx directed graph + finding_map — finding_id → node index in *this* graph + entity_map — entity_id → node index in *this* graph + reverse_map — node index → id (finding_id or entity_id) + node_labels — node index → label string ("Finding", "Host", …) + generation — linker generation from the source MasterGraph + """ + + graph: rx.PyDiGraph + finding_map: dict[str, int] + entity_map: dict[str, int] + reverse_map: dict[int, str] + node_labels: dict[int, str] + generation: int + + +# ─── VirtualGraphBuilder ────────────────────────────────────────────────────── + + +class VirtualGraphBuilder: + """Builds a VirtualGraph from a MasterGraph plus entity/mention lists.""" + + def build( + self, + master: MasterGraph, + entities: list[Entity], + mentions: list[EntityMention], + ) -> VirtualGraph: + vg = rx.PyDiGraph() + finding_map: dict[str, int] = {} + entity_map: dict[str, int] = {} + reverse_map: dict[int, str] = {} + node_labels: dict[int, str] = {} + + # ── 1. Copy finding nodes from master ───────────────────────────── + # master.node_map: finding_id -> master node index + # We create new node indices in the virtual graph. + master_to_virtual: dict[int, int] = {} + for finding_id, master_idx in master.node_map.items(): + node_data = master.graph.get_node_data(master_idx) + v_idx = vg.add_node(node_data) + master_to_virtual[master_idx] = v_idx + finding_map[finding_id] = v_idx + reverse_map[v_idx] = finding_id + node_labels[v_idx] = _FINDING_LABEL + + # ── 2. Copy LINKED edges from master ────────────────────────────── + # edge_list() returns a list of (src_idx, tgt_idx) in master space. + # edges() returns the corresponding payloads in the same order. + master_endpoints = list(master.graph.edge_list()) + master_payloads = list(master.graph.edges()) + for (src_m, tgt_m), payload in zip(master_endpoints, master_payloads): + src_v = master_to_virtual.get(src_m) + tgt_v = master_to_virtual.get(tgt_m) + if src_v is not None and tgt_v is not None: + vg.add_edge(src_v, tgt_v, payload) + + # ── 3. Add entity nodes ─────────────────────────────────────────── + for entity in entities: + node_data = EntityNode( + entity_id=entity.id, + entity_type=entity.type, + canonical_value=entity.canonical_value, + mention_count=entity.mention_count, + ) + v_idx = vg.add_node(node_data) + entity_map[entity.id] = v_idx + reverse_map[v_idx] = entity.id + node_labels[v_idx] = _entity_label(entity.type) + + # ── 4. Add MENTIONED_IN edges: Entity → Finding ─────────────────── + for mention in mentions: + ent_v = entity_map.get(mention.entity_id) + fnd_v = finding_map.get(mention.finding_id) + if ent_v is None or fnd_v is None: + continue + edge_data = MentionedInEdge( + mention_id=mention.id, + field=str(mention.field), + confidence=mention.confidence, + extractor=mention.extractor, + ) + vg.add_edge(ent_v, fnd_v, edge_data) + + return VirtualGraph( + graph=vg, + finding_map=finding_map, + entity_map=entity_map, + reverse_map=reverse_map, + node_labels=node_labels, + generation=master.generation, + ) + + +# ─── VirtualGraphCache ──────────────────────────────────────────────────────── + + +class VirtualGraphCache: + """Async LRU cache of VirtualGraphs with per-key build lock. + + Keyed by ``(user_id_str, generation, include_candidates, engagement_ids)``. + Capacity bounded by ``maxsize``. Concurrent callers for the same key + collapse to a single build — the first waiter builds; subsequent waiters + re-check and return the cached instance. + + Args: + store: ChainStoreProtocol instance + graph_cache: GraphCache instance (provides get_master_graph) + maxsize: maximum number of cached VirtualGraphs (default 4) + """ + + def __init__( + self, + *, + store: "ChainStoreProtocol", + graph_cache: "GraphCache", + maxsize: int = 4, + ) -> None: + self.store = store + self.graph_cache = graph_cache + self.maxsize = maxsize + self._cache: dict[tuple, VirtualGraph] = {} + self._access_order: list[tuple] = [] + self._build_locks: dict[tuple, asyncio.Lock] = {} + self._builder = VirtualGraphBuilder() + + async def get( + self, + *, + user_id: UUID | None, + include_candidates: bool = False, + engagement_ids: tuple[str, ...] | list[str] | None = None, + ) -> VirtualGraph: + """Return a VirtualGraph for the given scope, building if necessary.""" + generation = await self.store.current_linker_generation(user_id=user_id) + # Normalise engagement_ids to a hashable form + eng_key = tuple(sorted(engagement_ids)) if engagement_ids else None + key = ( + str(user_id) if user_id else None, + generation, + include_candidates, + eng_key, + ) + + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + + lock = self._build_locks.setdefault(key, asyncio.Lock()) + async with lock: + # Another waiter may have populated the cache while we waited. + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + + vg = await self._build( + user_id=user_id, + include_candidates=include_candidates, + engagement_ids=engagement_ids, + ) + self._cache[key] = vg + self._access_order.append(key) + + while len(self._access_order) > self.maxsize: + oldest = self._access_order.pop(0) + self._cache.pop(oldest, None) + self._build_locks.pop(oldest, None) + + return vg + + def invalidate(self, *, user_id: UUID | None) -> None: + """Drop all cached graphs for a specific user.""" + user_key = str(user_id) if user_id else None + to_remove = [k for k in self._access_order if k[0] == user_key] + for k in to_remove: + self._access_order.remove(k) + self._cache.pop(k, None) + self._build_locks.pop(k, None) + + def clear(self) -> None: + self._cache.clear() + self._access_order.clear() + self._build_locks.clear() + + # ── internals ───────────────────────────────────────────────────────────── + + async def _build( + self, + *, + user_id: UUID | None, + include_candidates: bool, + engagement_ids: tuple[str, ...] | list[str] | None, + ) -> VirtualGraph: + master = await self.graph_cache.get_master_graph( + user_id=user_id, + include_candidates=include_candidates, + ) + entities = await self.store.list_entities(user_id=user_id, limit=10_000) + mentions = await self.store.fetch_all_mentions_in_scope( + user_id=user_id, + engagement_ids=list(engagement_ids) if engagement_ids else None, + ) + return self._builder.build(master, entities, mentions) diff --git a/packages/cli/src/opentools/chain/store_protocol.py b/packages/cli/src/opentools/chain/store_protocol.py index 77b32e5..581d95f 100644 --- a/packages/cli/src/opentools/chain/store_protocol.py +++ b/packages/cli/src/opentools/chain/store_protocol.py @@ -144,6 +144,20 @@ async def fetch_entity_mentions_for_engagement( """ ... + async def fetch_all_mentions_in_scope( + self, + *, + user_id: UUID | None, + engagement_ids: list[str] | None = None, + ) -> list[EntityMention]: + """Return all entity mentions for the user scope. + + Used by VirtualGraphBuilder to populate MENTIONED_IN edges. + When ``engagement_ids`` is provided, restricts to mentions whose + finding belongs to one of those engagements. + """ + ... + # --- Relation CRUD --- async def upsert_relations_bulk( diff --git a/packages/cli/src/opentools/chain/stores/postgres_async.py b/packages/cli/src/opentools/chain/stores/postgres_async.py index 206cac0..1890085 100644 --- a/packages/cli/src/opentools/chain/stores/postgres_async.py +++ b/packages/cli/src/opentools/chain/stores/postgres_async.py @@ -623,6 +623,19 @@ async def mentions_for_finding( result = await self._session.execute(stmt) return [_orm_to_mention(r) for r in result.scalars()] + @require_initialized + @require_user_scope + async def fetch_all_mentions_in_scope( + self, *, user_id: UUID, engagement_ids: list[str] | None = None + ) -> list[EntityMention]: + M = self._models + assert self._session is not None + stmt = select(M.ChainEntityMention).where( + M.ChainEntityMention.user_id == user_id, + ) + result = await self._session.execute(stmt) + return [_orm_to_mention(r) for r in result.scalars()] + @require_initialized @require_user_scope async def delete_mentions_for_finding( diff --git a/packages/cli/src/opentools/chain/stores/sqlite_async.py b/packages/cli/src/opentools/chain/stores/sqlite_async.py index 69cbb7e..7cb0534 100644 --- a/packages/cli/src/opentools/chain/stores/sqlite_async.py +++ b/packages/cli/src/opentools/chain/stores/sqlite_async.py @@ -453,6 +453,16 @@ async def mentions_for_finding( rows = await cur.fetchall() return [_row_to_mention(row) for row in rows] + @require_initialized + async def fetch_all_mentions_in_scope( + self, *, user_id, engagement_ids: list[str] | None = None + ) -> list[EntityMention]: + async with self._conn.execute( + "SELECT * FROM entity_mention", + ) as cur: + rows = await cur.fetchall() + return [_row_to_mention(row) for row in rows] + @require_initialized async def delete_mentions_for_finding( self, finding_id: str, *, user_id diff --git a/packages/cli/tests/chain/cypher/__init__.py b/packages/cli/tests/chain/cypher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/cli/tests/chain/cypher/test_ast_nodes.py b/packages/cli/tests/chain/cypher/test_ast_nodes.py new file mode 100644 index 0000000..d4f83a1 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_ast_nodes.py @@ -0,0 +1,116 @@ +from opentools.chain.cypher.ast_nodes import ( + BooleanExpr, + ComparisonExpr, + EdgePattern, + FunctionCallExpr, + MatchClause, + NodePattern, + PropertyAccessExpr, + ReturnClause, + ReturnItem, + SessionAssignment, + VarLengthSpec, + WhereClause, +) + + +def test_node_pattern(): + n = NodePattern(variable="a", label="Finding") + assert n.variable == "a" + assert n.label == "Finding" + + +def test_node_pattern_no_label(): + n = NodePattern(variable="x", label=None) + assert n.label is None + + +def test_edge_pattern_outgoing(): + e = EdgePattern(variable="r", label="LINKED", direction="out", var_length=None) + assert e.direction == "out" + assert e.var_length is None + + +def test_edge_pattern_with_var_length(): + vl = VarLengthSpec(min_hops=1, max_hops=5) + e = EdgePattern(variable="r", label="LINKED", direction="out", var_length=vl) + assert e.var_length.min_hops == 1 + assert e.var_length.max_hops == 5 + + +def test_var_length_spec_defaults(): + vl = VarLengthSpec(min_hops=1, max_hops=3) + assert vl.min_hops == 1 + assert vl.max_hops == 3 + + +def test_property_access_expr(): + p = PropertyAccessExpr(variable="a", property_name="severity") + assert p.variable == "a" + assert p.property_name == "severity" + + +def test_comparison_expr(): + left = PropertyAccessExpr(variable="a", property_name="severity") + c = ComparisonExpr(left=left, operator="=", right="critical") + assert c.operator == "=" + assert c.right == "critical" + + +def test_boolean_expr(): + left = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", right="critical", + ) + right = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="tool"), + operator="=", right="nmap", + ) + b = BooleanExpr(operator="AND", operands=[left, right]) + assert b.operator == "AND" + assert len(b.operands) == 2 + + +def test_function_call_expr(): + f = FunctionCallExpr(name="has_entity", args=["a", "host", "10.0.0.1"]) + assert f.name == "has_entity" + assert len(f.args) == 3 + + +def test_function_call_plugin_namespaced(): + f = FunctionCallExpr(name="my_plugin.risk_score", args=["a"]) + assert "." in f.name + + +def test_match_clause(): + node_a = NodePattern(variable="a", label="Finding") + edge_r = EdgePattern(variable="r", label="LINKED", direction="out", var_length=None) + node_b = NodePattern(variable="b", label="Finding") + mc = MatchClause(patterns=[(node_a, edge_r, node_b)]) + assert len(mc.patterns) == 1 + + +def test_where_clause(): + pred = ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", right="critical", + ) + wc = WhereClause(expression=pred) + assert wc.expression is pred + + +def test_return_clause(): + items = [ + ReturnItem(expression="a", alias=None), + ReturnItem(expression=PropertyAccessExpr(variable="a", property_name="title"), alias="name"), + ] + rc = ReturnClause(items=items) + assert len(rc.items) == 2 + assert rc.items[1].alias == "name" + + +def test_session_assignment(): + rc = ReturnClause(items=[ReturnItem(expression="a", alias=None)]) + mc = MatchClause(patterns=[]) + sa = SessionAssignment(variable_name="results", match_clause=mc, where_clause=None, return_clause=rc) + assert sa.variable_name == "results" diff --git a/packages/cli/tests/chain/cypher/test_builtins.py b/packages/cli/tests/chain/cypher/test_builtins.py new file mode 100644 index 0000000..8ac00a6 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_builtins.py @@ -0,0 +1,53 @@ +import pytest +from opentools.chain.cypher.builtins import ( + builtin_collect, builtin_has_entity, builtin_has_mitre, + builtin_length, builtin_nodes, builtin_relationships, + get_builtin, list_builtins, +) + +def test_builtin_length(): + path = {"nodes": [1, 2, 3], "edges": [10, 20]} + assert builtin_length(path) == 2 + +def test_builtin_length_empty_path(): + path = {"nodes": [1], "edges": []} + assert builtin_length(path) == 0 + +def test_builtin_nodes(): + path = {"nodes": ["a", "b", "c"], "edges": [1, 2]} + assert builtin_nodes(path) == ["a", "b", "c"] + +def test_builtin_relationships(): + path = {"nodes": ["a", "b"], "edges": ["r1"]} + assert builtin_relationships(path) == ["r1"] + +def test_builtin_has_entity(): + node = {"entities": [{"type": "host", "canonical_value": "10.0.0.1"}, {"type": "cve", "canonical_value": "CVE-2024-1234"}]} + assert builtin_has_entity(node, "host", "10.0.0.1") is True + assert builtin_has_entity(node, "host", "10.0.0.2") is False + assert builtin_has_entity(node, "cve", "CVE-2024-1234") is True + +def test_builtin_has_entity_no_entities(): + node = {"entities": []} + assert builtin_has_entity(node, "host", "anything") is False + +def test_builtin_has_mitre(): + node = {"entities": [{"type": "mitre_technique", "canonical_value": "T1059"}]} + assert builtin_has_mitre(node, "T1059") is True + assert builtin_has_mitre(node, "T1078") is False + +def test_builtin_collect(): + values = [1, 2, 3, 4] + assert builtin_collect(values) == [1, 2, 3, 4] + +def test_get_builtin(): + fn = get_builtin("length") + assert fn is builtin_length + assert get_builtin("nonexistent") is None + +def test_list_builtins(): + builtins = list_builtins() + assert "length" in builtins + assert "has_entity" in builtins + assert "collect" in builtins + assert len(builtins) >= 6 diff --git a/packages/cli/tests/chain/cypher/test_cli_query.py b/packages/cli/tests/chain/cypher/test_cli_query.py new file mode 100644 index 0000000..ff09cd6 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_cli_query.py @@ -0,0 +1,24 @@ +"""Tests for the CLI query subcommands.""" +from __future__ import annotations +from typer.testing import CliRunner +from opentools.chain.cli import app + +runner = CliRunner() + +def test_query_run_help(): + result = runner.invoke(app, ["query", "run", "--help"]) + assert result.exit_code == 0 + assert "Execute a Cypher query" in result.output or "cypher" in result.output.lower() + +def test_query_explain_help(): + result = runner.invoke(app, ["query", "explain", "--help"]) + assert result.exit_code == 0 + +def test_query_repl_help(): + result = runner.invoke(app, ["query", "repl", "--help"]) + assert result.exit_code == 0 + +def test_query_preset_help(): + result = runner.invoke(app, ["query", "preset", "--help"]) + assert result.exit_code == 0 + assert "preset" in result.output.lower() diff --git a/packages/cli/tests/chain/cypher/test_errors.py b/packages/cli/tests/chain/cypher/test_errors.py new file mode 100644 index 0000000..be69d3d --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_errors.py @@ -0,0 +1,26 @@ +from opentools.chain.cypher.errors import ( + QueryParseError, + QueryResourceError, + QueryValidationError, +) + + +def test_query_parse_error_is_exception(): + err = QueryParseError("unexpected token", line=3, column=12) + assert isinstance(err, Exception) + assert err.line == 3 + assert err.column == 12 + assert "unexpected token" in str(err) + + +def test_query_validation_error_is_exception(): + err = QueryValidationError("unknown function: foo.bar") + assert isinstance(err, Exception) + assert "foo.bar" in str(err) + + +def test_query_resource_error_is_exception(): + err = QueryResourceError("binding cap exceeded", limit_name="intermediate_binding_cap", limit_value=10_000) + assert isinstance(err, Exception) + assert err.limit_name == "intermediate_binding_cap" + assert err.limit_value == 10_000 diff --git a/packages/cli/tests/chain/cypher/test_executor.py b/packages/cli/tests/chain/cypher/test_executor.py new file mode 100644 index 0000000..e757827 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_executor.py @@ -0,0 +1,90 @@ +from __future__ import annotations +from datetime import datetime, timezone +import pytest +import rustworkx as rx +from opentools.chain.cypher.executor import CypherExecutor +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.parser import parse_cypher +from opentools.chain.cypher.planner import plan_query +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.result import QueryResult +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import EntityNode, MentionedInEdge, VirtualGraph +from opentools.chain.query.graph_cache import EdgeData, FindingNode + +def _build_test_vg() -> VirtualGraph: + """3 findings, 1 host entity, 2 LINKED edges, 2 MENTIONED_IN edges.""" + g = rx.PyDiGraph() + now = datetime.now(timezone.utc) + n0 = g.add_node(FindingNode(finding_id="fnd_1", severity="high", tool="nmap", title="Open SSH", created_at=now)) + n1 = g.add_node(FindingNode(finding_id="fnd_2", severity="critical", tool="nuclei", title="RCE vuln", created_at=now)) + n2 = g.add_node(FindingNode(finding_id="fnd_3", severity="medium", tool="burp", title="XSS", created_at=now)) + n3 = g.add_node(EntityNode(entity_id="ent_host1", entity_type="host", canonical_value="10.0.0.1", mention_count=2)) + g.add_edge(n0, n1, EdgeData(relation_id="rel_1", weight=2.0, cost=0.5, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + g.add_edge(n1, n2, EdgeData(relation_id="rel_2", weight=1.5, cost=0.7, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + g.add_edge(n3, n0, MentionedInEdge(mention_id="m1", field="description", confidence=1.0, extractor="ioc_finder")) + g.add_edge(n3, n1, MentionedInEdge(mention_id="m2", field="description", confidence=1.0, extractor="ioc_finder")) + return VirtualGraph( + graph=g, finding_map={"fnd_1": n0, "fnd_2": n1, "fnd_3": n2}, + entity_map={"ent_host1": n3}, reverse_map={n0: "fnd_1", n1: "fnd_2", n2: "fnd_3", n3: "ent_host1"}, + node_labels={n0: "Finding", n1: "Finding", n2: "Finding", n3: "Host"}, generation=1, + ) + +def _execute_sync(query_str: str, vg: VirtualGraph | None = None, limits: QueryLimits | None = None) -> QueryResult: + import asyncio + if vg is None: + vg = _build_test_vg() + if limits is None: + limits = QueryLimits() + ast = parse_cypher(query_str) + plan = plan_query(ast, limits) + executor = CypherExecutor( + virtual_graph=vg, plan=plan, session=QuerySession(), + plugin_registry=PluginFunctionRegistry(), limits=limits, + ) + return asyncio.run(executor.execute()) + +def test_scan_all_findings(): + result = _execute_sync("MATCH (a:Finding) RETURN a") + assert len(result.rows) == 3 + assert "a" in result.columns + +def test_scan_entity_label(): + result = _execute_sync("MATCH (h:Host) RETURN h") + assert len(result.rows) == 1 + +def test_expand_linked(): + result = _execute_sync("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert len(result.rows) == 2 # fnd_1->fnd_2, fnd_2->fnd_3 + +def test_expand_mentioned_in(): + result = _execute_sync("MATCH (h:Host)-[r:MENTIONED_IN]->(f:Finding) RETURN h, f") + assert len(result.rows) == 2 # host->fnd_1, host->fnd_2 + +def test_where_filter(): + result = _execute_sync('MATCH (a:Finding) WHERE a.severity = "critical" RETURN a') + assert len(result.rows) == 1 + assert result.rows[0]["a"]["severity"] == "critical" + +def test_where_numeric_comparison(): + result = _execute_sync("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.weight > 1.8 RETURN a, b") + assert len(result.rows) == 1 # only rel_1 has weight=2.0 + +def test_return_property(): + result = _execute_sync("MATCH (a:Finding) RETURN a.title, a.severity") + assert len(result.rows) == 3 + +def test_subgraph_projection(): + result = _execute_sync("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert result.subgraph is not None + assert len(result.subgraph.node_indices) >= 2 + +def test_resource_limit_max_rows(): + result = _execute_sync("MATCH (a:Finding) RETURN a", limits=QueryLimits(max_rows=1)) + assert len(result.rows) == 1 + assert result.truncated is True + +def test_empty_result(): + result = _execute_sync('MATCH (a:Finding) WHERE a.severity = "nonexistent" RETURN a') + assert len(result.rows) == 0 + assert result.truncated is False diff --git a/packages/cli/tests/chain/cypher/test_integration.py b/packages/cli/tests/chain/cypher/test_integration.py new file mode 100644 index 0000000..12307f7 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_integration.py @@ -0,0 +1,109 @@ +"""End-to-end integration: parse → plan → build virtual graph → execute.""" +from __future__ import annotations +from datetime import datetime, timezone +from unittest.mock import AsyncMock +import pytest +import rustworkx as rx +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.parser import parse_cypher +from opentools.chain.cypher.planner import plan_query +from opentools.chain.cypher.executor import CypherExecutor +from opentools.chain.cypher.plugins import PluginFunctionRegistry +from opentools.chain.cypher.session import QuerySession +from opentools.chain.cypher.virtual_graph import VirtualGraphBuilder, VirtualGraphCache, EntityNode, MentionedInEdge +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import EdgeData, FindingNode, MasterGraph, GraphCache +from opentools.chain.types import MentionField + + +def _make_master_graph() -> MasterGraph: + g = rx.PyDiGraph() + now = datetime.now(timezone.utc) + n0 = g.add_node(FindingNode(finding_id="fnd_1", severity="high", tool="nmap", title="Open SSH", created_at=now)) + n1 = g.add_node(FindingNode(finding_id="fnd_2", severity="critical", tool="nuclei", title="RCE vuln", created_at=now)) + n2 = g.add_node(FindingNode(finding_id="fnd_3", severity="medium", tool="burp", title="XSS", created_at=now)) + g.add_edge(n0, n1, EdgeData(relation_id="rel_1", weight=2.0, cost=0.5, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + g.add_edge(n1, n2, EdgeData(relation_id="rel_2", weight=1.5, cost=0.7, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + return MasterGraph(graph=g, node_map={"fnd_1": n0, "fnd_2": n1, "fnd_3": n2}, reverse_map={n0: "fnd_1", n1: "fnd_2", n2: "fnd_3"}, generation=1, max_weight=2.0) + + +def _make_entities() -> list[Entity]: + now = datetime.now(timezone.utc) + return [ + Entity(id="ent_host1", type="host", canonical_value="10.0.0.1", first_seen_at=now, last_seen_at=now, mention_count=2), + Entity(id="ent_cve1", type="cve", canonical_value="CVE-2024-1234", first_seen_at=now, last_seen_at=now, mention_count=1), + ] + + +def _make_mentions() -> list[EntityMention]: + now = datetime.now(timezone.utc) + return [ + EntityMention(id="m1", entity_id="ent_host1", finding_id="fnd_1", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m2", entity_id="ent_host1", finding_id="fnd_2", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m3", entity_id="ent_cve1", finding_id="fnd_2", field=MentionField.TITLE, raw_value="CVE-2024-1234", extractor="security_regex", confidence=0.95, created_at=now), + ] + + +@pytest.mark.asyncio +async def test_full_pipeline_via_virtual_graph_cache(): + """Build virtual graph through cache, execute query, get results.""" + master = _make_master_graph() + entities = _make_entities() + mentions = _make_mentions() + + store = AsyncMock() + store.current_linker_generation = AsyncMock(return_value=1) + store.list_entities = AsyncMock(return_value=entities) + store.fetch_all_mentions_in_scope = AsyncMock(return_value=mentions) + + graph_cache = AsyncMock() + graph_cache.get_master_graph = AsyncMock(return_value=master) + + vg_cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=4) + + # Build virtual graph through cache + vg = await vg_cache.get(user_id=None, include_candidates=False, engagement_ids=None) + assert vg.graph.num_nodes() == 5 # 3 findings + 2 entities + + # Execute query + limits = QueryLimits() + ast = parse_cypher("MATCH (a:Finding) RETURN a") + plan = plan_query(ast, limits) + executor = CypherExecutor( + virtual_graph=vg, plan=plan, session=QuerySession(), + plugin_registry=PluginFunctionRegistry(), limits=limits, + ) + result = await executor.execute() + + assert len(result.rows) == 3 + assert "a" in result.columns + + +@pytest.mark.asyncio +async def test_full_pipeline_entity_traversal(): + """Query that traverses through entity nodes.""" + master = _make_master_graph() + entities = _make_entities() + mentions = _make_mentions() + + store = AsyncMock() + store.current_linker_generation = AsyncMock(return_value=1) + store.list_entities = AsyncMock(return_value=entities) + store.fetch_all_mentions_in_scope = AsyncMock(return_value=mentions) + + graph_cache = AsyncMock() + graph_cache.get_master_graph = AsyncMock(return_value=master) + + vg_cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=4) + vg = await vg_cache.get(user_id=None, include_candidates=False, engagement_ids=None) + + limits = QueryLimits() + ast = parse_cypher("MATCH (h:Host)-[r:MENTIONED_IN]->(f:Finding) RETURN h, f") + plan = plan_query(ast, limits) + executor = CypherExecutor( + virtual_graph=vg, plan=plan, session=QuerySession(), + plugin_registry=PluginFunctionRegistry(), limits=limits, + ) + result = await executor.execute() + + assert len(result.rows) == 2 # host mentions fnd_1 and fnd_2 diff --git a/packages/cli/tests/chain/cypher/test_limits.py b/packages/cli/tests/chain/cypher/test_limits.py new file mode 100644 index 0000000..f055541 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_limits.py @@ -0,0 +1,25 @@ +from opentools.chain.cypher.limits import QueryLimits + + +def test_query_limits_defaults(): + limits = QueryLimits() + assert limits.timeout_seconds == 30.0 + assert limits.max_rows == 1000 + assert limits.intermediate_binding_cap == 10_000 + assert limits.max_var_length_hops == 10 + + +def test_query_limits_custom(): + limits = QueryLimits(timeout_seconds=60.0, max_rows=500) + assert limits.timeout_seconds == 60.0 + assert limits.max_rows == 500 + assert limits.intermediate_binding_cap == 10_000 # unchanged default + + +def test_query_limits_frozen(): + limits = QueryLimits() + try: + limits.timeout_seconds = 99.0 + assert False, "should be frozen" + except Exception: + pass diff --git a/packages/cli/tests/chain/cypher/test_parser.py b/packages/cli/tests/chain/cypher/test_parser.py new file mode 100644 index 0000000..2c04240 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_parser.py @@ -0,0 +1,145 @@ +"""Tests for the Cypher-style query parser.""" +import pytest +from opentools.chain.cypher.ast_nodes import ( + ComparisonExpr, CypherQuery, EdgePattern, FunctionCallExpr, + NodePattern, PropertyAccessExpr, SessionAssignment, +) +from opentools.chain.cypher.errors import QueryParseError +from opentools.chain.cypher.parser import parse_cypher + + +def test_parse_simple_match_return(): + q = parse_cypher("MATCH (a:Finding) RETURN a") + assert isinstance(q, CypherQuery) + assert len(q.match_clause.patterns) == 1 + assert len(q.return_clause.items) == 1 + + +def test_parse_two_node_pattern(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b") + assert len(q.match_clause.patterns) == 1 + pattern = q.match_clause.patterns[0] + assert isinstance(pattern[0], NodePattern) + assert pattern[0].label == "Finding" + assert isinstance(pattern[1], EdgePattern) + assert pattern[1].label == "LINKED" + assert pattern[1].direction == "out" + assert isinstance(pattern[2], NodePattern) + assert pattern[2].label == "Finding" + + +def test_parse_incoming_edge(): + q = parse_cypher("MATCH (a:Finding)<-[r:MENTIONED_IN]-(e:Host) RETURN a, e") + pattern = q.match_clause.patterns[0] + assert pattern[1].direction == "in" + assert pattern[1].label == "MENTIONED_IN" + + +def test_parse_entity_node_labels(): + for label in ["Host", "IP", "CVE", "Domain", "Port", "MitreAttack", "Entity"]: + q = parse_cypher(f"MATCH (e:{label}) RETURN e") + assert q.match_clause.patterns[0][0].label == label + + +def test_parse_var_length_path(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED*1..5]->(b:Finding) RETURN a, b") + edge = q.match_clause.patterns[0][1] + assert edge.var_length is not None + assert edge.var_length.min_hops == 1 + assert edge.var_length.max_hops == 5 + + +def test_parse_var_length_exceeds_max_hops(): + with pytest.raises(QueryParseError, match="max.*10"): + parse_cypher("MATCH (a:Finding)-[r:LINKED*1..15]->(b:Finding) RETURN a, b") + + +def test_parse_where_comparison(): + q = parse_cypher('MATCH (a:Finding) WHERE a.severity = "critical" RETURN a') + assert q.where_clause is not None + expr = q.where_clause.expression + assert isinstance(expr, ComparisonExpr) + assert isinstance(expr.left, PropertyAccessExpr) + assert expr.left.variable == "a" + assert expr.left.property_name == "severity" + assert expr.operator == "=" + assert expr.right == "critical" + + +def test_parse_where_numeric_comparison(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.weight > 2.0 RETURN a, b") + expr = q.where_clause.expression + assert expr.operator == ">" + assert expr.right == 2.0 + + +def test_parse_where_and(): + q = parse_cypher('MATCH (a:Finding) WHERE a.severity = "critical" AND a.tool = "nmap" RETURN a') + expr = q.where_clause.expression + assert expr.operator == "AND" if hasattr(expr, "operator") else True + + +def test_parse_where_function_call(): + q = parse_cypher('MATCH (a:Finding) WHERE has_entity(a, "host", "10.0.0.1") RETURN a') + assert q.where_clause is not None + + +def test_parse_where_contains(): + q = parse_cypher('MATCH (a:Finding) WHERE a.title CONTAINS "ssh" RETURN a') + expr = q.where_clause.expression + assert expr.operator == "CONTAINS" + + +def test_parse_where_is_null(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) WHERE r.llm_rationale IS NOT NULL RETURN a, b") + assert q.where_clause is not None + + +def test_parse_return_property(): + q = parse_cypher("MATCH (a:Finding) RETURN a.title, a.severity") + assert len(q.return_clause.items) == 2 + assert isinstance(q.return_clause.items[0].expression, PropertyAccessExpr) + + +def test_parse_return_collect(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN collect(a)") + item = q.return_clause.items[0] + assert isinstance(item.expression, FunctionCallExpr) + assert item.expression.name == "collect" + + +def test_parse_session_assignment(): + q = parse_cypher("results = MATCH (a:Finding) RETURN a") + assert q.session_assignment == "results" + + +def test_parse_from_clause(): + q = parse_cypher("MATCH (a) FROM prev_results -[r:LINKED]->(b:Finding) RETURN a, b") + assert q.from_clause is not None + assert q.from_clause.session_variable == "prev_results" + + +@pytest.mark.parametrize("verb", ["CREATE", "DELETE", "SET", "MERGE", "REMOVE", "DETACH", "DROP"]) +def test_parse_rejects_mutation_verbs(verb): + with pytest.raises(QueryParseError): + parse_cypher(f"{verb} (a:Finding)") + + +def test_parse_empty_string(): + with pytest.raises(QueryParseError): + parse_cypher("") + + +def test_parse_garbage(): + with pytest.raises(QueryParseError): + parse_cypher("not a query at all 123 !!!") + + +def test_parse_case_insensitive_keywords(): + q = parse_cypher('match (a:Finding) where a.severity = "critical" return a') + assert q is not None + + +def test_parse_multiple_patterns(): + q = parse_cypher("MATCH (a:Finding)-[r:LINKED]->(b:Finding), (b)-[:MENTIONED_IN]->(e:Host) RETURN a, e") + assert len(q.match_clause.patterns) == 2 diff --git a/packages/cli/tests/chain/cypher/test_planner.py b/packages/cli/tests/chain/cypher/test_planner.py new file mode 100644 index 0000000..9652361 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_planner.py @@ -0,0 +1,96 @@ +# packages/cli/tests/chain/cypher/test_planner.py +import pytest +from opentools.chain.cypher.ast_nodes import ( + ComparisonExpr, CypherQuery, EdgePattern, MatchClause, + NodePattern, PropertyAccessExpr, ReturnClause, ReturnItem, + VarLengthSpec, WhereClause, +) +from opentools.chain.cypher.limits import QueryLimits +from opentools.chain.cypher.planner import plan_query, PlanStep, QueryPlan + +def _simple_query() -> CypherQuery: + """MATCH (a:Finding) RETURN a""" + return CypherQuery( + match_clause=MatchClause(patterns=[(NodePattern(variable="a", label="Finding"),)]), + where_clause=None, + return_clause=ReturnClause(items=[ReturnItem(expression="a", alias=None)]), + ) + +def _two_node_query() -> CypherQuery: + """MATCH (a:Finding)-[r:LINKED]->(b:Finding) RETURN a, b""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + (NodePattern(variable="a", label="Finding"), + EdgePattern(variable="r", label="LINKED", direction="out", var_length=None), + NodePattern(variable="b", label="Finding")), + ]), + where_clause=None, + return_clause=ReturnClause(items=[ + ReturnItem(expression="a", alias=None), + ReturnItem(expression="b", alias=None), + ]), + ) + +def _filtered_query() -> CypherQuery: + """MATCH (a:Finding) WHERE a.severity = "critical" RETURN a""" + return CypherQuery( + match_clause=MatchClause(patterns=[(NodePattern(variable="a", label="Finding"),)]), + where_clause=WhereClause(expression=ComparisonExpr( + left=PropertyAccessExpr(variable="a", property_name="severity"), + operator="=", right="critical", + )), + return_clause=ReturnClause(items=[ReturnItem(expression="a", alias=None)]), + ) + +def _var_length_query() -> CypherQuery: + """MATCH (a:Finding)-[r:LINKED*1..5]->(b:Finding) RETURN a, b""" + return CypherQuery( + match_clause=MatchClause(patterns=[ + (NodePattern(variable="a", label="Finding"), + EdgePattern(variable="r", label="LINKED", direction="out", var_length=VarLengthSpec(min_hops=1, max_hops=5)), + NodePattern(variable="b", label="Finding")), + ]), + where_clause=None, + return_clause=ReturnClause(items=[ + ReturnItem(expression="a", alias=None), + ReturnItem(expression="b", alias=None), + ]), + ) + +def test_plan_simple_scan(): + plan = plan_query(_simple_query(), QueryLimits()) + assert len(plan.steps) == 1 + assert plan.steps[0].kind == "scan" + assert plan.steps[0].label == "Finding" + assert plan.steps[0].target_var == "a" + +def test_plan_two_node_has_scan_then_expand(): + plan = plan_query(_two_node_query(), QueryLimits()) + assert plan.steps[0].kind == "scan" + assert plan.steps[0].target_var == "a" + assert plan.steps[1].kind == "expand" + assert plan.steps[1].target_var == "r" + assert plan.steps[1].label == "LINKED" + +def test_plan_predicate_pushdown(): + plan = plan_query(_filtered_query(), QueryLimits()) + scan_step = plan.steps[0] + assert scan_step.kind == "scan" + assert scan_step.target_var == "a" + assert len(scan_step.predicates) == 1 + assert isinstance(scan_step.predicates[0], ComparisonExpr) + +def test_plan_var_length_expand(): + plan = plan_query(_var_length_query(), QueryLimits()) + var_length_steps = [s for s in plan.steps if s.kind == "var_length_expand"] + assert len(var_length_steps) == 1 + vl = var_length_steps[0] + assert vl.min_hops == 1 + assert vl.max_hops == 5 + assert vl.label == "LINKED" + +def test_plan_preserves_limits(): + limits = QueryLimits(timeout_seconds=60.0, max_rows=500) + plan = plan_query(_simple_query(), limits) + assert plan.limits.timeout_seconds == 60.0 + assert plan.limits.max_rows == 500 diff --git a/packages/cli/tests/chain/cypher/test_plugins.py b/packages/cli/tests/chain/cypher/test_plugins.py new file mode 100644 index 0000000..b55784f --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_plugins.py @@ -0,0 +1,36 @@ +import pytest +from opentools.chain.cypher.plugins import PluginFunctionRegistry + +@pytest.fixture +def registry(): + return PluginFunctionRegistry() + +def test_register_scalar_function(registry): + registry.register_function("my_plugin.risk_score", fn=lambda node: 0.9, help="Risk score", arg_types=["node"], return_type="float") + assert registry.get_function("my_plugin.risk_score") is not None + +def test_register_aggregation(registry): + registry.register_aggregation("my_plugin.combined_risk", fn=lambda values: max(values), help="Max risk", input_type="float", return_type="float") + assert registry.get_aggregation("my_plugin.combined_risk") is not None + +def test_reject_undotted_plugin_name(registry): + with pytest.raises(ValueError, match="dotted"): + registry.register_function("no_dot", fn=lambda x: x, help="bad", arg_types=["node"], return_type="float") + +def test_reject_duplicate_name(registry): + registry.register_function("my_plugin.f", fn=lambda x: x, help="first", arg_types=["node"], return_type="float") + with pytest.raises(ValueError, match="already registered"): + registry.register_function("my_plugin.f", fn=lambda x: x, help="second", arg_types=["node"], return_type="float") + +def test_list_all_functions(registry): + registry.register_function("a.one", fn=lambda x: x, help="h1", arg_types=["node"], return_type="float") + registry.register_aggregation("a.two", fn=lambda v: sum(v), help="h2", input_type="float", return_type="float") + all_fns = registry.list_all() + assert "a.one" in all_fns + assert "a.two" in all_fns + assert all_fns["a.one"]["kind"] == "scalar" + assert all_fns["a.two"]["kind"] == "aggregation" + +def test_resolve_returns_none_for_unknown(registry): + assert registry.get_function("nonexistent.fn") is None + assert registry.get_aggregation("nonexistent.fn") is None diff --git a/packages/cli/tests/chain/cypher/test_session.py b/packages/cli/tests/chain/cypher/test_session.py new file mode 100644 index 0000000..426fe16 --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_session.py @@ -0,0 +1,41 @@ +from opentools.chain.cypher.result import QueryResult, QueryStats +from opentools.chain.cypher.session import QuerySession + + +def test_session_store_and_get(): + session = QuerySession() + result = QueryResult(columns=["a"], rows=[{"a": 1}, {"a": 2}], stats=QueryStats()) + session.store("my_results", result) + retrieved = session.get("my_results") + assert retrieved is result + + +def test_session_get_unknown(): + session = QuerySession() + assert session.get("nonexistent") is None + + +def test_session_list_variables(): + session = QuerySession() + r1 = QueryResult(columns=["a"], rows=[], stats=QueryStats()) + r2 = QueryResult(columns=["b"], rows=[], stats=QueryStats()) + session.store("first", r1) + session.store("second", r2) + assert set(session.list_variables()) == {"first", "second"} + + +def test_session_clear(): + session = QuerySession() + session.store("x", QueryResult(columns=[], rows=[], stats=QueryStats())) + session.clear() + assert session.get("x") is None + assert session.list_variables() == [] + + +def test_session_overwrite(): + session = QuerySession() + r1 = QueryResult(columns=["a"], rows=[{"a": 1}], stats=QueryStats()) + r2 = QueryResult(columns=["a"], rows=[{"a": 2}], stats=QueryStats()) + session.store("x", r1) + session.store("x", r2) + assert session.get("x") is r2 diff --git a/packages/cli/tests/chain/cypher/test_virtual_graph.py b/packages/cli/tests/chain/cypher/test_virtual_graph.py new file mode 100644 index 0000000..7c9376b --- /dev/null +++ b/packages/cli/tests/chain/cypher/test_virtual_graph.py @@ -0,0 +1,106 @@ +from __future__ import annotations +import asyncio +from datetime import datetime, timezone +from unittest.mock import AsyncMock +import pytest +import rustworkx as rx +from opentools.chain.cypher.virtual_graph import EntityNode, VirtualGraph, VirtualGraphBuilder, VirtualGraphCache +from opentools.chain.models import Entity, EntityMention +from opentools.chain.query.graph_cache import EdgeData, FindingNode, MasterGraph +from opentools.chain.types import MentionField + +def _make_master_graph() -> MasterGraph: + g = rx.PyDiGraph() + now = datetime.now(timezone.utc) + n0 = g.add_node(FindingNode(finding_id="fnd_1", severity="high", tool="nmap", title="Open SSH", created_at=now)) + n1 = g.add_node(FindingNode(finding_id="fnd_2", severity="critical", tool="nuclei", title="RCE vuln", created_at=now)) + n2 = g.add_node(FindingNode(finding_id="fnd_3", severity="medium", tool="burp", title="XSS", created_at=now)) + g.add_edge(n0, n1, EdgeData(relation_id="rel_1", weight=2.0, cost=0.5, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + g.add_edge(n1, n2, EdgeData(relation_id="rel_2", weight=1.5, cost=0.7, status="auto_confirmed", symmetric=False, reasons=[], llm_rationale=None, llm_relation_type=None)) + return MasterGraph(graph=g, node_map={"fnd_1": n0, "fnd_2": n1, "fnd_3": n2}, reverse_map={n0: "fnd_1", n1: "fnd_2", n2: "fnd_3"}, generation=1, max_weight=2.0) + +def _make_entities() -> list[Entity]: + now = datetime.now(timezone.utc) + return [ + Entity(id="ent_host1", type="host", canonical_value="10.0.0.1", first_seen_at=now, last_seen_at=now, mention_count=2), + Entity(id="ent_cve1", type="cve", canonical_value="CVE-2024-1234", first_seen_at=now, last_seen_at=now, mention_count=1), + ] + +def _make_mentions() -> list[EntityMention]: + now = datetime.now(timezone.utc) + return [ + EntityMention(id="m1", entity_id="ent_host1", finding_id="fnd_1", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m2", entity_id="ent_host1", finding_id="fnd_2", field=MentionField.DESCRIPTION, raw_value="10.0.0.1", extractor="ioc_finder", confidence=1.0, created_at=now), + EntityMention(id="m3", entity_id="ent_cve1", finding_id="fnd_2", field=MentionField.TITLE, raw_value="CVE-2024-1234", extractor="security_regex", confidence=0.95, created_at=now), + ] + +def test_build_virtual_graph_node_counts(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + assert vg.graph.num_nodes() == 5 + assert len(vg.finding_map) == 3 + assert len(vg.entity_map) == 2 + +def test_build_virtual_graph_edge_counts(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + assert vg.graph.num_edges() == 5 # 2 LINKED + 3 MENTIONED_IN + +def test_build_virtual_graph_node_labels(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + finding_labels = [vg.node_labels[idx] for idx in vg.finding_map.values()] + assert all(l == "Finding" for l in finding_labels) + host_idx = vg.entity_map["ent_host1"] + assert vg.node_labels[host_idx] == "Host" + cve_idx = vg.entity_map["ent_cve1"] + assert vg.node_labels[cve_idx] == "CVE" + +def test_mentioned_in_direction(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + host_idx = vg.entity_map["ent_host1"] + successors = list(vg.graph.successor_indices(host_idx)) + assert len(successors) == 2 + successor_ids = {vg.reverse_map[s] for s in successors} + assert successor_ids == {"fnd_1", "fnd_2"} + +def test_linked_edges_preserved(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + fnd1_idx = vg.finding_map["fnd_1"] + fnd2_idx = vg.finding_map["fnd_2"] + edge_data = vg.graph.get_edge_data(fnd1_idx, fnd2_idx) + assert edge_data is not None + +def test_entity_node_properties(): + master = _make_master_graph() + builder = VirtualGraphBuilder() + vg = builder.build(master, _make_entities(), _make_mentions()) + host_idx = vg.entity_map["ent_host1"] + node_data = vg.graph.get_node_data(host_idx) + assert isinstance(node_data, EntityNode) + assert node_data.entity_id == "ent_host1" + assert node_data.canonical_value == "10.0.0.1" + assert node_data.entity_type == "host" + +@pytest.mark.asyncio +async def test_virtual_graph_cache_reuse(): + master = _make_master_graph() + entities = _make_entities() + mentions = _make_mentions() + store = AsyncMock() + store.current_linker_generation = AsyncMock(return_value=1) + store.list_entities = AsyncMock(return_value=entities) + store.fetch_all_mentions_in_scope = AsyncMock(return_value=mentions) + graph_cache = AsyncMock() + graph_cache.get_master_graph = AsyncMock(return_value=master) + cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=4) + vg1 = await cache.get(user_id=None, include_candidates=False, engagement_ids=None) + vg2 = await cache.get(user_id=None, include_candidates=False, engagement_ids=None) + assert vg1 is vg2 diff --git a/packages/cli/tests/chain/test_cli_commands.py b/packages/cli/tests/chain/test_cli_commands.py index 427f4cd..8bee513 100644 --- a/packages/cli/tests/chain/test_cli_commands.py +++ b/packages/cli/tests/chain/test_cli_commands.py @@ -95,12 +95,12 @@ def test_cli_path_runs(cli_runner, populated_db): def test_cli_query_mitre_coverage_runs(cli_runner, populated_db): - result = cli_runner.invoke(app, ["query", "mitre-coverage", "--engagement", "eng_cli"]) + result = cli_runner.invoke(app, ["query", "preset", "mitre-coverage", "--engagement", "eng_cli"]) assert result.exit_code == 0 def test_cli_query_unknown_preset_fails(cli_runner, populated_db): - result = cli_runner.invoke(app, ["query", "not-a-real-preset", "--engagement", "eng_cli"]) + result = cli_runner.invoke(app, ["query", "preset", "not-a-real-preset", "--engagement", "eng_cli"]) assert result.exit_code != 0 diff --git a/packages/web/backend/app/main.py b/packages/web/backend/app/main.py index 2aa4eeb..9c76426 100644 --- a/packages/web/backend/app/main.py +++ b/packages/web/backend/app/main.py @@ -21,6 +21,7 @@ correlation, chain, scans, + chain_query, ) @@ -71,3 +72,4 @@ async def lifespan(app: FastAPI): app.include_router(correlation.router) app.include_router(chain.router) app.include_router(scans.router) +app.include_router(chain_query.router) diff --git a/packages/web/backend/app/routes/chain_query.py b/packages/web/backend/app/routes/chain_query.py new file mode 100644 index 0000000..6340073 --- /dev/null +++ b/packages/web/backend/app/routes/chain_query.py @@ -0,0 +1,105 @@ +"""Cypher query DSL web API endpoints.""" +from __future__ import annotations + +from typing import Any, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from app.dependencies import get_current_user, get_db +from app.models import User + +router = APIRouter(prefix="/api/chain/query", tags=["chain-query"]) + + +class QueryRequest(BaseModel): + query: str + engagement_id: Optional[str] = None + include_candidates: bool = False + timeout: float = 30.0 + max_rows: int = 1000 + + +class QueryResponse(BaseModel): + columns: list[str] + rows: list[dict[str, Any]] + subgraph: Optional[dict] = None + stats: dict + truncated: bool + + +@router.post("", response_model=QueryResponse) +async def execute_query( + request: QueryRequest, + current_user: User = Depends(get_current_user), + db=Depends(get_db), +): + """Execute a Cypher query against the attack chain knowledge graph.""" + from opentools.chain.config import get_chain_config + from opentools.chain.cypher import parse_and_execute + from opentools.chain.cypher.errors import QueryParseError, QueryResourceError, QueryValidationError + from opentools.chain.cypher.limits import QueryLimits + from opentools.chain.cypher.virtual_graph import VirtualGraphCache + from opentools.chain.query.graph_cache import GraphCache + from app.services.chain_store_factory import chain_store_from_session + + try: + cfg = get_chain_config() + store = chain_store_from_session(db) + await store.initialize() + + graph_cache = GraphCache(store=store, maxsize=cfg.query.graph_cache_size) + vg_cache = VirtualGraphCache(store=store, graph_cache=graph_cache, maxsize=cfg.cypher.virtual_graph_cache_size) + + engagement_ids = frozenset([request.engagement_id]) if request.engagement_id else None + limits = QueryLimits(timeout_seconds=request.timeout, max_rows=request.max_rows) + + result = await parse_and_execute( + request.query, + store=store, + graph_cache=graph_cache, + vg_cache=vg_cache, + user_id=current_user.id, + include_candidates=request.include_candidates, + engagement_ids=engagement_ids, + limits=limits, + ) + + subgraph_data = None + if result.subgraph: + subgraph_data = { + "nodes": [{"index": idx} for idx in result.subgraph.node_indices], + "edges": [{"source": s, "target": t} for s, t in result.subgraph.edge_tuples], + } + + return QueryResponse( + columns=result.columns, + rows=result.rows, + subgraph=subgraph_data, + stats={ + "duration_ms": result.stats.duration_ms, + "bindings_explored": result.stats.bindings_explored, + "rows_returned": result.stats.rows_returned, + }, + truncated=result.truncated, + ) + + except QueryParseError as e: + raise HTTPException(status_code=400, detail=f"Parse error: {e}") + except QueryValidationError as e: + raise HTTPException(status_code=400, detail=f"Validation error: {e}") + except QueryResourceError as e: + raise HTTPException(status_code=400, detail=f"Resource limit: {e}") + + +@router.get("/functions") +async def list_functions( + current_user: User = Depends(get_current_user), +): + """List all available query functions (built-in and plugin).""" + from opentools.chain.cypher.builtins import list_builtins + + result = [] + for name, info in list_builtins().items(): + result.append({"name": name, "kind": "builtin", "help": info.get("help", "")}) + return result diff --git a/packages/web/frontend/package-lock.json b/packages/web/frontend/package-lock.json index 4dde176..d468539 100644 --- a/packages/web/frontend/package-lock.json +++ b/packages/web/frontend/package-lock.json @@ -8,10 +8,17 @@ "name": "opentools-frontend", "version": "0.1.0", "dependencies": { + "@codemirror/autocomplete": "^6.20.1", + "@codemirror/commands": "^6.10.3", + "@codemirror/language": "^6.12.3", + "@codemirror/search": "^6.6.0", + "@codemirror/state": "^6.6.0", + "@codemirror/view": "^6.41.0", "@primevue/themes": "^4.0", "@tanstack/vue-query": "^5.0", "@vueuse/core": "^12.0", "chart.js": "^4.5.1", + "codemirror": "^6.0.2", "force-graph": "^1.51.2", "pinia": "^3.0", "primeicons": "^7.0", @@ -73,6 +80,87 @@ "node": ">=6.9.0" } }, + "node_modules/@codemirror/autocomplete": { + "version": "6.20.1", + "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.1.tgz", + "integrity": "sha512-1cvg3Vz1dSSToCNlJfRA2WSI4ht3K+WplO0UMOgmUYPivCyy2oueZY6Lx7M9wThm7SDUBViRmuT+OG/i8+ON9A==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@codemirror/commands": { + "version": "6.10.3", + "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.3.tgz", + "integrity": "sha512-JFRiqhKu+bvSkDLI+rUhJwSxQxYb759W5GBezE8Uc8mHLqC9aV/9aTC7yJSqCtB3F00pylrLCwnyS91Ap5ej4Q==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.6.0", + "@codemirror/view": "^6.27.0", + "@lezer/common": "^1.1.0" + } + }, + "node_modules/@codemirror/language": { + "version": "6.12.3", + "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.12.3.tgz", + "integrity": "sha512-QwCZW6Tt1siP37Jet9Tb02Zs81TQt6qQrZR2H+eGMcFsL1zMrk2/b9CLC7/9ieP1fjIUMgviLWMmgiHoJrj+ZA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.23.0", + "@lezer/common": "^1.5.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0", + "style-mod": "^4.0.0" + } + }, + "node_modules/@codemirror/lint": { + "version": "6.9.5", + "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.5.tgz", + "integrity": "sha512-GElsbU9G7QT9xXhpUg1zWGmftA/7jamh+7+ydKRuT0ORpWS3wOSP0yT1FOlIZa7mIJjpVPipErsyvVqB9cfTFA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.35.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/search": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.6.0.tgz", + "integrity": "sha512-koFuNXcDvyyotWcgOnZGmY7LZqEOXZaaxD/j6n18TCLx2/9HieZJ5H6hs1g8FiRxBD0DNfs0nXn17g872RmYdw==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.37.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/state": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.6.0.tgz", + "integrity": "sha512-4nbvra5R5EtiCzr9BTHiTLc+MLXK2QGiAVYMyi8PkQd3SR+6ixar/Q/01Fa21TBIDOZXgeWV4WppsQolSreAPQ==", + "license": "MIT", + "dependencies": { + "@marijn/find-cluster-break": "^1.0.0" + } + }, + "node_modules/@codemirror/view": { + "version": "6.41.0", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.41.0.tgz", + "integrity": "sha512-6H/qadXsVuDY219Yljhohglve8xf4B8xJkVOEWfA5uiYKiTFppjqsvsfR5iPA0RbvRBoOyTZpbLIxe9+0UR8xA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.6.0", + "crelt": "^1.0.6", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, "node_modules/@esbuild/aix-ppc64": { "version": "0.25.12", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz", @@ -527,6 +615,36 @@ "integrity": "sha512-M5UknZPHRu3DEDWoipU6sE8PdkZ6Z/S+v4dD+Ke8IaNlpdSQah50lz1KtcFBa2vsdOnwbbnxJwVM4wty6udA5w==", "license": "MIT" }, + "node_modules/@lezer/common": { + "version": "1.5.2", + "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.5.2.tgz", + "integrity": "sha512-sxQE460fPZyU3sdc8lafxiPwJHBzZRy/udNFynGQky1SePYBdhkBl1kOagA9uT3pxR8K09bOrmTUqA9wb/PjSQ==", + "license": "MIT" + }, + "node_modules/@lezer/highlight": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz", + "integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.3.0" + } + }, + "node_modules/@lezer/lr": { + "version": "1.4.8", + "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.8.tgz", + "integrity": "sha512-bPWa0Pgx69ylNlMlPvBPryqeLYQjyJjqPx+Aupm5zydLIF3NE+6MMLT8Yi23Bd9cif9VS00aUebn+6fDIGBcDA==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@marijn/find-cluster-break": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", + "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==", + "license": "MIT" + }, "node_modules/@primeuix/styled": { "version": "0.7.4", "resolved": "https://registry.npmjs.org/@primeuix/styled/-/styled-0.7.4.tgz", @@ -1377,6 +1495,21 @@ "pnpm": ">=8" } }, + "node_modules/codemirror": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz", + "integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==", + "license": "MIT", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, "node_modules/copy-anything": { "version": "4.0.5", "resolved": "https://registry.npmjs.org/copy-anything/-/copy-anything-4.0.5.tgz", @@ -1392,6 +1525,12 @@ "url": "https://github.com/sponsors/mesqueeb" } }, + "node_modules/crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==", + "license": "MIT" + }, "node_modules/csstype": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", @@ -2073,6 +2212,12 @@ "node": ">=0.10.0" } }, + "node_modules/style-mod": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz", + "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==", + "license": "MIT" + }, "node_modules/superjson": { "version": "2.2.6", "resolved": "https://registry.npmjs.org/superjson/-/superjson-2.2.6.tgz", @@ -2269,6 +2414,12 @@ "peerDependencies": { "typescript": ">=5.0.0" } + }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", + "license": "MIT" } } } diff --git a/packages/web/frontend/package.json b/packages/web/frontend/package.json index e3b4a4b..f328c57 100644 --- a/packages/web/frontend/package.json +++ b/packages/web/frontend/package.json @@ -9,10 +9,17 @@ "preview": "vite preview" }, "dependencies": { + "@codemirror/autocomplete": "^6.20.1", + "@codemirror/commands": "^6.10.3", + "@codemirror/language": "^6.12.3", + "@codemirror/search": "^6.6.0", + "@codemirror/state": "^6.6.0", + "@codemirror/view": "^6.41.0", "@primevue/themes": "^4.0", "@tanstack/vue-query": "^5.0", "@vueuse/core": "^12.0", "chart.js": "^4.5.1", + "codemirror": "^6.0.2", "force-graph": "^1.51.2", "pinia": "^3.0", "primeicons": "^7.0", diff --git a/packages/web/frontend/src/components/CypherEditor.vue b/packages/web/frontend/src/components/CypherEditor.vue new file mode 100644 index 0000000..e4b0977 --- /dev/null +++ b/packages/web/frontend/src/components/CypherEditor.vue @@ -0,0 +1,165 @@ + + + + + + Run (Ctrl+Enter) + + + + + + + + diff --git a/packages/web/frontend/src/components/ForceGraphCanvas.vue b/packages/web/frontend/src/components/ForceGraphCanvas.vue index f20c510..152d71f 100644 --- a/packages/web/frontend/src/components/ForceGraphCanvas.vue +++ b/packages/web/frontend/src/components/ForceGraphCanvas.vue @@ -39,11 +39,13 @@ const props = withDefaults(defineProps<{ data: GraphData selectedNodeId: string | null selectedLinkId: string | null + highlightedNodeIds?: string[] timeRange?: { start: Date; end: Date } | null layoutMode?: 'force' | 'killchain' colorMode?: 'severity' | 'engagement' engagementColors?: Record }>(), { + highlightedNodeIds: () => [], timeRange: null, layoutMode: 'force', colorMode: 'severity', @@ -150,6 +152,19 @@ function initGraph() { ? (props.engagementColors[n.engagement_id] || '#95a5a6') : (SEVERITY_COLORS[n.severity] || '#95a5a6') const isSelected = n.id === props.selectedNodeId + const isHighlighted = props.highlightedNodeIds.includes(n.id) + + // Highlight glow (from query results) + if (isHighlighted) { + ctx.save() + ctx.shadowColor = '#FFD700' + ctx.shadowBlur = 8 / globalScale + ctx.beginPath() + ctx.arc(node.x, node.y, radius + 2 / globalScale, 0, 2 * Math.PI) + ctx.fillStyle = 'rgba(255, 215, 0, 0.3)' + ctx.fill() + ctx.restore() + } // Pivotality glow if (n.pivotality && n.pivotality > 0.1) { @@ -432,6 +447,12 @@ watch(() => props.data, (newData) => { } }, { deep: true }) +watch(() => props.highlightedNodeIds, () => { + if (graph) { + graph.refresh() + } +}) + watch(() => props.layoutMode, (mode) => { if (mode === 'killchain') { applyKillChainLayout() diff --git a/packages/web/frontend/src/components/InlineQueryPanel.vue b/packages/web/frontend/src/components/InlineQueryPanel.vue new file mode 100644 index 0000000..ed7960f --- /dev/null +++ b/packages/web/frontend/src/components/InlineQueryPanel.vue @@ -0,0 +1,148 @@ + + + + + {{ expanded ? 'Hide Query' : 'Query' }} + + + + + + {{ loading ? 'Running...' : 'Run (Ctrl+Enter)' }} + + + {{ error }} + + + {{ result.stats.rows_returned }} rows, {{ result.stats.duration_ms.toFixed(1) }}ms + + + + + {{ col }} + + + + + {{ formatCell(row[col]) }} + + + + + ... {{ result.rows.length - 20 }} more rows + + No results + + + + + + + + diff --git a/packages/web/frontend/src/components/QueryResultsPane.vue b/packages/web/frontend/src/components/QueryResultsPane.vue new file mode 100644 index 0000000..8709956 --- /dev/null +++ b/packages/web/frontend/src/components/QueryResultsPane.vue @@ -0,0 +1,57 @@ + + + + Running query... + {{ error }} + + + {{ result.stats.rows_returned }} rows, {{ result.stats.duration_ms.toFixed(1) }}ms + (truncated) + + + + + {{ col }} + + + + + {{ formatCell(row[col]) }} + + + + No results + + + + + + + diff --git a/packages/web/frontend/src/router/index.ts b/packages/web/frontend/src/router/index.ts index b0042de..60d0ab7 100644 --- a/packages/web/frontend/src/router/index.ts +++ b/packages/web/frontend/src/router/index.ts @@ -12,6 +12,7 @@ const router = createRouter({ { path: '/engagements/:id', name: 'engagement-detail', component: () => import('@/views/EngagementDetailView.vue') }, { path: '/findings/:id', name: 'finding-detail', component: () => import('@/views/FindingDetailView.vue') }, { path: '/engagements/:id/chain', name: 'engagement-chain', component: () => import('@/views/ChainGraphView.vue') }, + { path: '/chain/query', name: 'chain-query', component: () => import('@/views/ChainQueryView.vue') }, { path: '/chain/global', name: 'chain-global', component: () => import('@/views/GlobalChainView.vue') }, { path: '/recipes', name: 'recipes', component: () => import('@/views/RecipeListView.vue') }, { path: '/recipes/:id/run', name: 'recipe-run', component: () => import('@/views/RecipeRunnerView.vue') }, diff --git a/packages/web/frontend/src/views/ChainGraphView.vue b/packages/web/frontend/src/views/ChainGraphView.vue index 13c9913..8c5cdbe 100644 --- a/packages/web/frontend/src/views/ChainGraphView.vue +++ b/packages/web/frontend/src/views/ChainGraphView.vue @@ -11,6 +11,7 @@ import ChainDetailPanel from '@/components/ChainDetailPanel.vue' import ChainFilterToolbar from '@/components/ChainFilterToolbar.vue' import ChainLegend from '@/components/ChainLegend.vue' import ChainEmptyState from '@/components/ChainEmptyState.vue' +import InlineQueryPanel from '@/components/InlineQueryPanel.vue' import ChainTimelineScrubber from '@/components/ChainTimelineScrubber.vue' const route = useRoute() @@ -209,6 +210,12 @@ const { data: engagement } = useQuery({ queryFn: () => fetch(`/api/v1/engagements/${engId}`, { credentials: 'include' }).then(r => r.json()), }) + +const highlightedNodeIds = ref([]) + +function onQueryHighlight(nodeIds: string[]) { + highlightedNodeIds.value = nodeIds +} @@ -238,11 +245,12 @@ const { data: engagement } = useQuery({ - + + diff --git a/packages/web/frontend/src/views/ChainQueryView.vue b/packages/web/frontend/src/views/ChainQueryView.vue new file mode 100644 index 0000000..50039c9 --- /dev/null +++ b/packages/web/frontend/src/views/ChainQueryView.vue @@ -0,0 +1,64 @@ + + + + Chain Query + + + + + + + + + + + +
Run chain analysis on individual engagements first.