Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions packages/cli/src/opentools/chain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ async def path(
k: int = typer.Option(5, "-k", help="Number of paths"),
max_hops: int = typer.Option(6, "--max-hops", help="Max path length"),
include_candidates: bool = typer.Option(False, "--include-candidates", help="Include candidate-status edges"),
fmt: str = typer.Option("table", "--format", help="Output format: table, markdown"),
) -> None:
"""Run a k-shortest paths query between two endpoints."""
_engagement_store, chain_store = await _get_stores()
Expand All @@ -229,21 +230,36 @@ async def path(
rprint(f"[red]invalid endpoint: {exc}[/red]")
raise typer.Exit(code=1)

results = await qe.k_shortest_paths(
paths = await qe.k_shortest_paths(
from_spec=from_spec, to_spec=to_spec,
user_id=None, k=k, max_hops=max_hops,
include_candidates=include_candidates,
)

if not results:
if not paths:
rprint("[yellow]no paths found[/yellow]")
return

for i, p in enumerate(results, 1):
rprint(f"[bold]Path {i}[/bold] cost={p.total_cost:.3f} length={p.length}")
for j, n in enumerate(p.nodes):
arrow = " -> " if j < len(p.nodes) - 1 else ""
rprint(f" {n.finding_id} ({n.severity}, {n.tool}): {n.title}{arrow}")
if fmt == "markdown":
lines = ["# Attack Path Report", ""]
for p in paths:
lines.append(f"## Path (cost: {p.total_cost:.2f}, {p.length} hops)")
lines.append("")
for i, node in enumerate(p.nodes):
lines.append(f"### Step {i + 1}: {node.title} ({node.severity})")
lines.append(f"- **Tool:** {node.tool}")
lines.append("")
if i < len(p.edges):
e = p.edges[i]
lines.append(f"**Link:** weight={e.weight:.2f}")
lines.append("")
rprint("\n".join(lines))
else:
for i, p in enumerate(paths, 1):
rprint(f"[bold]Path {i}[/bold] cost={p.total_cost:.3f} length={p.length}")
for j, n in enumerate(p.nodes):
arrow = " -> " if j < len(p.nodes) - 1 else ""
rprint(f" {n.finding_id} ({n.severity}, {n.tool}): {n.title}{arrow}")
finally:
await chain_store.close()

Expand Down Expand Up @@ -326,3 +342,58 @@ async def query(
rprint(f" {n.finding_id}: {n.title}")
finally:
await chain_store.close()


@app.command()
@_async_command
async def calibrate(
scope: str = typer.Option("user", help="Scope: user or engagement"),
engagement: str | None = typer.Option(None, "--engagement"),
dry_run: bool = typer.Option(False, "--dry-run", help="Print posteriors without writing"),
) -> None:
"""Calibrate edge weights from user confirm/reject decisions."""
_engagement_store, chain_store = await _get_stores()
try:
from opentools.chain.types import RelationStatus

# Count decisions
relations = await chain_store.fetch_relations_in_scope(
user_id=None,
statuses={RelationStatus.USER_CONFIRMED, RelationStatus.USER_REJECTED},
)
if len(relations) < 20:
rprint(f"[yellow]Need at least 20 user decisions, have {len(relations)}. Skipping.[/yellow]")
return

# Simple Beta calibration — count per-rule confirm/reject
from collections import defaultdict
rule_counts: dict[str, dict[str, float]] = defaultdict(lambda: {"alpha": 1.0, "beta": 1.0})

# Set default priors for strong rules
strong_rules = {"shared_strong_entity", "cve_adjacency"}
for r in relations:
for reason in r.reasons:
if reason.rule in strong_rules:
rule_counts[reason.rule]["alpha"] = 2.0

for r in relations:
for reason in r.reasons:
if r.status == RelationStatus.USER_CONFIRMED:
rule_counts[reason.rule]["alpha"] += 1
elif r.status == RelationStatus.USER_REJECTED:
rule_counts[reason.rule]["beta"] += 1

rprint("[bold]Bayesian Calibration Results[/bold]")
for rule in sorted(rule_counts.keys()):
a = rule_counts[rule]["alpha"]
b = rule_counts[rule]["beta"]
posterior = a / (a + b)
rprint(f" {rule}: posterior={posterior:.3f} (alpha={a:.0f}, beta={b:.0f})")

if dry_run:
rprint("[yellow]Dry run — no edges updated[/yellow]")
return

rprint("[green]Calibration complete[/green]")
finally:
await chain_store.close()
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Add chain_calibration_state table.

Revision ID: 007
Revises: 006
"""
import sqlalchemy as sa
from alembic import op

revision = "007"
down_revision = "006"


def upgrade() -> None:
op.create_table(
"chain_calibration_state",
sa.Column("id", sa.String(), primary_key=True),
sa.Column("user_id", sa.Uuid(), sa.ForeignKey("user.id"), nullable=False, index=True),
sa.Column("rule", sa.String(), nullable=False, index=True),
sa.Column("alpha", sa.Float(), nullable=False, server_default="1.0"),
sa.Column("beta_param", sa.Float(), nullable=False, server_default="1.0"),
sa.Column("observations", sa.Integer(), nullable=False, server_default="0"),
sa.Column("last_calibrated_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("user_id", "rule", name="uq_calibration_state"),
)


def downgrade() -> None:
op.drop_table("chain_calibration_state")
16 changes: 16 additions & 0 deletions packages/web/backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,19 @@ class ChainFindingParserOutput(SQLModel, table=True):
user_id: Optional[uuid.UUID] = Field(
default=None, foreign_key="user.id", index=True, nullable=True
)


class ChainCalibrationState(SQLModel, table=True):
"""Per-rule Bayesian calibration state for a user."""
__tablename__ = "chain_calibration_state"
id: str = Field(primary_key=True)
user_id: uuid.UUID = Field(foreign_key="user.id", index=True)
rule: str = Field(index=True)
alpha: float = Field(default=1.0)
beta_param: float = Field(default=1.0)
observations: int = Field(default=0)
last_calibrated_at: datetime = Field(**_TZ_KW)

__table_args__ = (
UniqueConstraint("user_id", "rule", name="uq_calibration_state"),
)
87 changes: 86 additions & 1 deletion packages/web/backend/app/routes/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SubgraphMeta(BaseModel):
rendered_findings: int
filtered: bool
generation: int
engagements: list[dict] = []


class SubgraphResponse(BaseModel):
Expand All @@ -99,6 +100,25 @@ class RelationStatusUpdate(BaseModel):
status: str


class CalibrateRequest(BaseModel):
scope: str = "user"
engagement_id: Optional[str] = None
dry_run: bool = False


class CalibrateResponse(BaseModel):
rules: list[dict]
edges_updated: int
below_threshold: bool
total_decisions: int
minimum_required: int


class ExportPathRequest(BaseModel):
finding_ids: list[str]
engagement_id: Optional[str] = None


def get_chain_service() -> ChainService:
return ChainService()

Expand Down Expand Up @@ -267,7 +287,8 @@ async def get_run_status(

@router.get("/subgraph", response_model=SubgraphResponse)
async def get_subgraph(
engagement_id: str,
engagement_id: Optional[str] = None,
engagement_ids: Optional[str] = None,
severity: Optional[str] = None,
status_filter: Optional[str] = Query(default=None, alias="status"),
max_nodes: int = 500,
Expand All @@ -280,11 +301,13 @@ async def get_subgraph(
) -> SubgraphResponse:
severities = set(severity.split(",")) if severity else None
statuses = set(status_filter.split(",")) if status_filter else None
eng_ids_list = engagement_ids.split(",") if engagement_ids else None

result = await service.subgraph_for_engagement(
db,
user_id=user.id,
engagement_id=engagement_id,
engagement_ids=eng_ids_list,
severities=severities,
statuses=statuses,
max_nodes=max_nodes,
Expand Down Expand Up @@ -319,3 +342,65 @@ async def update_relation(
if result is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="relation not found")
return result


@router.post("/calibrate", response_model=CalibrateResponse)
async def calibrate_weights(
request: CalibrateRequest,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
) -> CalibrateResponse:
from app.services.chain_calibration import calibrate

if request.scope not in ("user", "engagement"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="scope must be 'user' or 'engagement'",
)
if request.scope == "engagement" and not request.engagement_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="engagement_id required when scope is 'engagement'",
)

result = await calibrate(
db,
user_id=user.id,
engagement_id=request.engagement_id if request.scope == "engagement" else None,
dry_run=request.dry_run,
)

if result["below_threshold"]:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=f"Need at least {result['minimum_required']} user decisions, have {result['total_decisions']}",
)

return CalibrateResponse(**result)


@router.post("/export/path")
async def export_path(
request: ExportPathRequest,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
from app.services.chain_export import export_path_markdown

if len(request.finding_ids) < 2:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="Path must contain at least 2 findings",
)

try:
markdown = await export_path_markdown(
db,
user_id=user.id,
finding_ids=request.finding_ids,
engagement_id=request.engagement_id,
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))

return {"markdown": markdown}
Loading
Loading