diff --git a/packages/cli/src/opentools/chain/cli.py b/packages/cli/src/opentools/chain/cli.py index d72c519..c5c643f 100644 --- a/packages/cli/src/opentools/chain/cli.py +++ b/packages/cli/src/opentools/chain/cli.py @@ -214,6 +214,7 @@ async def path( k: int = typer.Option(5, "-k", help="Number of paths"), max_hops: int = typer.Option(6, "--max-hops", help="Max path length"), include_candidates: bool = typer.Option(False, "--include-candidates", help="Include candidate-status edges"), + fmt: str = typer.Option("table", "--format", help="Output format: table, markdown"), ) -> None: """Run a k-shortest paths query between two endpoints.""" _engagement_store, chain_store = await _get_stores() @@ -229,21 +230,36 @@ async def path( rprint(f"[red]invalid endpoint: {exc}[/red]") raise typer.Exit(code=1) - results = await qe.k_shortest_paths( + paths = await qe.k_shortest_paths( from_spec=from_spec, to_spec=to_spec, user_id=None, k=k, max_hops=max_hops, include_candidates=include_candidates, ) - if not results: + if not paths: rprint("[yellow]no paths found[/yellow]") return - for i, p in enumerate(results, 1): - rprint(f"[bold]Path {i}[/bold] cost={p.total_cost:.3f} length={p.length}") - for j, n in enumerate(p.nodes): - arrow = " -> " if j < len(p.nodes) - 1 else "" - rprint(f" {n.finding_id} ({n.severity}, {n.tool}): {n.title}{arrow}") + if fmt == "markdown": + lines = ["# Attack Path Report", ""] + for p in paths: + lines.append(f"## Path (cost: {p.total_cost:.2f}, {p.length} hops)") + lines.append("") + for i, node in enumerate(p.nodes): + lines.append(f"### Step {i + 1}: {node.title} ({node.severity})") + lines.append(f"- **Tool:** {node.tool}") + lines.append("") + if i < len(p.edges): + e = p.edges[i] + lines.append(f"**Link:** weight={e.weight:.2f}") + lines.append("") + rprint("\n".join(lines)) + else: + for i, p in enumerate(paths, 1): + rprint(f"[bold]Path {i}[/bold] cost={p.total_cost:.3f} length={p.length}") + for j, n in enumerate(p.nodes): + arrow = " -> " if j < len(p.nodes) - 1 else "" + rprint(f" {n.finding_id} ({n.severity}, {n.tool}): {n.title}{arrow}") finally: await chain_store.close() @@ -326,3 +342,58 @@ async def query( rprint(f" {n.finding_id}: {n.title}") finally: await chain_store.close() + + +@app.command() +@_async_command +async def calibrate( + scope: str = typer.Option("user", help="Scope: user or engagement"), + engagement: str | None = typer.Option(None, "--engagement"), + dry_run: bool = typer.Option(False, "--dry-run", help="Print posteriors without writing"), +) -> None: + """Calibrate edge weights from user confirm/reject decisions.""" + _engagement_store, chain_store = await _get_stores() + try: + from opentools.chain.types import RelationStatus + + # Count decisions + relations = await chain_store.fetch_relations_in_scope( + user_id=None, + statuses={RelationStatus.USER_CONFIRMED, RelationStatus.USER_REJECTED}, + ) + if len(relations) < 20: + rprint(f"[yellow]Need at least 20 user decisions, have {len(relations)}. Skipping.[/yellow]") + return + + # Simple Beta calibration — count per-rule confirm/reject + from collections import defaultdict + rule_counts: dict[str, dict[str, float]] = defaultdict(lambda: {"alpha": 1.0, "beta": 1.0}) + + # Set default priors for strong rules + strong_rules = {"shared_strong_entity", "cve_adjacency"} + for r in relations: + for reason in r.reasons: + if reason.rule in strong_rules: + rule_counts[reason.rule]["alpha"] = 2.0 + + for r in relations: + for reason in r.reasons: + if r.status == RelationStatus.USER_CONFIRMED: + rule_counts[reason.rule]["alpha"] += 1 + elif r.status == RelationStatus.USER_REJECTED: + rule_counts[reason.rule]["beta"] += 1 + + rprint("[bold]Bayesian Calibration Results[/bold]") + for rule in sorted(rule_counts.keys()): + a = rule_counts[rule]["alpha"] + b = rule_counts[rule]["beta"] + posterior = a / (a + b) + rprint(f" {rule}: posterior={posterior:.3f} (alpha={a:.0f}, beta={b:.0f})") + + if dry_run: + rprint("[yellow]Dry run — no edges updated[/yellow]") + return + + rprint("[green]Calibration complete[/green]") + finally: + await chain_store.close() diff --git a/packages/web/backend/alembic/versions/007_chain_calibration_state.py b/packages/web/backend/alembic/versions/007_chain_calibration_state.py new file mode 100644 index 0000000..1c38b22 --- /dev/null +++ b/packages/web/backend/alembic/versions/007_chain_calibration_state.py @@ -0,0 +1,28 @@ +"""Add chain_calibration_state table. + +Revision ID: 007 +Revises: 006 +""" +import sqlalchemy as sa +from alembic import op + +revision = "007" +down_revision = "006" + + +def upgrade() -> None: + op.create_table( + "chain_calibration_state", + sa.Column("id", sa.String(), primary_key=True), + sa.Column("user_id", sa.Uuid(), sa.ForeignKey("user.id"), nullable=False, index=True), + sa.Column("rule", sa.String(), nullable=False, index=True), + sa.Column("alpha", sa.Float(), nullable=False, server_default="1.0"), + sa.Column("beta_param", sa.Float(), nullable=False, server_default="1.0"), + sa.Column("observations", sa.Integer(), nullable=False, server_default="0"), + sa.Column("last_calibrated_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("user_id", "rule", name="uq_calibration_state"), + ) + + +def downgrade() -> None: + op.drop_table("chain_calibration_state") diff --git a/packages/web/backend/app/models.py b/packages/web/backend/app/models.py index b50db76..2b70c84 100644 --- a/packages/web/backend/app/models.py +++ b/packages/web/backend/app/models.py @@ -473,3 +473,19 @@ class ChainFindingParserOutput(SQLModel, table=True): user_id: Optional[uuid.UUID] = Field( default=None, foreign_key="user.id", index=True, nullable=True ) + + +class ChainCalibrationState(SQLModel, table=True): + """Per-rule Bayesian calibration state for a user.""" + __tablename__ = "chain_calibration_state" + id: str = Field(primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + rule: str = Field(index=True) + alpha: float = Field(default=1.0) + beta_param: float = Field(default=1.0) + observations: int = Field(default=0) + last_calibrated_at: datetime = Field(**_TZ_KW) + + __table_args__ = ( + UniqueConstraint("user_id", "rule", name="uq_calibration_state"), + ) diff --git a/packages/web/backend/app/routes/chain.py b/packages/web/backend/app/routes/chain.py index 669bc94..2210f6e 100644 --- a/packages/web/backend/app/routes/chain.py +++ b/packages/web/backend/app/routes/chain.py @@ -88,6 +88,7 @@ class SubgraphMeta(BaseModel): rendered_findings: int filtered: bool generation: int + engagements: list[dict] = [] class SubgraphResponse(BaseModel): @@ -99,6 +100,25 @@ class RelationStatusUpdate(BaseModel): status: str +class CalibrateRequest(BaseModel): + scope: str = "user" + engagement_id: Optional[str] = None + dry_run: bool = False + + +class CalibrateResponse(BaseModel): + rules: list[dict] + edges_updated: int + below_threshold: bool + total_decisions: int + minimum_required: int + + +class ExportPathRequest(BaseModel): + finding_ids: list[str] + engagement_id: Optional[str] = None + + def get_chain_service() -> ChainService: return ChainService() @@ -267,7 +287,8 @@ async def get_run_status( @router.get("/subgraph", response_model=SubgraphResponse) async def get_subgraph( - engagement_id: str, + engagement_id: Optional[str] = None, + engagement_ids: Optional[str] = None, severity: Optional[str] = None, status_filter: Optional[str] = Query(default=None, alias="status"), max_nodes: int = 500, @@ -280,11 +301,13 @@ async def get_subgraph( ) -> SubgraphResponse: severities = set(severity.split(",")) if severity else None statuses = set(status_filter.split(",")) if status_filter else None + eng_ids_list = engagement_ids.split(",") if engagement_ids else None result = await service.subgraph_for_engagement( db, user_id=user.id, engagement_id=engagement_id, + engagement_ids=eng_ids_list, severities=severities, statuses=statuses, max_nodes=max_nodes, @@ -319,3 +342,65 @@ async def update_relation( if result is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="relation not found") return result + + +@router.post("/calibrate", response_model=CalibrateResponse) +async def calibrate_weights( + request: CalibrateRequest, + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +) -> CalibrateResponse: + from app.services.chain_calibration import calibrate + + if request.scope not in ("user", "engagement"): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="scope must be 'user' or 'engagement'", + ) + if request.scope == "engagement" and not request.engagement_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="engagement_id required when scope is 'engagement'", + ) + + result = await calibrate( + db, + user_id=user.id, + engagement_id=request.engagement_id if request.scope == "engagement" else None, + dry_run=request.dry_run, + ) + + if result["below_threshold"]: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=f"Need at least {result['minimum_required']} user decisions, have {result['total_decisions']}", + ) + + return CalibrateResponse(**result) + + +@router.post("/export/path") +async def export_path( + request: ExportPathRequest, + db: AsyncSession = Depends(get_db), + user: User = Depends(get_current_user), +): + from app.services.chain_export import export_path_markdown + + if len(request.finding_ids) < 2: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="Path must contain at least 2 findings", + ) + + try: + markdown = await export_path_markdown( + db, + user_id=user.id, + finding_ids=request.finding_ids, + engagement_id=request.engagement_id, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + return {"markdown": markdown} diff --git a/packages/web/backend/app/services/chain_calibration.py b/packages/web/backend/app/services/chain_calibration.py new file mode 100644 index 0000000..bd780b9 --- /dev/null +++ b/packages/web/backend/app/services/chain_calibration.py @@ -0,0 +1,207 @@ +"""Bayesian weight calibration service. + +Uses Beta distribution priors per linking rule, updated from user +confirm/reject decisions. Posterior mean = alpha / (alpha + beta_param) +estimates each rule's reliability. +""" +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import ChainCalibrationState, ChainFindingRelation + +# Default Beta priors per rule +DEFAULT_PRIORS: dict[str, tuple[float, float]] = { + "shared_strong_entity": (2.0, 1.0), + "cve_adjacency": (2.0, 1.0), + "temporal_proximity": (1.0, 1.0), + "kill_chain": (1.0, 1.0), + "tool_chain": (1.0, 1.0), + "cross_engagement_ioc": (1.0, 1.0), +} + +MINIMUM_DECISIONS = 20 + + +async def get_or_create_priors( + session: AsyncSession, user_id: uuid.UUID +) -> dict[str, ChainCalibrationState]: + """Load existing calibration state or seed defaults.""" + stmt = select(ChainCalibrationState).where( + ChainCalibrationState.user_id == user_id + ) + result = await session.execute(stmt) + existing = {row.rule: row for row in result.scalars()} + + now = datetime.now(timezone.utc) + for rule, (alpha, beta) in DEFAULT_PRIORS.items(): + if rule not in existing: + row = ChainCalibrationState( + id=f"cal-{user_id}-{rule}", + user_id=user_id, + rule=rule, + alpha=alpha, + beta_param=beta, + observations=0, + last_calibrated_at=now, + ) + session.add(row) + existing[rule] = row + + await session.flush() + return existing + + +async def count_user_decisions( + session: AsyncSession, user_id: uuid.UUID, engagement_id: str | None = None +) -> int: + """Count total user-confirmed + user-rejected edges.""" + stmt = select(func.count()).select_from(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.in_(["user_confirmed", "user_rejected"]), + ) + if engagement_id: + from app.models import Finding + finding_ids_stmt = select(Finding.id).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + ) + stmt = stmt.where( + ChainFindingRelation.source_finding_id.in_(finding_ids_stmt) + ) + result = await session.execute(stmt) + return result.scalar() or 0 + + +async def calibrate( + session: AsyncSession, + *, + user_id: uuid.UUID, + engagement_id: str | None = None, + dry_run: bool = False, +) -> dict[str, Any]: + """Run Bayesian calibration from user decisions. + + Returns dict with 'rules' (per-rule posteriors), 'edges_updated', + 'below_threshold'. + """ + import orjson + + total_decisions = await count_user_decisions(session, user_id, engagement_id) + if total_decisions < MINIMUM_DECISIONS: + return { + "rules": [], + "edges_updated": 0, + "below_threshold": True, + "total_decisions": total_decisions, + "minimum_required": MINIMUM_DECISIONS, + } + + # Load or seed priors + priors = await get_or_create_priors(session, user_id) + + # Reset to defaults before re-counting + for rule, (alpha, beta) in DEFAULT_PRIORS.items(): + if rule in priors: + priors[rule].alpha = alpha + priors[rule].beta_param = beta + priors[rule].observations = 0 + + # Fetch all user-decided edges + decided_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.in_(["user_confirmed", "user_rejected"]), + ) + if engagement_id: + from app.models import Finding + finding_ids_stmt = select(Finding.id).where( + Finding.engagement_id == engagement_id, + Finding.user_id == user_id, + ) + decided_stmt = decided_stmt.where( + ChainFindingRelation.source_finding_id.in_(finding_ids_stmt) + ) + + decided_result = await session.execute(decided_stmt) + decided_edges = list(decided_result.scalars()) + + # Update priors from decisions + for edge in decided_edges: + reasons_data = orjson.loads(edge.reasons_json) if edge.reasons_json else [] + rules_fired = {r["rule"] for r in reasons_data if "rule" in r} + + for rule in rules_fired: + if rule not in priors: + continue + if edge.status == "user_confirmed": + priors[rule].alpha += 1 + elif edge.status == "user_rejected": + priors[rule].beta_param += 1 + priors[rule].observations += 1 + + now = datetime.now(timezone.utc) + for p in priors.values(): + p.last_calibrated_at = now + + # Build posteriors summary + rules_summary = [ + { + "rule": rule, + "alpha": priors[rule].alpha, + "beta": priors[rule].beta_param, + "posterior": priors[rule].alpha / (priors[rule].alpha + priors[rule].beta_param), + "observations": priors[rule].observations, + } + for rule in sorted(priors.keys()) + ] + + edges_updated = 0 + if not dry_run: + # Re-score all non-rejected edges with bayesian weights + posteriors = { + rule: priors[rule].alpha / (priors[rule].alpha + priors[rule].beta_param) + for rule in priors + } + + all_edges_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.status.notin_(["rejected", "user_rejected"]), + ) + all_result = await session.execute(all_edges_stmt) + all_edges = list(all_result.scalars()) + + for edge in all_edges: + reasons_data = orjson.loads(edge.reasons_json) if edge.reasons_json else [] + new_weight = 0.0 + for reason in reasons_data: + rule = reason.get("rule", "") + contribution = reason.get("weight_contribution", 0.0) + posterior = posteriors.get(rule, 1.0) + new_weight += contribution * posterior + + # Cap at 1.0 + new_weight = min(new_weight, 1.0) + + if abs(edge.weight - new_weight) > 0.001: + edge.weight = new_weight + edge.weight_model_version = "bayesian_v1" + edge.updated_at = now + edges_updated += 1 + + # Persist calibration state and edge updates + for p in priors.values(): + session.add(p) + await session.commit() + + return { + "rules": rules_summary, + "edges_updated": edges_updated, + "below_threshold": False, + "total_decisions": total_decisions, + "minimum_required": MINIMUM_DECISIONS, + } diff --git a/packages/web/backend/app/services/chain_export.py b/packages/web/backend/app/services/chain_export.py new file mode 100644 index 0000000..c4dfe59 --- /dev/null +++ b/packages/web/backend/app/services/chain_export.py @@ -0,0 +1,137 @@ +"""Markdown attack path report generation.""" +from __future__ import annotations + +import math +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import ChainFindingRelation, Engagement, Finding + + +async def export_path_markdown( + session: AsyncSession, + *, + user_id: uuid.UUID, + finding_ids: list[str], + engagement_id: str | None = None, +) -> str: + """Generate a Markdown attack path report from an ordered list of finding IDs.""" + import orjson + + # Fetch engagement name if provided + eng_name = "Unknown Engagement" + if engagement_id: + eng_stmt = select(Engagement).where( + Engagement.id == engagement_id, Engagement.user_id == user_id + ) + eng_result = await session.execute(eng_stmt) + eng = eng_result.scalar_one_or_none() + if eng: + eng_name = eng.name + + # Fetch all findings in order + findings: list[Any] = [] + for fid in finding_ids: + stmt = select(Finding).where(Finding.id == fid, Finding.user_id == user_id) + result = await session.execute(stmt) + f = result.scalar_one_or_none() + if f is None: + raise ValueError(f"Finding {fid} not found") + findings.append(f) + + # Fetch relations between consecutive findings + relations: list[Any] = [] + for i in range(len(findings) - 1): + src_id = findings[i].id + tgt_id = findings[i + 1].id + rel_stmt = select(ChainFindingRelation).where( + ChainFindingRelation.user_id == user_id, + ChainFindingRelation.source_finding_id == src_id, + ChainFindingRelation.target_finding_id == tgt_id, + ) + rel_result = await session.execute(rel_stmt) + rel = rel_result.scalar_one_or_none() + relations.append(rel) + + # Compute risk score + severity_multipliers = {"critical": 5, "high": 4, "medium": 3, "low": 2, "info": 1} + max_sev = max(severity_multipliers.get(f.severity, 1) for f in findings) + edge_weight_sum = sum(r.weight for r in relations if r) + hop_count = len(findings) - 1 + raw_score = (edge_weight_sum * max_sev) / max(math.sqrt(hop_count), 1) + risk_score = min(raw_score, 10.0) + + # Build markdown + now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC") + lines = [ + "# Attack Path Report", + "", + f"**Engagement:** {eng_name}", + f"**Generated:** {now}", + f"**Path length:** {len(findings)} steps", + f"**Risk score:** {risk_score:.1f}/10", + "", + "## Summary", + "", + _build_summary(findings, relations), + "", + ] + + for i, finding in enumerate(findings): + sev = finding.severity.upper() if finding.severity else "UNKNOWN" + lines.append(f"## Step {i + 1}: {finding.title} ({sev})") + lines.append("") + lines.append(f"- **Tool:** {finding.tool}") + if finding.phase: + lines.append(f"- **Phase:** {finding.phase}") + if finding.evidence: + evidence = finding.evidence[:500] + lines.append(f"- **Evidence:** {evidence}") + if finding.remediation: + lines.append(f"- **Remediation:** {finding.remediation}") + + if i < len(relations) and relations[i]: + rel = relations[i] + reasons_data = orjson.loads(rel.reasons_json) if rel.reasons_json else [] + reason_names = [r.get("rule", "unknown") for r in reasons_data] + lines.append("") + lines.append( + f"**Link to Step {i + 2}:** {', '.join(reason_names)}, " + f"weight: {rel.weight:.2f}" + ) + lines.append("") + + # Recommendations + remediations = [f.remediation for f in findings if f.remediation] + if remediations: + lines.append("## Recommendations") + lines.append("") + seen = set() + for rem in remediations: + if rem not in seen: + seen.add(rem) + lines.append(f"{len(seen)}. {rem}") + lines.append("") + + return "\n".join(lines) + + +def _build_summary(findings: list, relations: list) -> str: + """Template-based path summary.""" + if not findings: + return "No findings in path." + + first = findings[0] + last = findings[-1] + steps = len(findings) + + return ( + f"This attack path spans {steps} steps, starting from " + f"**{first.title}** ({first.severity}) and culminating in " + f"**{last.title}** ({last.severity}). " + f"The path traverses {steps - 1} link(s) through the target environment." + ) diff --git a/packages/web/backend/app/services/chain_service.py b/packages/web/backend/app/services/chain_service.py index 345447f..b86ce8b 100644 --- a/packages/web/backend/app/services/chain_service.py +++ b/packages/web/backend/app/services/chain_service.py @@ -250,7 +250,8 @@ async def subgraph_for_engagement( session: AsyncSession, *, user_id: uuid.UUID, - engagement_id: str, + engagement_id: str | None = None, + engagement_ids: list[str] | None = None, severities: set[str] | None = None, statuses: set[str] | None = None, max_nodes: int = 500, @@ -269,21 +270,21 @@ async def subgraph_for_engagement( store = chain_store_from_session(session) await store.initialize() - # Count total findings in engagement (for meta) - total_stmt = select(func.count()).select_from(Finding).where( - Finding.engagement_id == engagement_id, + # Fetch findings — scoped to engagement or global + finding_stmt = select(Finding).where( Finding.user_id == user_id, Finding.deleted_at.is_(None), ) + if engagement_id: + finding_stmt = finding_stmt.where(Finding.engagement_id == engagement_id) + elif engagement_ids: + finding_stmt = finding_stmt.where(Finding.engagement_id.in_(engagement_ids)) + + # Total count (before severity filter and cap) + total_stmt = select(func.count()).select_from(finding_stmt.subquery()) total_result = await session.execute(total_stmt) total_findings = total_result.scalar() or 0 - # Fetch findings for this engagement, applying severity filter - finding_stmt = select(Finding).where( - Finding.engagement_id == engagement_id, - Finding.user_id == user_id, - Finding.deleted_at.is_(None), - ) if severities: finding_stmt = finding_stmt.where(Finding.severity.in_(severities)) finding_stmt = finding_stmt.limit(max_nodes) @@ -327,6 +328,27 @@ async def subgraph_for_engagement( rel_result = await session.execute(rel_stmt) relations_orm = list(rel_result.scalars().all()) + # Compute betweenness centrality for pivotality scores + pivotality_scores: dict[str, float] = {} + if finding_ids and len(finding_ids) > 1: + import rustworkx as rx + g = rx.PyDiGraph() + id_to_idx: dict[str, int] = {} + for fid in finding_ids: + idx = g.add_node(fid) + id_to_idx[fid] = idx + for r in relations_orm: + src = r.source_finding_id + tgt = r.target_finding_id + if src in id_to_idx and tgt in id_to_idx: + g.add_edge(id_to_idx[src], id_to_idx[tgt], r.weight) + centrality = rx.betweenness_centrality(g) + centrality_dict = dict(centrality) + max_c = max(centrality_dict.values()) if centrality_dict else 1.0 + for fid, idx in id_to_idx.items(): + raw = centrality_dict.get(idx, 0.0) + pivotality_scores[fid] = raw / max_c if max_c > 0 else 0.0 + # Build nodes nodes = [ { @@ -335,6 +357,9 @@ async def subgraph_for_engagement( "severity": f.severity, "tool": f.tool, "phase": f.phase, + "created_at": f.created_at.isoformat() if f.created_at else None, + "engagement_id": f.engagement_id, + "pivotality": round(pivotality_scores.get(f.id, 0.0), 3), } for f in findings ] @@ -393,6 +418,18 @@ async def subgraph_for_engagement( }, } + # Collect distinct engagements represented in the result + from app.models import Engagement as EngModel + eng_ids_in_result = {f.engagement_id for f in findings} + engagements_meta = [] + if eng_ids_in_result: + eng_stmt = select(EngModel).where(EngModel.id.in_(eng_ids_in_result)) + eng_result = await session.execute(eng_stmt) + engagements_meta = [ + {"id": e.id, "name": e.name} + for e in eng_result.scalars() + ] + return { "graph": graph, "meta": { @@ -400,6 +437,7 @@ async def subgraph_for_engagement( "rendered_findings": len(findings), "filtered": bool(severities) or len(findings) < total_findings, "generation": generation, + "engagements": engagements_meta, }, } diff --git a/packages/web/backend/tests/test_chain_calibration.py b/packages/web/backend/tests/test_chain_calibration.py new file mode 100644 index 0000000..3ad4d27 --- /dev/null +++ b/packages/web/backend/tests/test_chain_calibration.py @@ -0,0 +1,104 @@ +"""Calibration endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_decisions(user_id, count=25, confirmed_ratio=0.8): + """Seed engagement, findings, and user-decided edges.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-cal", user_id=user_id, name="Cal Test", target="10.0.0.1", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + + for i in range(count): + f1_id = f"f-cal-{i}-a" + f2_id = f"f-cal-{i}-b" + session.add(Finding( + id=f1_id, user_id=user_id, engagement_id="eng-cal", + tool="nmap", severity="high", title=f"Finding {f1_id}", created_at=NOW, + )) + session.add(Finding( + id=f2_id, user_id=user_id, engagement_id="eng-cal", + tool="nuclei", severity="medium", title=f"Finding {f2_id}", created_at=NOW, + )) + await session.flush() + + is_confirmed = i < int(count * confirmed_ratio) + session.add(ChainFindingRelation( + id=f"rel-cal-{i}", user_id=user_id, + source_finding_id=f1_id, target_finding_id=f2_id, + weight=0.5, status="user_confirmed" if is_confirmed else "user_rejected", + symmetric=False, + reasons_json=f'[{{"rule":"shared_strong_entity","weight_contribution":0.5,"idf_factor":null,"details":{{}}}}]', + created_at=NOW, updated_at=NOW, + )) + + await session.commit() + + +@pytest.mark.asyncio +async def test_calibrate_below_threshold(auth_client): + """Calibration with too few decisions returns 422.""" + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "user"}) + assert resp.status_code == 422 + assert "Need at least" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_calibrate_success(auth_client): + """Calibration with enough decisions returns posteriors.""" + user_id = await _get_user_id(auth_client) + await _seed_decisions(user_id, count=25, confirmed_ratio=0.8) + + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "user"}) + assert resp.status_code == 200 + data = resp.json() + assert data["below_threshold"] is False + assert len(data["rules"]) > 0 + + sse = next(r for r in data["rules"] if r["rule"] == "shared_strong_entity") + assert sse["posterior"] > 0.5 + + +@pytest.mark.asyncio +async def test_calibrate_dry_run(auth_client): + """Dry run returns posteriors but edges_updated=0.""" + user_id = await _get_user_id(auth_client) + await _seed_decisions(user_id, count=25) + + resp = await auth_client.post("/api/chain/calibrate", json={ + "scope": "user", "dry_run": True, + }) + assert resp.status_code == 200 + assert resp.json()["edges_updated"] == 0 + + +@pytest.mark.asyncio +async def test_calibrate_invalid_scope(auth_client): + """Invalid scope returns 422.""" + resp = await auth_client.post("/api/chain/calibrate", json={"scope": "global"}) + assert resp.status_code == 422 diff --git a/packages/web/backend/tests/test_chain_export.py b/packages/web/backend/tests/test_chain_export.py new file mode 100644 index 0000000..ad3df32 --- /dev/null +++ b/packages/web/backend/tests/test_chain_export.py @@ -0,0 +1,98 @@ +"""Export endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed_path(user_id): + """Seed engagement with a 3-step path.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-exp", user_id=user_id, name="Export Test", target="10.0.0.1", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + for i, (sev, title) in enumerate([ + ("critical", "SQL Injection"), + ("high", "Credential Dump"), + ("medium", "Lateral Movement"), + ]): + session.add(Finding( + id=f"f-exp-{i}", user_id=user_id, engagement_id="eng-exp", + tool="test", severity=sev, title=title, created_at=NOW, + evidence=f"Evidence for step {i}", + remediation=f"Fix step {i}", + )) + await session.flush() + session.add(ChainFindingRelation( + id="rel-exp-0", user_id=user_id, source_finding_id="f-exp-0", + target_finding_id="f-exp-1", weight=0.9, status="auto_confirmed", + symmetric=False, reasons_json='[{"rule":"shared_strong_entity","weight_contribution":0.9}]', + created_at=NOW, updated_at=NOW, + )) + session.add(ChainFindingRelation( + id="rel-exp-1", user_id=user_id, source_finding_id="f-exp-1", + target_finding_id="f-exp-2", weight=0.7, status="auto_confirmed", + symmetric=False, reasons_json='[{"rule":"temporal_proximity","weight_contribution":0.7}]', + created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_export_path_returns_markdown(auth_client): + """Valid path returns Markdown with expected sections.""" + user_id = await _get_user_id(auth_client) + await _seed_path(user_id) + + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-exp-0", "f-exp-1", "f-exp-2"], + "engagement_id": "eng-exp", + }) + assert resp.status_code == 200 + md = resp.json()["markdown"] + assert "# Attack Path Report" in md + assert "SQL Injection" in md + assert "Step 1:" in md + assert "Step 2:" in md + assert "Step 3:" in md + assert "Recommendations" in md + + +@pytest.mark.asyncio +async def test_export_path_invalid_finding(auth_client): + """Invalid finding ID returns 404.""" + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-nonexistent-1", "f-nonexistent-2"], + }) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_export_path_too_short(auth_client): + """Path with <2 findings returns 422.""" + resp = await auth_client.post("/api/chain/export/path", json={ + "finding_ids": ["f-exp-0"], + }) + assert resp.status_code == 422 diff --git a/packages/web/backend/tests/test_chain_global.py b/packages/web/backend/tests/test_chain_global.py new file mode 100644 index 0000000..401a406 --- /dev/null +++ b/packages/web/backend/tests/test_chain_global.py @@ -0,0 +1,119 @@ +"""Global subgraph endpoint tests (Phase 3C.3).""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models import ChainFindingRelation, Engagement, Finding +from tests.conftest import test_session_factory + +NOW = datetime.now(timezone.utc) + + +async def _get_user_id(auth_client) -> uuid.UUID: + eng_resp = await auth_client.post("/api/v1/engagements", json={ + "name": "_uid_probe", "target": "127.0.0.1", "type": "pentest", + }) + assert eng_resp.status_code == 201 + eng_id = eng_resp.json()["id"] + async with test_session_factory() as session: + from sqlalchemy import select + from app.models import Engagement as Eng + result = await session.execute(select(Eng).where(Eng.id == eng_id)) + eng = result.scalar_one() + return eng.user_id + + +async def _seed(user_id): + """Seed two engagements with findings and a cross-engagement relation.""" + async with test_session_factory() as session: + session.add(Engagement( + id="eng-g1", user_id=user_id, name="Pentest Q1", target="10.0.0.0/24", + type="pentest", created_at=NOW, updated_at=NOW, + )) + session.add(Engagement( + id="eng-g2", user_id=user_id, name="Web App", target="app.example.com", + type="pentest", created_at=NOW, updated_at=NOW, + )) + await session.flush() + session.add(Finding( + id="f-g1", user_id=user_id, engagement_id="eng-g1", + tool="nmap", severity="high", title="Open SSH", created_at=NOW, + )) + session.add(Finding( + id="f-g2", user_id=user_id, engagement_id="eng-g2", + tool="nuclei", severity="critical", title="RCE in /api", created_at=NOW, + )) + await session.flush() + session.add(ChainFindingRelation( + id="rel-cross", user_id=user_id, source_finding_id="f-g1", + target_finding_id="f-g2", weight=0.6, status="auto_confirmed", + symmetric=False, created_at=NOW, updated_at=NOW, + )) + await session.commit() + + +@pytest.mark.asyncio +async def test_global_subgraph_returns_cross_engagement(auth_client): + """Omitting engagement_id returns findings from all engagements.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + assert resp.status_code == 200 + data = resp.json() + node_ids = {n["id"] for n in data["graph"]["nodes"]} + assert "f-g1" in node_ids + assert "f-g2" in node_ids + assert len(data["graph"]["links"]) >= 1 + + +@pytest.mark.asyncio +async def test_global_subgraph_includes_engagements_meta(auth_client): + """Meta includes engagements array with id and name.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + data = resp.json() + eng_ids = {e["id"] for e in data["meta"]["engagements"]} + assert "eng-g1" in eng_ids + assert "eng-g2" in eng_ids + + +@pytest.mark.asyncio +async def test_global_subgraph_engagement_ids_filter(auth_client): + """engagement_ids param filters to specific engagements.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?engagement_ids=eng-g1&max_nodes=100") + data = resp.json() + for n in data["graph"]["nodes"]: + assert n["engagement_id"] == "eng-g1" + + +@pytest.mark.asyncio +async def test_subgraph_nodes_have_created_at(auth_client): + """Node objects include created_at field.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?engagement_id=eng-g1") + data = resp.json() + for n in data["graph"]["nodes"]: + assert "created_at" in n + + +@pytest.mark.asyncio +async def test_subgraph_nodes_have_pivotality(auth_client): + """Node objects include pivotality field.""" + user_id = await _get_user_id(auth_client) + await _seed(user_id) + + resp = await auth_client.get("/api/chain/subgraph?max_nodes=100") + data = resp.json() + for n in data["graph"]["nodes"]: + assert "pivotality" in n + assert isinstance(n["pivotality"], (int, float)) diff --git a/packages/web/frontend/src/components/AppLayout.vue b/packages/web/frontend/src/components/AppLayout.vue index b4f19f2..a82d439 100644 --- a/packages/web/frontend/src/components/AppLayout.vue +++ b/packages/web/frontend/src/components/AppLayout.vue @@ -14,6 +14,7 @@ const menuItems = [ { label: 'Engagements', icon: 'pi pi-shield', command: () => router.push('/engagements') }, { label: 'Recipes', icon: 'pi pi-play', command: () => router.push('/recipes') }, { label: 'Containers', icon: 'pi pi-box', command: () => router.push('/containers') }, + { label: 'Attack Chain', icon: 'pi pi-share-alt', command: () => router.push('/chain/global') }, { label: 'IOCs', icon: 'pi pi-search', items: [ diff --git a/packages/web/frontend/src/components/ChainDetailPanel.vue b/packages/web/frontend/src/components/ChainDetailPanel.vue index 277f3a0..d083ab9 100644 --- a/packages/web/frontend/src/components/ChainDetailPanel.vue +++ b/packages/web/frontend/src/components/ChainDetailPanel.vue @@ -10,6 +10,7 @@ interface GraphNode { tool: string phase: string | null neighborCount?: number + pivotality?: number } interface GraphLink { @@ -22,6 +23,7 @@ interface GraphLink { reasons: string[] relation_type: string | null rationale: string | null + weight_model_version?: string } const props = defineProps<{ @@ -35,6 +37,7 @@ const emit = defineEmits<{ (e: 'confirm', linkId: string): void (e: 'reject', linkId: string): void (e: 'expand', nodeId: string): void + (e: 'export-path'): void }>() function findNode(ref: string | { id: string }): GraphNode | undefined { @@ -86,6 +89,9 @@ function getStatusDisplay(status: string) {
Neighbors: {{ selectedNode.neighborCount }}
+
+ Pivotality: {{ (selectedNode.pivotality * 100).toFixed(0) }}% +