Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions benchmarks/ablation_eu_ai_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""4-way ablation on the eu-ai-act-prohibited pack.

Compares macro F1 / R / P / FP rate at the pack's default threshold (1.5)
across four configurations:
A = baseline (no lexical_groups, no policy_overrides)
B = +lexical only (lexical_groups on, policy_overrides off)
C = +policy only (lexical_groups off, policy_overrides on)
D = +both (lexical_groups on, policy_overrides on)

For each config we mutate a temp copy of the pack's _ns.json, load it as
a microresolve namespace, run the 100 prohibited + 80 benign corpus
through it, and compute per-intent + macro metrics.

Pass criteria (locked before measurement):
- D ≥ A + 1pp F1
- D ≥ B + 0.5pp F1 (proves policy adds value over morph)
- No pack regresses >1pp F1 from B → D
- Benign FP rate increases ≤2pp from B → D anywhere
"""
import json
import shutil
import sys
from pathlib import Path
from collections import defaultdict

import microresolve

PACK_NAME = "eu-ai-act-prohibited"
PACK_SRC = Path("packs") / PACK_NAME
CORPUS = Path("_internal/EU_AI_ACT_EVAL_CORPUS.json")
TARGET_THRESHOLD = 1.5
GAP = 1.5


def stage_pack(config: str, root: Path) -> Path:
"""Stage the pack to <root>/<config>/<pack>, mutating _ns.json per config."""
cfg_root = root / config
if cfg_root.exists():
shutil.rmtree(cfg_root)
cfg_root.mkdir(parents=True)
dest = cfg_root / PACK_NAME
shutil.copytree(PACK_SRC, dest)

ns_path = dest / "_ns.json"
ns = json.load(open(ns_path))
if config in ("baseline", "policy_only"):
ns.pop("lexical_groups", None)
if config in ("baseline", "lex_only"):
ns.pop("policy_overrides", None)
json.dump(ns, open(ns_path, "w"), indent=2)
return cfg_root


def run_config(config: str, root: Path, corpus: dict) -> dict:
cfg_root = stage_pack(config, root)
engine = microresolve.MicroResolve(data_dir=str(cfg_root))
ns = engine.namespace(PACK_NAME)

# Resolve every query, score it against ground truth
per_intent = defaultdict(lambda: {"tp": 0, "fn": 0, "fp": 0, "tn": 0})
intent_ids = ns.intent_ids()
benign_hits = 0
benign_total = 0

for entry in corpus["prohibited"]:
gt = entry["expected_intent"]
query = entry["text"]
result = ns.resolve(query)
hit_high = any(
i.band == "High" and i.score >= TARGET_THRESHOLD for i in result.intents
)
top = next(
(i for i in result.intents if i.score >= TARGET_THRESHOLD),
None,
)
for iid in intent_ids:
is_gt = iid == gt
is_hit = top is not None and top.id == iid
if is_gt and is_hit:
per_intent[iid]["tp"] += 1
elif is_gt and not is_hit:
per_intent[iid]["fn"] += 1
elif not is_gt and is_hit:
per_intent[iid]["fp"] += 1
else:
per_intent[iid]["tn"] += 1

for entry in corpus["benign"]:
query = entry["text"]
result = ns.resolve(query)
benign_total += 1
top = next(
(i for i in result.intents if i.score >= TARGET_THRESHOLD),
None,
)
if top is not None and top.id != "legitimate_use":
benign_hits += 1
for iid in intent_ids:
if iid == top.id and iid != "legitimate_use":
per_intent[iid]["fp"] += 1

# Compute macros
metrics = {}
p_sum = r_sum = f_sum = 0.0
n = 0
for iid in intent_ids:
if iid == "legitimate_use":
continue
m = per_intent[iid]
tp, fn, fp = m["tp"], m["fn"], m["fp"]
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
metrics[iid] = {"p": p, "r": r, "f1": f, "tp": tp, "fn": fn, "fp": fp}
p_sum += p
r_sum += r
f_sum += f
n += 1

benign_fp_rate = benign_hits / benign_total if benign_total > 0 else 0.0

return {
"config": config,
"macro_p": p_sum / n,
"macro_r": r_sum / n,
"macro_f1": f_sum / n,
"benign_fp_rate": benign_fp_rate,
"benign_hits": benign_hits,
"benign_total": benign_total,
"per_intent": metrics,
}


def main():
corpus = json.load(open(CORPUS))
root = Path("/tmp/ablation")
if root.exists():
shutil.rmtree(root)
root.mkdir(parents=True)

configs = ["baseline", "lex_only", "policy_only", "both"]
results = {}
for c in configs:
print(f"--- {c} ---", flush=True)
results[c] = run_config(c, root, corpus)
r = results[c]
print(
f" macro: P={r['macro_p']:.3f} R={r['macro_r']:.3f} F1={r['macro_f1']:.3f}"
f" benign-FP={r['benign_fp_rate']:.3f} ({r['benign_hits']}/{r['benign_total']})"
)

out = Path("benchmarks/results/ablation_eu_ai_act.json")
out.parent.mkdir(exist_ok=True)
json.dump(results, open(out, "w"), indent=2)

print()
print("=" * 72)
print("Summary config F1 ΔF1 R ΔR P benign-FP")
print("=" * 72)
base = results["baseline"]
for c in configs:
r = results[c]
d_f1 = (r["macro_f1"] - base["macro_f1"]) * 100
d_r = (r["macro_r"] - base["macro_r"]) * 100
print(
f" {c:13s} {r['macro_f1']:.3f} {d_f1:+5.1f}pp {r['macro_r']:.3f} {d_r:+5.1f}pp {r['macro_p']:.3f} {r['benign_fp_rate']:.3f}"
)

print()
print("Pass criteria check:")
b = results["baseline"]
lex = results["lex_only"]
both = results["both"]
crit1 = (both["macro_f1"] - b["macro_f1"]) >= 0.01
crit2 = (both["macro_f1"] - lex["macro_f1"]) >= 0.005
crit3 = (lex["macro_f1"] - both["macro_f1"]) <= 0.01
crit4 = (both["benign_fp_rate"] - lex["benign_fp_rate"]) <= 0.02
print(f" D > A + 1pp F1? {crit1} ({(both['macro_f1']-b['macro_f1'])*100:+.2f}pp)")
print(f" D > B + 0.5pp F1? {crit2} ({(both['macro_f1']-lex['macro_f1'])*100:+.2f}pp)")
print(f" D - B regression ≤ 1pp F1? {crit3}")
print(f" D - B benign-FP ≤ 2pp? {crit4} ({(both['benign_fp_rate']-lex['benign_fp_rate'])*100:+.2f}pp)")
all_pass = crit1 and crit2 and crit3 and crit4
print(f"\n OVERALL: {'PASS — ship combined' if all_pass else 'KILL — strip policy_overrides'}")

print()
print(f"Full results written to {out}")
return 0 if all_pass else 1


if __name__ == "__main__":
sys.exit(main())
2 changes: 2 additions & 0 deletions src/bin/server/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod routes_intents;
mod routes_lexical;
mod routes_logs;
mod routes_phrases;
mod routes_policy_overrides;
mod routes_projects;
mod routes_review;
mod routes_settings;
Expand Down Expand Up @@ -378,6 +379,7 @@ async fn main() {
// here.
let protected_api = axum::Router::new()
.merge(routes_core::routes())
.merge(routes_policy_overrides::routes())
.merge(routes_intents::routes())
.merge(routes_lexical::routes())
.merge(routes_logs::routes())
Expand Down
100 changes: 96 additions & 4 deletions src/bin/server/routes_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,17 @@ pub async fn resolve(

// Audit: record the no-match decision too (compliance buyers
// need to see "the system saw this query and declined to fire").
audit_resolve(&state, &kid, &app_id, &req.query, &[], 0.0, latency_us);
let no_match_trace = opt_trace.as_ref().map(build_compact_audit_trace);
audit_resolve(
&state,
&kid,
&app_id,
&req.query,
&[],
0.0,
latency_us,
no_match_trace,
);

let trace_val = opt_trace.map(|t| build_trace_json(&t));
let mut resp = serde_json::json!({
Expand Down Expand Up @@ -202,8 +212,20 @@ pub async fn resolve(
}

// ── Audit log: tamper-evident decision record ────────────────────
// When the caller asked for a trace, embed a compact summary in the
// audit payload too — this is what makes Art. 13 interpretive
// transparency real (you can defend not just "we routed" but "we
// routed because tokens X, Y, Z").
let compact_trace = opt_trace.as_ref().map(build_compact_audit_trace);
audit_resolve(
&state, &kid, &app_id, &req.query, &intents, threshold, latency_us,
&state,
&kid,
&app_id,
&req.query,
&intents,
threshold,
latency_us,
compact_trace,
);

let mut resp = serde_json::json!({
Expand All @@ -224,7 +246,11 @@ pub async fn resolve(
/// shapes the payload and serializes the chain write. The query is
/// stored as a SHA-256 hash (PII-friendly) — auditors can verify
/// "decision X happened for query Y" by hashing Y and looking it up,
/// without the operator retaining raw queries.
/// without the operator retaining raw queries. When `compact_trace`
/// is supplied, it lands inside the payload — surfaces *why* a routing
/// happened, not just *that* it happened (Art. 13 interpretive
/// transparency in the audit chain).
#[allow(clippy::too_many_arguments)]
fn audit_resolve(
state: &AppState,
kid: &str,
Expand All @@ -233,17 +259,21 @@ fn audit_resolve(
intents: &[serde_json::Value],
threshold: f32,
latency_us: u64,
compact_trace: Option<serde_json::Value>,
) {
if !state.audit_log.mode().enabled() {
return;
}
let payload = serde_json::json!({
let mut payload = serde_json::json!({
"ns": app_id,
"query_hash": hash_query(query),
"intents": intents,
"threshold_applied": threshold,
"latency_us": latency_us,
});
if let Some(t) = compact_trace {
payload["trace"] = t;
}
state.audit_log.record(kid, app_id, "resolve", payload);
}

Expand All @@ -259,5 +289,67 @@ fn build_trace_json(t: &microresolve::ResolveTrace) -> serde_json::Value {
},
"negated": t.negated,
"threshold_applied": t.threshold_applied,
"per_token": t.per_token.iter().map(|c| serde_json::json!({
"token": c.token,
"intent": c.intent,
"weight": (c.weight * 1000.0).round() / 1000.0,
"idf": (c.idf * 1000.0).round() / 1000.0,
"delta": (c.delta * 1000.0).round() / 1000.0,
"negated": c.negated,
})).collect::<Vec<_>>(),
"per_intent": t.per_intent.iter().map(|s| serde_json::json!({
"intent": s.intent,
"raw_score": (s.raw_score * 100.0).round() / 100.0,
"voting_tokens": s.voting_tokens,
"voting_multiplier": (s.voting_multiplier * 100.0).round() / 100.0,
"policy_overrides_bonus": (s.policy_overrides_bonus * 100.0).round() / 100.0,
"policy_overrides_fired": s.policy_overrides_fired,
})).collect::<Vec<_>>(),
"explanation": t.explanation,
})
}

/// Compact trace summary for audit log entries: top intents (with voting state
/// and any conjunctions that fired) and top 5 token contributions. Designed
/// to be small enough to live inside every resolve audit event without
/// bloating the chain. Full trace stays in the API response only when requested.
fn build_compact_audit_trace(t: &microresolve::ResolveTrace) -> serde_json::Value {
let top_intents: Vec<serde_json::Value> = t
.per_intent
.iter()
.take(3)
.map(|s| {
serde_json::json!({
"intent": s.intent,
"raw_score": (s.raw_score * 100.0).round() / 100.0,
"voting_tokens": s.voting_tokens,
"policy_overrides_fired": s.policy_overrides_fired,
})
})
.collect();
// Top 5 by absolute delta, picking the highest-impact contributions only.
let mut sorted_contrib: Vec<&microresolve::scoring::TokenContribution> =
t.per_token.iter().collect();
sorted_contrib.sort_by(|a, b| {
b.delta
.abs()
.partial_cmp(&a.delta.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_tokens: Vec<serde_json::Value> = sorted_contrib
.iter()
.take(5)
.map(|c| {
serde_json::json!({
"token": c.token,
"intent": c.intent,
"delta": (c.delta * 1000.0).round() / 1000.0,
})
})
.collect();
serde_json::json!({
"top_intents": top_intents,
"top_tokens": top_tokens,
"explanation": t.explanation,
})
}
Loading
Loading