From eb4aa06f25212aab35e90af62c8051c4b624bb45 Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 14:13:22 +0530 Subject: [PATCH 1/5] Add LoCoMo and BEAM benchmark harnesses --- .gitignore | 18 +++ benchmarks/README.md | 6 +- benchmarks/beam/README.md | 53 ++++++++ benchmarks/beam/__init__.py | 1 + benchmarks/beam/config.py | 63 ++++++++++ benchmarks/beam/dataset.py | 223 ++++++++++++++++++++++++++++++++++ benchmarks/beam/run.py | 111 +++++++++++++++++ benchmarks/beam/runner.py | 175 ++++++++++++++++++++++++++ benchmarks/common/__init__.py | 1 + benchmarks/common/io.py | 86 +++++++++++++ benchmarks/common/metrics.py | 86 +++++++++++++ benchmarks/common/xmem.py | 126 +++++++++++++++++++ benchmarks/locomo/README.md | 36 ++++++ benchmarks/locomo/__init__.py | 1 + benchmarks/locomo/config.py | 51 ++++++++ benchmarks/locomo/dataset.py | 199 ++++++++++++++++++++++++++++++ benchmarks/locomo/run.py | 100 +++++++++++++++ benchmarks/locomo/runner.py | 172 ++++++++++++++++++++++++++ 18 files changed, 1506 insertions(+), 2 deletions(-) create mode 100644 benchmarks/beam/README.md create mode 100644 benchmarks/beam/__init__.py create mode 100644 benchmarks/beam/config.py create mode 100644 benchmarks/beam/dataset.py create mode 100644 benchmarks/beam/run.py create mode 100644 benchmarks/beam/runner.py create mode 100644 benchmarks/common/__init__.py create mode 100644 benchmarks/common/io.py create mode 100644 benchmarks/common/metrics.py create mode 100644 benchmarks/common/xmem.py create mode 100644 benchmarks/locomo/README.md create mode 100644 benchmarks/locomo/__init__.py create mode 100644 benchmarks/locomo/config.py create mode 100644 benchmarks/locomo/dataset.py create mode 100644 benchmarks/locomo/run.py create mode 100644 benchmarks/locomo/runner.py diff --git a/.gitignore b/.gitignore index 30ef4e6..cdebef9 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,24 @@ benchmarks/longmemeval/**/*.pyc benchmarks/longmemeval/data/ benchmarks/longmemeval/results/ benchmarks/longmemeval/outputs/ +!benchmarks/common/ +!benchmarks/common/** +benchmarks/common/**/__pycache__/ +benchmarks/common/**/*.pyc +!benchmarks/locomo/ +!benchmarks/locomo/** +benchmarks/locomo/**/__pycache__/ +benchmarks/locomo/**/*.pyc +benchmarks/locomo/data/ +benchmarks/locomo/results/ +benchmarks/locomo/outputs/ +!benchmarks/beam/ +!benchmarks/beam/** +benchmarks/beam/**/__pycache__/ +benchmarks/beam/**/*.pyc +benchmarks/beam/data/ +benchmarks/beam/results/ +benchmarks/beam/outputs/ backboard/ rust/ diff --git a/benchmarks/README.md b/benchmarks/README.md index 4408368..6c3b3f1 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -3,7 +3,9 @@ This directory contains benchmark harnesses for XMem. - `longmemeval/`: Python-only LongMemEval benchmark runner targeting the XMem HTTP API. +- `locomo/`: Python-only LoCoMo benchmark runner targeting the XMem HTTP API. +- `beam/`: Python-only BEAM runner, defaulting to the Hugging Face BEAM 1M split. Benchmark runs can create large dataset and result artifacts. Keep those files under -`benchmarks/longmemeval/data`, `benchmarks/longmemeval/results`, or -`benchmarks/longmemeval/outputs`; those paths are intentionally ignored by git. +each benchmark's `data`, `results`, or `outputs` directory; those paths are +intentionally ignored by git. diff --git a/benchmarks/beam/README.md b/benchmarks/beam/README.md new file mode 100644 index 0000000..ac7d4d1 --- /dev/null +++ b/benchmarks/beam/README.md @@ -0,0 +1,53 @@ +# BEAM 1M Benchmark for XMem Python + +This harness benchmarks the Python XMem API on the BEAM dataset, defaulting to +the `1M` split from `Mohammadta/BEAM` on Hugging Face. It does not run or +compare the Go implementation. + +BEAM rows contain long `chat` histories and stringified `probing_questions`. +The dataset card lists ten memory ability types: abstention, contradiction +resolution, event ordering, information extraction, instruction following, +knowledge update, multi-session reasoning, preference following, summarization, +and temporal reasoning. + +## Dependencies + +BEAM is distributed as parquet. Install `pyarrow` before reading the downloaded +dataset: + +```bash +pip install pyarrow +``` + +## Smoke Check + +```bash +python -m benchmarks.beam.run \ + --split 1M \ + --download \ + --dry-run \ + --limit 1 +``` + +This downloads the BEAM 1M parquet file, validates parsing, counts ingest items, +and does not call XMem. + +## Run Against XMem + +```bash +export XMEM_API_KEY="..." + +python -m benchmarks.beam.run \ + --split 1M \ + --dataset-path benchmarks/beam/data/1M-00000-of-00001.parquet \ + --api-base-url https://api.xmem.in \ + --output-dir benchmarks/beam/results/beam-1m +``` + +Outputs: + +- `results.jsonl`: full per-question records with local proxy metrics +- `predictions.jsonl`: `question_id` and `hypothesis` +- `summary.json`: local exact/contains/token-F1 grouped by BEAM question type + +Use BEAM's official/equivalent evaluator for publication-quality accuracy. diff --git a/benchmarks/beam/__init__.py b/benchmarks/beam/__init__.py new file mode 100644 index 0000000..9cf012e --- /dev/null +++ b/benchmarks/beam/__init__.py @@ -0,0 +1 @@ +"""BEAM benchmark harness for the Python XMem API.""" diff --git a/benchmarks/beam/config.py b/benchmarks/beam/config.py new file mode 100644 index 0000000..41f20f9 --- /dev/null +++ b/benchmarks/beam/config.py @@ -0,0 +1,63 @@ +"""Configuration for the BEAM benchmark harness.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +DEFAULT_API_BASE_URL = "https://api.xmem.in" +DEFAULT_API_KEY_ENV = "XMEM_API_KEY" +DEFAULT_SPLIT = "1M" +DEFAULT_DATASET_URLS = { + "100K": ( + "https://huggingface.co/datasets/Mohammadta/BEAM/resolve/main/" + "data/100K-00000-of-00001.parquet" + ), + "500K": ( + "https://huggingface.co/datasets/Mohammadta/BEAM/resolve/main/" + "data/500K-00000-of-00001.parquet" + ), + "1M": ( + "https://huggingface.co/datasets/Mohammadta/BEAM/resolve/main/" + "data/1M-00000-of-00001.parquet" + ), +} + + +@dataclass(frozen=True) +class BenchmarkConfig: + dataset_path: Path + output_dir: Path + api_base_url: str = DEFAULT_API_BASE_URL + api_key_env: str = DEFAULT_API_KEY_ENV + api_timeout_seconds: float = 120.0 + max_retries: int = 3 + retry_backoff_seconds: float = 2.0 + batch_size: int = 25 + ingest_api_version: str = "v2" + poll_interval_seconds: float = 2.0 + poll_timeout_seconds: float = 1800.0 + top_k: int = 10 + effort_level: str = "low" + user_prefix: str = "beam" + limit: int | None = None + offset: int = 0 + question_type: str | None = None + split: str = DEFAULT_SPLIT + skip_ingest: bool = False + resume: bool = True + dry_run: bool = False + + @property + def api_key(self) -> str: + return os.getenv(self.api_key_env, "").strip() + + def require_api_key(self) -> str: + api_key = self.api_key + if not api_key: + raise RuntimeError( + f"Missing API key. Set {self.api_key_env} before running BEAM." + ) + return api_key diff --git a/benchmarks/beam/dataset.py b/benchmarks/beam/dataset.py new file mode 100644 index 0000000..2a94bba --- /dev/null +++ b/benchmarks/beam/dataset.py @@ -0,0 +1,223 @@ +"""Dataset loading and normalization for BEAM.""" + +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterable + +from benchmarks.common.io import download_file, read_records + +from .config import DEFAULT_DATASET_URLS + + +@dataclass(frozen=True) +class BeamTurn: + role: str + content: str + time_anchor: str = "" + message_id: str = "" + + +@dataclass(frozen=True) +class BeamConversation: + conversation_id: str + chat_sessions: list[list[BeamTurn]] = field(default_factory=list) + + +@dataclass(frozen=True) +class BeamExample: + question_id: str + conversation_id: str + question: str + answer: str + question_type: str = "" + split: str = "1M" + chat_sessions: list[list[BeamTurn]] = field(default_factory=list) + + @property + def user_id_suffix(self) -> str: + safe = "".join( + ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in self.question_id + ) + return safe.strip("_") or "example" + + +@dataclass(frozen=True) +class IngestItem: + user_query: str + agent_response: str + user_id: str + session_datetime: str = "" + effort_level: str = "low" + + +def download_dataset(split: str, destination: Path) -> Path: + if split not in DEFAULT_DATASET_URLS: + known = ", ".join(sorted(DEFAULT_DATASET_URLS)) + raise ValueError(f"Unknown BEAM split '{split}'. Known splits: {known}") + return download_file( + DEFAULT_DATASET_URLS[split], + destination, + timeout_seconds=300.0, + ) + + +def load_examples(path: Path, *, split: str = "1M") -> list[BeamExample]: + records = read_records(path) + examples: list[BeamExample] = [] + for conv_index, record in enumerate(records): + conversation_id = str( + record.get("conversation_id") or f"conversation-{conv_index}" + ) + chat_sessions = _parse_chat(record.get("chat") or []) + questions = _parse_probing_questions(record.get("probing_questions") or []) + for q_index, question_record in enumerate(questions): + question = _first_text( + question_record, + ("question", "query", "prompt", "user_question"), + ) + if not question: + continue + answer = _first_text( + question_record, + ("answer", "gold_answer", "reference"), + ) + question_type = _first_text( + question_record, + ("question_type", "type", "ability", "category"), + ) + examples.append( + BeamExample( + question_id=str( + question_record.get("question_id") + or question_record.get("id") + or f"{conversation_id}-q-{q_index}" + ), + conversation_id=conversation_id, + question=question, + answer=answer, + question_type=question_type or "unknown", + split=split, + chat_sessions=chat_sessions, + ) + ) + return examples + + +def select_examples( + examples: Iterable[BeamExample], + *, + offset: int = 0, + limit: int | None = None, + question_type: str | None = None, +) -> list[BeamExample]: + selected = list(examples) + if question_type: + selected = [ + example + for example in selected + if example.question_type.lower() == question_type.lower() + ] + if offset: + selected = selected[offset:] + if limit is not None: + selected = selected[:limit] + return selected + + +def build_ingest_items( + example: BeamExample, + *, + user_id: str, + effort_level: str = "low", +) -> list[IngestItem]: + items: list[IngestItem] = [] + for session in example.chat_sessions: + for index, turn in enumerate(session): + next_turn = session[index + 1] if index + 1 < len(session) else None + items.append( + IngestItem( + user_query=_format_turn(turn), + agent_response=_format_turn(next_turn) if next_turn else "", + user_id=user_id, + session_datetime=turn.time_anchor, + effort_level=effort_level, + ) + ) + return items + + +def _parse_chat(raw_chat: Any) -> list[list[BeamTurn]]: + raw_chat = _coerce_literal(raw_chat) + if not isinstance(raw_chat, list): + return [] + if raw_chat and isinstance(raw_chat[0], dict): + raw_chat = [raw_chat] + + sessions: list[list[BeamTurn]] = [] + for raw_session in raw_chat: + if not isinstance(raw_session, list): + continue + turns = [_parse_turn(item) for item in raw_session] + turns = [turn for turn in turns if turn and turn.content] + if turns: + sessions.append(turns) + return sessions + + +def _parse_turn(raw_turn: Any) -> BeamTurn | None: + if not isinstance(raw_turn, dict): + return None + return BeamTurn( + role=str(raw_turn.get("role") or "message"), + content=str(raw_turn.get("content") or raw_turn.get("text") or "").strip(), + time_anchor=str(raw_turn.get("time_anchor") or ""), + message_id=str(raw_turn.get("id") or raw_turn.get("index") or ""), + ) + + +def _parse_probing_questions(raw_questions: Any) -> list[dict[str, Any]]: + raw_questions = _coerce_literal(raw_questions) + if isinstance(raw_questions, dict): + values = raw_questions.values() + return [item for item in values if isinstance(item, dict)] + if isinstance(raw_questions, list): + return [item for item in raw_questions if isinstance(item, dict)] + return [] + + +def _coerce_literal(value: Any) -> Any: + if not isinstance(value, str): + return value + text = value.strip() + if not text: + return [] + try: + return ast.literal_eval(text) + except (SyntaxError, ValueError): + try: + return json.loads(text) + except json.JSONDecodeError: + return value + + +def _first_text(record: dict[str, Any], keys: tuple[str, ...]) -> str: + for key in keys: + value = record.get(key) + if value is not None: + return str(value).strip() + return "" + + +def _format_turn(turn: BeamTurn | None) -> str: + if turn is None: + return "" + prefix = turn.role + if turn.message_id: + prefix = f"{prefix} ({turn.message_id})" + if turn.time_anchor: + prefix = f"{prefix} [{turn.time_anchor}]" + return f"{prefix}: {turn.content}" diff --git a/benchmarks/beam/run.py b/benchmarks/beam/run.py new file mode 100644 index 0000000..a752e63 --- /dev/null +++ b/benchmarks/beam/run.py @@ -0,0 +1,111 @@ +"""Command line entrypoint for the BEAM benchmark.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +from pathlib import Path + +from .config import ( + DEFAULT_API_BASE_URL, + DEFAULT_API_KEY_ENV, + DEFAULT_SPLIT, + BenchmarkConfig, +) +from .dataset import download_dataset +from .runner import BeamRunner + + +def main() -> None: + try: + args = parse_args() + dataset_path = prepare_dataset(args) + config = BenchmarkConfig( + dataset_path=dataset_path, + output_dir=args.output_dir, + api_base_url=args.api_base_url, + api_key_env=args.api_key_env, + api_timeout_seconds=args.api_timeout_seconds, + max_retries=args.max_retries, + retry_backoff_seconds=args.retry_backoff_seconds, + batch_size=args.batch_size, + ingest_api_version=args.ingest_api_version, + poll_interval_seconds=args.poll_interval_seconds, + poll_timeout_seconds=args.poll_timeout_seconds, + top_k=args.top_k, + effort_level=args.effort_level, + user_prefix=args.user_prefix, + limit=args.limit, + offset=args.offset, + question_type=args.question_type, + split=args.split, + skip_ingest=args.skip_ingest, + resume=not args.no_resume, + dry_run=args.dry_run, + ) + print(json.dumps(asyncio.run(BeamRunner(config).run()), indent=2)) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + + +def prepare_dataset(args: argparse.Namespace) -> Path: + if args.download: + try: + return download_dataset(args.split, args.dataset_path) + except Exception as exc: + raise RuntimeError( + "Failed to download BEAM. Check network access, or pass " + f"--dataset-path to a local file. Details: {exc}" + ) from exc + if not args.dataset_path.exists(): + raise FileNotFoundError( + f"Dataset file not found: {args.dataset_path}. " + "Run with --download, or pass --dataset-path." + ) + return args.dataset_path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run BEAM against XMem.") + parser.add_argument( + "--split", + choices=("100K", "500K", "1M"), + default=DEFAULT_SPLIT, + ) + parser.add_argument( + "--dataset-path", + type=Path, + default=Path("benchmarks/beam/data/1M-00000-of-00001.parquet"), + ) + parser.add_argument("--download", action="store_true") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("benchmarks/beam/results/latest"), + ) + parser.add_argument("--api-base-url", default=DEFAULT_API_BASE_URL) + parser.add_argument("--api-key-env", default=DEFAULT_API_KEY_ENV) + parser.add_argument("--api-timeout-seconds", type=float, default=120.0) + parser.add_argument("--max-retries", type=int, default=3) + parser.add_argument("--retry-backoff-seconds", type=float, default=2.0) + parser.add_argument("--batch-size", type=int, default=25) + parser.add_argument("--ingest-api-version", choices=("v1", "v2"), default="v2") + parser.add_argument("--poll-interval-seconds", type=float, default=2.0) + parser.add_argument("--poll-timeout-seconds", type=float, default=1800.0) + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--effort-level", choices=("low", "high"), default="low") + parser.add_argument("--user-prefix", default="beam") + parser.add_argument("--limit", type=int) + parser.add_argument("--offset", type=int, default=0) + parser.add_argument("--question-type") + parser.add_argument("--skip-ingest", action="store_true") + parser.add_argument("--no-resume", action="store_true") + parser.add_argument("--dry-run", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/beam/runner.py b/benchmarks/beam/runner.py new file mode 100644 index 0000000..e9416e8 --- /dev/null +++ b/benchmarks/beam/runner.py @@ -0,0 +1,175 @@ +"""BEAM benchmark orchestration against the Python XMem API.""" + +from __future__ import annotations + +import time +from typing import Any + +from benchmarks.common.io import append_jsonl, read_jsonl, write_json +from benchmarks.common.metrics import score_answer, summarize_results +from benchmarks.common.xmem import XMemApiClient + +from .config import BenchmarkConfig +from .dataset import BeamExample, build_ingest_items, load_examples, select_examples + + +class BeamRunner: + def __init__(self, config: BenchmarkConfig) -> None: + self.config = config + self.results_path = config.output_dir / "results.jsonl" + self.predictions_path = config.output_dir / "predictions.jsonl" + self.summary_path = config.output_dir / "summary.json" + + async def run(self) -> dict[str, Any]: + examples = select_examples( + load_examples(self.config.dataset_path, split=self.config.split), + offset=self.config.offset, + limit=self.config.limit, + question_type=self.config.question_type, + ) + if self.config.dry_run: + return self._dry_run_summary(examples) + + completed_ids = self._completed_question_ids() if self.config.resume else set() + run_started = time.time() + async with XMemApiClient( + base_url=self.config.api_base_url, + api_key=self.config.require_api_key(), + timeout_seconds=self.config.api_timeout_seconds, + max_retries=self.config.max_retries, + retry_backoff_seconds=self.config.retry_backoff_seconds, + ) as client: + for index, example in enumerate(examples, start=1): + if example.question_id in completed_ids: + continue + result = await self._run_example( + client, + example, + index=index, + total=len(examples), + ) + append_jsonl(self.results_path, result) + append_jsonl( + self.predictions_path, + { + "question_id": result["question_id"], + "hypothesis": result["prediction"], + }, + ) + + results = read_jsonl(self.results_path) + summary = summarize_results(results, group_field="question_type") + summary["dataset_path"] = str(self.config.dataset_path) + summary["api_base_url"] = self.config.api_base_url + summary["split"] = self.config.split + summary["duration_seconds"] = round(time.time() - run_started, 2) + write_json(self.summary_path, summary) + return summary + + async def _run_example( + self, + client: XMemApiClient, + example: BeamExample, + *, + index: int, + total: int, + ) -> dict[str, Any]: + user_id = f"{self.config.user_prefix}-{example.user_id_suffix}" + ingest_count = 0 + ingest_elapsed_ms = 0.0 + if not self.config.skip_ingest: + items = build_ingest_items( + example, + user_id=user_id, + effort_level=self.config.effort_level, + ) + ingest_count = len(items) + ingest_elapsed_ms = await self._ingest_items(client, items) + + retrieve = await client.retrieve( + {"query": example.question, "user_id": user_id, "top_k": self.config.top_k} + ) + prediction = str(retrieve.data.get("answer") or "") + result = { + "question_id": example.question_id, + "conversation_id": example.conversation_id, + "question_type": example.question_type or "unknown", + "split": example.split, + "question": example.question, + "reference_answer": example.answer, + "prediction": prediction, + "metrics": score_answer(prediction, example.answer), + "source_count": len(retrieve.data.get("sources") or []), + "confidence": retrieve.data.get("confidence"), + "user_id": user_id, + "ingest_count": ingest_count, + "ingest_elapsed_ms": round(ingest_elapsed_ms, 2), + "retrieve_elapsed_ms": retrieve.elapsed_ms, + "index": index, + "total": total, + } + print( + f"[{index}/{total}] {example.question_id}: " + f"f1={result['metrics']['token_f1']} retrieve_ms={retrieve.elapsed_ms}", + flush=True, + ) + return result + + async def _ingest_items(self, client: XMemApiClient, items: list[Any]) -> float: + if not items: + return 0.0 + elapsed_ms = 0.0 + processed = 0 + for start in range(0, len(items), self.config.batch_size): + chunk = items[start : start + self.config.batch_size] + payload = [item.__dict__ for item in chunk] + if self.config.ingest_api_version == "v1": + result = await client.batch_ingest_v1(payload) + elapsed_ms += result.elapsed_ms + else: + accepted = await client.batch_ingest_v2(payload) + elapsed_ms += accepted.elapsed_ms + status_url = str(accepted.data.get("status_url") or "") + if not status_url: + raise RuntimeError("XMem v2 batch ingest missing status_url") + status = await client.poll_job( + status_url, + interval_seconds=self.config.poll_interval_seconds, + timeout_seconds=self.config.poll_timeout_seconds, + ) + elapsed_ms += status.elapsed_ms + if str(status.data.get("status") or "").lower() != "succeeded": + raise RuntimeError(f"XMem batch ingest failed: {status.data}") + processed += len(chunk) + print(f"[INGEST] processed={processed}/{len(items)}", flush=True) + return elapsed_ms + + def _completed_question_ids(self) -> set[str]: + return {str(row.get("question_id")) for row in read_jsonl(self.results_path)} + + def _dry_run_summary(self, examples: list[BeamExample]) -> dict[str, Any]: + ingest_counts = [ + len( + build_ingest_items( + example, + user_id="dry-run", + effort_level=self.config.effort_level, + ) + ) + for example in examples + ] + summary = { + "dry_run": True, + "dataset_path": str(self.config.dataset_path), + "split": self.config.split, + "selected_examples": len(examples), + "total_ingest_items": sum(ingest_counts), + "min_ingest_items": min(ingest_counts) if ingest_counts else 0, + "max_ingest_items": max(ingest_counts) if ingest_counts else 0, + "question_types": sorted( + {example.question_type or "unknown" for example in examples} + ), + "conversations": sorted({example.conversation_id for example in examples}), + } + write_json(self.summary_path, summary) + return summary diff --git a/benchmarks/common/__init__.py b/benchmarks/common/__init__.py new file mode 100644 index 0000000..9a4b93c --- /dev/null +++ b/benchmarks/common/__init__.py @@ -0,0 +1 @@ +"""Shared helpers for XMem benchmark harnesses.""" diff --git a/benchmarks/common/io.py b/benchmarks/common/io.py new file mode 100644 index 0000000..dc6b97f --- /dev/null +++ b/benchmarks/common/io.py @@ -0,0 +1,86 @@ +"""Shared benchmark file and download helpers.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import httpx + + +def download_file( + url: str, + destination: Path, + *, + timeout_seconds: float = 120.0, +) -> Path: + destination.parent.mkdir(parents=True, exist_ok=True) + with httpx.stream( + "GET", + url, + follow_redirects=True, + timeout=timeout_seconds, + ) as response: + response.raise_for_status() + with destination.open("wb") as handle: + for chunk in response.iter_bytes(): + handle.write(chunk) + return destination + + +def read_records(path: Path) -> list[dict[str, Any]]: + suffix = path.suffix.lower() + if suffix == ".jsonl": + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line) for line in handle if line.strip()] + if suffix == ".json": + payload = json.loads(path.read_text(encoding="utf-8")) + if isinstance(payload, list): + return payload + if isinstance(payload, dict): + for key in ("data", "examples", "records", "questions"): + value = payload.get(key) + if isinstance(value, list): + return value + return [payload] + if suffix == ".parquet": + return read_parquet_records(path) + raise ValueError( + f"Unsupported dataset format for {path}. Expected JSON, JSONL, or Parquet." + ) + + +def read_parquet_records(path: Path) -> list[dict[str, Any]]: + try: + import pyarrow.parquet as pq + except ImportError as exc: + raise RuntimeError( + "Reading BEAM parquet files requires pyarrow. Install it with " + "`pip install pyarrow`, or convert the dataset to JSON/JSONL and " + "pass --dataset-path." + ) from exc + + table = pq.read_table(path) + return table.to_pylist() + + +def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +def append_jsonl(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + return [] + with path.open("r", encoding="utf-8") as handle: + return [json.loads(line) for line in handle if line.strip()] diff --git a/benchmarks/common/metrics.py b/benchmarks/common/metrics.py new file mode 100644 index 0000000..707bbf7 --- /dev/null +++ b/benchmarks/common/metrics.py @@ -0,0 +1,86 @@ +"""Shared lightweight answer metrics for benchmark smoke summaries.""" + +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Any + + +def normalize_answer(text: str) -> str: + text = text.lower() + text = re.sub(r"[^a-z0-9\s]", " ", text) + text = re.sub(r"\b(a|an|the)\b", " ", text) + return " ".join(text.split()) + + +def token_f1(prediction: str, reference: str) -> float: + pred_tokens = normalize_answer(prediction).split() + ref_tokens = normalize_answer(reference).split() + if not pred_tokens or not ref_tokens: + return float(pred_tokens == ref_tokens) + common = set(pred_tokens) & set(ref_tokens) + overlap = sum( + min(pred_tokens.count(token), ref_tokens.count(token)) + for token in common + ) + if overlap == 0: + return 0.0 + precision = overlap / len(pred_tokens) + recall = overlap / len(ref_tokens) + return 2 * precision * recall / (precision + recall) + + +def score_answer(prediction: str, reference: str) -> dict[str, float | bool]: + normalized_prediction = normalize_answer(prediction) + normalized_reference = normalize_answer(reference) + return { + "exact_match": normalized_prediction == normalized_reference, + "contains": bool( + normalized_reference + and normalized_reference in normalized_prediction + ), + "token_f1": round(token_f1(prediction, reference), 4), + } + + +def summarize_results( + results: list[dict[str, Any]], + *, + group_field: str = "question_type", +) -> dict[str, Any]: + if not results: + return {"count": 0, "overall": {}, f"by_{group_field}": {}} + + buckets: dict[str, list[dict[str, Any]]] = defaultdict(list) + for result in results: + buckets[str(result.get(group_field) or "unknown")].append(result) + return { + "count": len(results), + "overall": _summarize_bucket(results), + f"by_{group_field}": { + label: _summarize_bucket(bucket) + for label, bucket in sorted(buckets.items()) + }, + } + + +def _summarize_bucket(results: list[dict[str, Any]]) -> dict[str, float | int]: + count = len(results) + exact = sum(1 for result in results if result.get("metrics", {}).get("exact_match")) + contains = sum(1 for result in results if result.get("metrics", {}).get("contains")) + f1_scores = [ + float(result.get("metrics", {}).get("token_f1") or 0.0) + for result in results + ] + avg_retrieve_ms = ( + sum(float(result.get("retrieve_elapsed_ms") or 0.0) for result in results) + / count + ) + return { + "count": count, + "exact_match": round(exact / count, 4), + "contains": round(contains / count, 4), + "token_f1": round(sum(f1_scores) / count, 4), + "avg_retrieve_ms": round(avg_retrieve_ms, 2), + } diff --git a/benchmarks/common/xmem.py b/benchmarks/common/xmem.py new file mode 100644 index 0000000..f9d19d5 --- /dev/null +++ b/benchmarks/common/xmem.py @@ -0,0 +1,126 @@ +"""Shared async HTTP client for XMem benchmark harnesses.""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any + +import httpx + + +TERMINAL_JOB_STATUSES = {"succeeded", "dead_letter"} + + +@dataclass(frozen=True) +class ApiCallResult: + data: dict[str, Any] + elapsed_ms: float + + +class XMemApiClient: + """Small async client around the Python XMem API.""" + + def __init__( + self, + *, + base_url: str, + api_key: str, + timeout_seconds: float = 120.0, + max_retries: int = 3, + retry_backoff_seconds: float = 2.0, + ) -> None: + self.base_url = base_url.rstrip("/") + self.max_retries = max_retries + self.retry_backoff_seconds = retry_backoff_seconds + self._client = httpx.AsyncClient( + base_url=self.base_url, + timeout=httpx.Timeout(timeout_seconds), + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": "xmem-benchmark/1.0", + }, + ) + + async def __aenter__(self) -> "XMemApiClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + async def close(self) -> None: + await self._client.aclose() + + async def batch_ingest_v1(self, items: list[dict[str, Any]]) -> ApiCallResult: + return await self._post("/v1/memory/batch-ingest", {"items": items}) + + async def batch_ingest_v2(self, items: list[dict[str, Any]]) -> ApiCallResult: + return await self._post("/v2/memory/batch-ingest", {"items": items}) + + async def retrieve(self, payload: dict[str, Any]) -> ApiCallResult: + return await self._post("/v1/memory/retrieve", payload) + + async def job_status(self, status_url: str) -> ApiCallResult: + return await self._get(status_url) + + async def poll_job( + self, + status_url: str, + *, + interval_seconds: float, + timeout_seconds: float, + ) -> ApiCallResult: + deadline = time.monotonic() + timeout_seconds + last_result: ApiCallResult | None = None + while time.monotonic() < deadline: + last_result = await self.job_status(status_url) + status = str(last_result.data.get("status") or "").lower() + if status in TERMINAL_JOB_STATUSES: + return last_result + await asyncio.sleep(interval_seconds) + status = last_result.data.get("status") if last_result else "unknown" + raise TimeoutError(f"Timed out polling job {status_url}; last status={status}") + + async def _get(self, path: str) -> ApiCallResult: + return await self._request("GET", path) + + async def _post(self, path: str, payload: dict[str, Any]) -> ApiCallResult: + return await self._request("POST", path, json=payload) + + async def _request(self, method: str, path: str, **kwargs: Any) -> ApiCallResult: + request_path = self._request_path(path) + start = time.perf_counter() + response: httpx.Response | None = None + for attempt in range(self.max_retries + 1): + try: + response = await self._client.request(method, request_path, **kwargs) + if response.status_code < 500 and response.status_code != 429: + break + except httpx.HTTPError: + if attempt >= self.max_retries: + raise + if attempt < self.max_retries: + await asyncio.sleep(self.retry_backoff_seconds * (attempt + 1)) + + if response is None: + raise RuntimeError(f"No response from {method} {request_path}") + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + response.raise_for_status() + body = response.json() + if body.get("status") == "error": + error = body.get("error") or f"XMem API error from {request_path}" + raise RuntimeError(error) + data = body.get("data") + if data is None: + data = {} + if not isinstance(data, dict): + data = {"value": data} + return ApiCallResult(data=data, elapsed_ms=elapsed_ms) + + @staticmethod + def _request_path(path: str) -> str: + if path.startswith(("http://", "https://", "/")): + return path + return f"/{path}" diff --git a/benchmarks/locomo/README.md b/benchmarks/locomo/README.md new file mode 100644 index 0000000..895f74f --- /dev/null +++ b/benchmarks/locomo/README.md @@ -0,0 +1,36 @@ +# LoCoMo Benchmark for XMem Python + +This harness benchmarks the Python XMem API on LoCoMo only. It does not run or +compare the Go implementation. + +LoCoMo contains ten long conversations. Each sample includes chronological +sessions under `conversation` and annotated question-answer items under `qa` +with `question`, `answer`, `category`, and optional evidence dialog ids. + +## Smoke Check + +```bash +python -m benchmarks.locomo.run --download --dry-run --limit 2 +``` + +This downloads `locomo10.json`, validates parsing, counts ingest items, and +does not call XMem. + +## Run Against XMem + +```bash +export XMEM_API_KEY="..." + +python -m benchmarks.locomo.run \ + --dataset-path benchmarks/locomo/data/locomo10.json \ + --api-base-url https://api.xmem.in \ + --output-dir benchmarks/locomo/results/run-001 +``` + +Outputs: + +- `results.jsonl`: full per-question records with local proxy metrics +- `predictions.jsonl`: `question_id` and `hypothesis` +- `summary.json`: local exact/contains/token-F1 grouped by LoCoMo category + +Use LoCoMo's official/equivalent evaluator for publication-quality accuracy. diff --git a/benchmarks/locomo/__init__.py b/benchmarks/locomo/__init__.py new file mode 100644 index 0000000..2949211 --- /dev/null +++ b/benchmarks/locomo/__init__.py @@ -0,0 +1 @@ +"""LoCoMo benchmark harness for the Python XMem API.""" diff --git a/benchmarks/locomo/config.py b/benchmarks/locomo/config.py new file mode 100644 index 0000000..5a53ee2 --- /dev/null +++ b/benchmarks/locomo/config.py @@ -0,0 +1,51 @@ +"""Configuration for the LoCoMo benchmark harness.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +DEFAULT_API_BASE_URL = "https://api.xmem.in" +DEFAULT_API_KEY_ENV = "XMEM_API_KEY" +DEFAULT_DATASET_URL = ( + "https://raw.githubusercontent.com/snap-research/locomo/main/" + "data/locomo10.json" +) + + +@dataclass(frozen=True) +class BenchmarkConfig: + dataset_path: Path + output_dir: Path + api_base_url: str = DEFAULT_API_BASE_URL + api_key_env: str = DEFAULT_API_KEY_ENV + api_timeout_seconds: float = 120.0 + max_retries: int = 3 + retry_backoff_seconds: float = 2.0 + batch_size: int = 25 + ingest_api_version: str = "v2" + poll_interval_seconds: float = 2.0 + poll_timeout_seconds: float = 1800.0 + top_k: int = 10 + effort_level: str = "low" + user_prefix: str = "locomo" + limit: int | None = None + offset: int = 0 + category: str | None = None + skip_ingest: bool = False + resume: bool = True + dry_run: bool = False + + @property + def api_key(self) -> str: + return os.getenv(self.api_key_env, "").strip() + + def require_api_key(self) -> str: + api_key = self.api_key + if not api_key: + raise RuntimeError( + f"Missing API key. Set {self.api_key_env} before running LoCoMo." + ) + return api_key diff --git a/benchmarks/locomo/dataset.py b/benchmarks/locomo/dataset.py new file mode 100644 index 0000000..23b8975 --- /dev/null +++ b/benchmarks/locomo/dataset.py @@ -0,0 +1,199 @@ +"""Dataset loading and normalization for LoCoMo.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterable + +from benchmarks.common.io import download_file, read_records + +from .config import DEFAULT_DATASET_URL + + +@dataclass(frozen=True) +class ConversationTurn: + speaker: str + content: str + dialog_id: str = "" + + +@dataclass(frozen=True) +class ConversationSession: + session_id: str + date: str = "" + turns: list[ConversationTurn] = field(default_factory=list) + + +@dataclass(frozen=True) +class LoCoMoExample: + question_id: str + sample_id: str + question: str + answer: str + category: str = "" + evidence: list[str] = field(default_factory=list) + sessions: list[ConversationSession] = field(default_factory=list) + + @property + def user_id_suffix(self) -> str: + safe = "".join( + ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in self.question_id + ) + return safe.strip("_") or "example" + + +@dataclass(frozen=True) +class IngestItem: + user_query: str + agent_response: str + user_id: str + session_datetime: str = "" + effort_level: str = "low" + + +def download_dataset(destination: Path) -> Path: + return download_file(DEFAULT_DATASET_URL, destination) + + +def load_examples(path: Path) -> list[LoCoMoExample]: + records = read_records(path) + examples: list[LoCoMoExample] = [] + for sample_index, record in enumerate(records): + sample_id = str(record.get("sample_id") or f"sample-{sample_index}") + sessions = _parse_sessions(record.get("conversation") or {}) + qa_items = record.get("qa") or [] + if isinstance(qa_items, dict): + qa_items = list(qa_items.values()) + for qa_index, qa in enumerate(qa_items): + if not isinstance(qa, dict): + continue + question = str(qa.get("question") or "").strip() + if not question: + continue + examples.append( + LoCoMoExample( + question_id=str( + qa.get("question_id") + or qa.get("id") + or f"{sample_id}-qa-{qa_index}" + ), + sample_id=sample_id, + question=question, + answer=str(qa.get("answer") or "").strip(), + category=str(qa.get("category") or "unknown").strip(), + evidence=[str(item) for item in qa.get("evidence") or []], + sessions=sessions, + ) + ) + return examples + + +def select_examples( + examples: Iterable[LoCoMoExample], + *, + offset: int = 0, + limit: int | None = None, + category: str | None = None, +) -> list[LoCoMoExample]: + selected = list(examples) + if category: + selected = [ + example + for example in selected + if example.category.lower() == category.lower() + ] + if offset: + selected = selected[offset:] + if limit is not None: + selected = selected[:limit] + return selected + + +def build_ingest_items( + example: LoCoMoExample, + *, + user_id: str, + effort_level: str = "low", +) -> list[IngestItem]: + items: list[IngestItem] = [] + for session in example.sessions: + turns = session.turns + for index, turn in enumerate(turns): + next_turn = turns[index + 1] if index + 1 < len(turns) else None + user_query = _format_turn(turn) + agent_response = _format_turn(next_turn) if next_turn else "" + items.append( + IngestItem( + user_query=user_query, + agent_response=agent_response, + user_id=user_id, + session_datetime=session.date, + effort_level=effort_level, + ) + ) + return items + + +def _parse_sessions(conversation: Any) -> list[ConversationSession]: + if not isinstance(conversation, dict): + return [] + + session_numbers = sorted( + _session_number(key) + for key, value in conversation.items() + if key.startswith("session_") + and not key.endswith("_date_time") + and isinstance(value, list) + ) + sessions: list[ConversationSession] = [] + for number in session_numbers: + session_key = f"session_{number}" + date_key = f"session_{number}_date_time" + turns = _parse_turns(conversation.get(session_key) or []) + if turns: + sessions.append( + ConversationSession( + session_id=session_key, + date=str(conversation.get(date_key) or ""), + turns=turns, + ) + ) + return sessions + + +def _parse_turns(raw_turns: list[Any]) -> list[ConversationTurn]: + turns: list[ConversationTurn] = [] + for raw_turn in raw_turns: + if not isinstance(raw_turn, dict): + continue + content = str(raw_turn.get("text") or raw_turn.get("content") or "").strip() + caption = str(raw_turn.get("blip_caption") or "").strip() + if caption: + content = f"{content}\nImage caption: {caption}".strip() + if not content: + continue + turns.append( + ConversationTurn( + speaker=str(raw_turn.get("speaker") or "speaker").strip(), + content=content, + dialog_id=str(raw_turn.get("dia_id") or raw_turn.get("id") or ""), + ) + ) + return turns + + +def _format_turn(turn: ConversationTurn | None) -> str: + if turn is None: + return "" + prefix = f"{turn.speaker}" + if turn.dialog_id: + prefix = f"{prefix} ({turn.dialog_id})" + return f"{prefix}: {turn.content}" + + +def _session_number(key: str) -> int: + try: + return int(key.split("_", 1)[1]) + except (IndexError, ValueError): + return 0 diff --git a/benchmarks/locomo/run.py b/benchmarks/locomo/run.py new file mode 100644 index 0000000..dad12af --- /dev/null +++ b/benchmarks/locomo/run.py @@ -0,0 +1,100 @@ +"""Command line entrypoint for the LoCoMo benchmark.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +from pathlib import Path + +from .config import DEFAULT_API_BASE_URL, DEFAULT_API_KEY_ENV, BenchmarkConfig +from .dataset import download_dataset +from .runner import LoCoMoRunner + + +def main() -> None: + try: + args = parse_args() + dataset_path = prepare_dataset(args) + config = BenchmarkConfig( + dataset_path=dataset_path, + output_dir=args.output_dir, + api_base_url=args.api_base_url, + api_key_env=args.api_key_env, + api_timeout_seconds=args.api_timeout_seconds, + max_retries=args.max_retries, + retry_backoff_seconds=args.retry_backoff_seconds, + batch_size=args.batch_size, + ingest_api_version=args.ingest_api_version, + poll_interval_seconds=args.poll_interval_seconds, + poll_timeout_seconds=args.poll_timeout_seconds, + top_k=args.top_k, + effort_level=args.effort_level, + user_prefix=args.user_prefix, + limit=args.limit, + offset=args.offset, + category=args.category, + skip_ingest=args.skip_ingest, + resume=not args.no_resume, + dry_run=args.dry_run, + ) + print(json.dumps(asyncio.run(LoCoMoRunner(config).run()), indent=2)) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + + +def prepare_dataset(args: argparse.Namespace) -> Path: + if args.download: + try: + return download_dataset(args.dataset_path) + except Exception as exc: + raise RuntimeError( + "Failed to download LoCoMo. Check network access, or pass " + f"--dataset-path to a local file. Details: {exc}" + ) from exc + if not args.dataset_path.exists(): + raise FileNotFoundError( + f"Dataset file not found: {args.dataset_path}. " + "Run with --download, or pass --dataset-path." + ) + return args.dataset_path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run LoCoMo against XMem.") + parser.add_argument( + "--dataset-path", + type=Path, + default=Path("benchmarks/locomo/data/locomo10.json"), + ) + parser.add_argument("--download", action="store_true") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("benchmarks/locomo/results/latest"), + ) + parser.add_argument("--api-base-url", default=DEFAULT_API_BASE_URL) + parser.add_argument("--api-key-env", default=DEFAULT_API_KEY_ENV) + parser.add_argument("--api-timeout-seconds", type=float, default=120.0) + parser.add_argument("--max-retries", type=int, default=3) + parser.add_argument("--retry-backoff-seconds", type=float, default=2.0) + parser.add_argument("--batch-size", type=int, default=25) + parser.add_argument("--ingest-api-version", choices=("v1", "v2"), default="v2") + parser.add_argument("--poll-interval-seconds", type=float, default=2.0) + parser.add_argument("--poll-timeout-seconds", type=float, default=1800.0) + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--effort-level", choices=("low", "high"), default="low") + parser.add_argument("--user-prefix", default="locomo") + parser.add_argument("--limit", type=int) + parser.add_argument("--offset", type=int, default=0) + parser.add_argument("--category") + parser.add_argument("--skip-ingest", action="store_true") + parser.add_argument("--no-resume", action="store_true") + parser.add_argument("--dry-run", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/locomo/runner.py b/benchmarks/locomo/runner.py new file mode 100644 index 0000000..db5c625 --- /dev/null +++ b/benchmarks/locomo/runner.py @@ -0,0 +1,172 @@ +"""LoCoMo benchmark orchestration against the Python XMem API.""" + +from __future__ import annotations + +import time +from typing import Any + +from benchmarks.common.io import append_jsonl, read_jsonl, write_json +from benchmarks.common.metrics import score_answer, summarize_results +from benchmarks.common.xmem import XMemApiClient + +from .config import BenchmarkConfig +from .dataset import LoCoMoExample, build_ingest_items, load_examples, select_examples + + +class LoCoMoRunner: + def __init__(self, config: BenchmarkConfig) -> None: + self.config = config + self.results_path = config.output_dir / "results.jsonl" + self.predictions_path = config.output_dir / "predictions.jsonl" + self.summary_path = config.output_dir / "summary.json" + + async def run(self) -> dict[str, Any]: + examples = select_examples( + load_examples(self.config.dataset_path), + offset=self.config.offset, + limit=self.config.limit, + category=self.config.category, + ) + if self.config.dry_run: + return self._dry_run_summary(examples) + + completed_ids = self._completed_question_ids() if self.config.resume else set() + run_started = time.time() + async with XMemApiClient( + base_url=self.config.api_base_url, + api_key=self.config.require_api_key(), + timeout_seconds=self.config.api_timeout_seconds, + max_retries=self.config.max_retries, + retry_backoff_seconds=self.config.retry_backoff_seconds, + ) as client: + for index, example in enumerate(examples, start=1): + if example.question_id in completed_ids: + continue + result = await self._run_example( + client, + example, + index=index, + total=len(examples), + ) + append_jsonl(self.results_path, result) + append_jsonl( + self.predictions_path, + { + "question_id": result["question_id"], + "hypothesis": result["prediction"], + }, + ) + + results = read_jsonl(self.results_path) + summary = summarize_results(results, group_field="category") + summary["dataset_path"] = str(self.config.dataset_path) + summary["api_base_url"] = self.config.api_base_url + summary["duration_seconds"] = round(time.time() - run_started, 2) + write_json(self.summary_path, summary) + return summary + + async def _run_example( + self, + client: XMemApiClient, + example: LoCoMoExample, + *, + index: int, + total: int, + ) -> dict[str, Any]: + user_id = f"{self.config.user_prefix}-{example.user_id_suffix}" + ingest_count = 0 + ingest_elapsed_ms = 0.0 + if not self.config.skip_ingest: + items = build_ingest_items( + example, + user_id=user_id, + effort_level=self.config.effort_level, + ) + ingest_count = len(items) + ingest_elapsed_ms = await self._ingest_items(client, items) + + retrieve = await client.retrieve( + {"query": example.question, "user_id": user_id, "top_k": self.config.top_k} + ) + prediction = str(retrieve.data.get("answer") or "") + result = { + "question_id": example.question_id, + "sample_id": example.sample_id, + "category": example.category or "unknown", + "question": example.question, + "reference_answer": example.answer, + "prediction": prediction, + "metrics": score_answer(prediction, example.answer), + "source_count": len(retrieve.data.get("sources") or []), + "confidence": retrieve.data.get("confidence"), + "user_id": user_id, + "ingest_count": ingest_count, + "ingest_elapsed_ms": round(ingest_elapsed_ms, 2), + "retrieve_elapsed_ms": retrieve.elapsed_ms, + "index": index, + "total": total, + } + print( + f"[{index}/{total}] {example.question_id}: " + f"f1={result['metrics']['token_f1']} retrieve_ms={retrieve.elapsed_ms}", + flush=True, + ) + return result + + async def _ingest_items(self, client: XMemApiClient, items: list[Any]) -> float: + if not items: + return 0.0 + elapsed_ms = 0.0 + processed = 0 + for start in range(0, len(items), self.config.batch_size): + chunk = items[start : start + self.config.batch_size] + payload = [item.__dict__ for item in chunk] + if self.config.ingest_api_version == "v1": + result = await client.batch_ingest_v1(payload) + elapsed_ms += result.elapsed_ms + else: + accepted = await client.batch_ingest_v2(payload) + elapsed_ms += accepted.elapsed_ms + status_url = str(accepted.data.get("status_url") or "") + if not status_url: + raise RuntimeError("XMem v2 batch ingest missing status_url") + status = await client.poll_job( + status_url, + interval_seconds=self.config.poll_interval_seconds, + timeout_seconds=self.config.poll_timeout_seconds, + ) + elapsed_ms += status.elapsed_ms + if str(status.data.get("status") or "").lower() != "succeeded": + raise RuntimeError(f"XMem batch ingest failed: {status.data}") + processed += len(chunk) + print(f"[INGEST] processed={processed}/{len(items)}", flush=True) + return elapsed_ms + + def _completed_question_ids(self) -> set[str]: + return {str(row.get("question_id")) for row in read_jsonl(self.results_path)} + + def _dry_run_summary(self, examples: list[LoCoMoExample]) -> dict[str, Any]: + ingest_counts = [ + len( + build_ingest_items( + example, + user_id="dry-run", + effort_level=self.config.effort_level, + ) + ) + for example in examples + ] + summary = { + "dry_run": True, + "dataset_path": str(self.config.dataset_path), + "selected_examples": len(examples), + "total_ingest_items": sum(ingest_counts), + "min_ingest_items": min(ingest_counts) if ingest_counts else 0, + "max_ingest_items": max(ingest_counts) if ingest_counts else 0, + "categories": sorted( + {example.category or "unknown" for example in examples} + ), + "samples": sorted({example.sample_id for example in examples}), + } + write_json(self.summary_path, summary) + return summary From 843c2a03d25b54545ab3f2884cf14d3d43700ded Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 14:22:28 +0530 Subject: [PATCH 2/5] Address benchmark review feedback --- benchmarks/beam/dataset.py | 8 ++++---- benchmarks/common/io.py | 26 ++++++++++++++++---------- benchmarks/common/xmem.py | 11 ++++++++--- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/benchmarks/beam/dataset.py b/benchmarks/beam/dataset.py index 2a94bba..4939955 100644 --- a/benchmarks/beam/dataset.py +++ b/benchmarks/beam/dataset.py @@ -196,11 +196,11 @@ def _coerce_literal(value: Any) -> Any: if not text: return [] try: - return ast.literal_eval(text) - except (SyntaxError, ValueError): + return json.loads(text) + except ValueError: try: - return json.loads(text) - except json.JSONDecodeError: + return ast.literal_eval(text) + except (SyntaxError, ValueError): return value diff --git a/benchmarks/common/io.py b/benchmarks/common/io.py index dc6b97f..4bb1a25 100644 --- a/benchmarks/common/io.py +++ b/benchmarks/common/io.py @@ -16,16 +16,22 @@ def download_file( timeout_seconds: float = 120.0, ) -> Path: destination.parent.mkdir(parents=True, exist_ok=True) - with httpx.stream( - "GET", - url, - follow_redirects=True, - timeout=timeout_seconds, - ) as response: - response.raise_for_status() - with destination.open("wb") as handle: - for chunk in response.iter_bytes(): - handle.write(chunk) + partial_destination = destination.with_suffix(destination.suffix + ".tmp") + try: + with httpx.stream( + "GET", + url, + follow_redirects=True, + timeout=timeout_seconds, + ) as response: + response.raise_for_status() + with partial_destination.open("wb") as handle: + for chunk in response.iter_bytes(): + handle.write(chunk) + partial_destination.replace(destination) + except BaseException: + partial_destination.unlink(missing_ok=True) + raise return destination diff --git a/benchmarks/common/xmem.py b/benchmarks/common/xmem.py index f9d19d5..fe8e10b 100644 --- a/benchmarks/common/xmem.py +++ b/benchmarks/common/xmem.py @@ -107,11 +107,16 @@ async def _request(self, method: str, path: str, **kwargs: Any) -> ApiCallResult if response is None: raise RuntimeError(f"No response from {method} {request_path}") elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - response.raise_for_status() - body = response.json() - if body.get("status") == "error": + try: + body = response.json() + except ValueError: + body = {} + if isinstance(body, dict) and body.get("status") == "error": error = body.get("error") or f"XMem API error from {request_path}" raise RuntimeError(error) + response.raise_for_status() + if not isinstance(body, dict): + body = {} data = body.get("data") if data is None: data = {} From 553540ceadce2f65fddd3498708c448ed6b2e928 Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 14:30:34 +0530 Subject: [PATCH 3/5] Fix benchmark ingest pairing --- benchmarks/beam/dataset.py | 45 +++++++++++++++++++++++++++++---- benchmarks/common/metrics.py | 8 ++---- benchmarks/locomo/dataset.py | 49 +++++++++++++++++++++++++++++------- 3 files changed, 82 insertions(+), 20 deletions(-) diff --git a/benchmarks/beam/dataset.py b/benchmarks/beam/dataset.py index 4939955..633d422 100644 --- a/benchmarks/beam/dataset.py +++ b/benchmarks/beam/dataset.py @@ -136,14 +136,14 @@ def build_ingest_items( ) -> list[IngestItem]: items: list[IngestItem] = [] for session in example.chat_sessions: - for index, turn in enumerate(session): - next_turn = session[index + 1] if index + 1 < len(session) else None + for user_turns, agent_turn in _exchange_pairs(session): + first_user_turn = user_turns[0] items.append( IngestItem( - user_query=_format_turn(turn), - agent_response=_format_turn(next_turn) if next_turn else "", + user_query="\n".join(_format_turn(turn) for turn in user_turns), + agent_response=_format_turn(agent_turn), user_id=user_id, - session_datetime=turn.time_anchor, + session_datetime=first_user_turn.time_anchor, effort_level=effort_level, ) ) @@ -221,3 +221,38 @@ def _format_turn(turn: BeamTurn | None) -> str: if turn.time_anchor: prefix = f"{prefix} [{turn.time_anchor}]" return f"{prefix}: {turn.content}" + + +def _exchange_pairs( + turns: list[BeamTurn], +) -> list[tuple[list[BeamTurn], BeamTurn]]: + if any(_is_assistant_role(turn.role) for turn in turns): + return _role_aware_exchange_pairs(turns) + + pairs: list[tuple[list[BeamTurn], BeamTurn]] = [] + for index in range(0, len(turns) - 1, 2): + pairs.append(([turns[index]], turns[index + 1])) + return pairs + + +def _role_aware_exchange_pairs( + turns: list[BeamTurn], +) -> list[tuple[list[BeamTurn], BeamTurn]]: + pairs: list[tuple[list[BeamTurn], BeamTurn]] = [] + pending_user_turns: list[BeamTurn] = [] + for turn in turns: + if _is_assistant_role(turn.role): + if pending_user_turns: + pairs.append((pending_user_turns, turn)) + pending_user_turns = [] + continue + pending_user_turns.append(turn) + return pairs + + +def _is_assistant_role(role: str) -> bool: + normalized = role.lower().replace("_", " ").replace("-", " ") + return any( + label in normalized + for label in ("assistant", "agent", "bot", "ai", "gpt") + ) diff --git a/benchmarks/common/metrics.py b/benchmarks/common/metrics.py index 707bbf7..b574682 100644 --- a/benchmarks/common/metrics.py +++ b/benchmarks/common/metrics.py @@ -3,7 +3,7 @@ from __future__ import annotations import re -from collections import defaultdict +from collections import Counter, defaultdict from typing import Any @@ -19,11 +19,7 @@ def token_f1(prediction: str, reference: str) -> float: ref_tokens = normalize_answer(reference).split() if not pred_tokens or not ref_tokens: return float(pred_tokens == ref_tokens) - common = set(pred_tokens) & set(ref_tokens) - overlap = sum( - min(pred_tokens.count(token), ref_tokens.count(token)) - for token in common - ) + overlap = sum((Counter(pred_tokens) & Counter(ref_tokens)).values()) if overlap == 0: return 0.0 precision = overlap / len(pred_tokens) diff --git a/benchmarks/locomo/dataset.py b/benchmarks/locomo/dataset.py index 23b8975..2f58d67 100644 --- a/benchmarks/locomo/dataset.py +++ b/benchmarks/locomo/dataset.py @@ -118,15 +118,11 @@ def build_ingest_items( ) -> list[IngestItem]: items: list[IngestItem] = [] for session in example.sessions: - turns = session.turns - for index, turn in enumerate(turns): - next_turn = turns[index + 1] if index + 1 < len(turns) else None - user_query = _format_turn(turn) - agent_response = _format_turn(next_turn) if next_turn else "" + for user_turns, agent_turn in _exchange_pairs(session.turns): items.append( IngestItem( - user_query=user_query, - agent_response=agent_response, + user_query="\n".join(_format_turn(turn) for turn in user_turns), + agent_response=_format_turn(agent_turn), user_id=user_id, session_datetime=session.date, effort_level=effort_level, @@ -139,13 +135,13 @@ def _parse_sessions(conversation: Any) -> list[ConversationSession]: if not isinstance(conversation, dict): return [] - session_numbers = sorted( + session_numbers = sorted({ _session_number(key) for key, value in conversation.items() if key.startswith("session_") and not key.endswith("_date_time") and isinstance(value, list) - ) + }) sessions: list[ConversationSession] = [] for number in session_numbers: session_key = f"session_{number}" @@ -192,6 +188,41 @@ def _format_turn(turn: ConversationTurn | None) -> str: return f"{prefix}: {turn.content}" +def _exchange_pairs( + turns: list[ConversationTurn], +) -> list[tuple[list[ConversationTurn], ConversationTurn]]: + if any(_is_assistant_speaker(turn.speaker) for turn in turns): + return _role_aware_exchange_pairs(turns) + + pairs: list[tuple[list[ConversationTurn], ConversationTurn]] = [] + for index in range(0, len(turns) - 1, 2): + pairs.append(([turns[index]], turns[index + 1])) + return pairs + + +def _role_aware_exchange_pairs( + turns: list[ConversationTurn], +) -> list[tuple[list[ConversationTurn], ConversationTurn]]: + pairs: list[tuple[list[ConversationTurn], ConversationTurn]] = [] + pending_user_turns: list[ConversationTurn] = [] + for turn in turns: + if _is_assistant_speaker(turn.speaker): + if pending_user_turns: + pairs.append((pending_user_turns, turn)) + pending_user_turns = [] + continue + pending_user_turns.append(turn) + return pairs + + +def _is_assistant_speaker(speaker: str) -> bool: + normalized = speaker.lower().replace("_", " ").replace("-", " ") + return any( + label in normalized + for label in ("assistant", "agent", "bot", "ai", "gpt") + ) + + def _session_number(key: str) -> int: try: return int(key.split("_", 1)[1]) From 6338c229002136ea8624dbedd9f6bdaf5a9513dc Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 14:38:29 +0530 Subject: [PATCH 4/5] Document benchmark additions in changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1251ec3..3d5a1a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,5 +2,6 @@ ## Unreleased +- Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. - Add local XMem setup through `npx create-xmem@latest` and `npm run dev`. - Add local Docker storage, Chrome extension build patching, diagnostics, verification, and context export/import/sync commands. From f6fd5cdced0afc2fc9798a82e4fe9ccebfcad9ab Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 15:01:42 +0530 Subject: [PATCH 5/5] Fix BEAM sampling and evaluation --- benchmarks/beam/README.md | 35 ++++- benchmarks/beam/config.py | 3 + benchmarks/beam/dataset.py | 22 ++- benchmarks/beam/evaluate.py | 280 ++++++++++++++++++++++++++++++++++++ benchmarks/beam/run.py | 10 ++ benchmarks/beam/runner.py | 52 ++++++- 6 files changed, 395 insertions(+), 7 deletions(-) create mode 100644 benchmarks/beam/evaluate.py diff --git a/benchmarks/beam/README.md b/benchmarks/beam/README.md index ac7d4d1..d0bd80f 100644 --- a/benchmarks/beam/README.md +++ b/benchmarks/beam/README.md @@ -44,10 +44,43 @@ python -m benchmarks.beam.run \ --output-dir benchmarks/beam/results/beam-1m ``` +To run a balanced slice, sample an equal percentage from each BEAM question +type: + +```bash +python -m benchmarks.beam.run \ + --split 1M \ + --dataset-path benchmarks/beam/data/1M-00000-of-00001.parquet \ + --sample-percent-per-question-type 1 \ + --api-base-url https://api.xmem.in \ + --output-dir benchmarks/beam/results/beam-1m-1pct +``` + Outputs: - `results.jsonl`: full per-question records with local proxy metrics - `predictions.jsonl`: `question_id` and `hypothesis` - `summary.json`: local exact/contains/token-F1 grouped by BEAM question type -Use BEAM's official/equivalent evaluator for publication-quality accuracy. +## Judge Evaluation + +The benchmark runner writes the BEAM rubric for each question into +`results.jsonl`. To compute BEAM-style pass rate and average judge score with an +OpenAI judge model: + +```bash +export OPENAI_API_KEY="..." + +python -m benchmarks.beam.evaluate \ + --results-path benchmarks/beam/results/beam-1m-1pct/results.jsonl \ + --output-dir benchmarks/beam/results/beam-1m-1pct +``` + +This writes: + +- `evaluations.jsonl`: per-question rubric judge scores and reasons +- `evaluation_summary.json`: pass rate and average judge score overall and by + question type + +The pass threshold is `judge_score >= 0.5`, matching the usual BEAM pass-rate +interpretation over rubric scores. diff --git a/benchmarks/beam/config.py b/benchmarks/beam/config.py index 41f20f9..fc84acb 100644 --- a/benchmarks/beam/config.py +++ b/benchmarks/beam/config.py @@ -45,6 +45,9 @@ class BenchmarkConfig: limit: int | None = None offset: int = 0 question_type: str | None = None + sample_percent_per_question_type: float | None = None + sample_min_per_question_type: int = 1 + sample_seed: int = 13 split: str = DEFAULT_SPLIT skip_ingest: bool = False resume: bool = True diff --git a/benchmarks/beam/dataset.py b/benchmarks/beam/dataset.py index 633d422..dc5302c 100644 --- a/benchmarks/beam/dataset.py +++ b/benchmarks/beam/dataset.py @@ -34,6 +34,7 @@ class BeamExample: question: str answer: str question_type: str = "" + rubric: list[str] = field(default_factory=list) split: str = "1M" chat_sessions: list[list[BeamTurn]] = field(default_factory=list) @@ -83,7 +84,7 @@ def load_examples(path: Path, *, split: str = "1M") -> list[BeamExample]: continue answer = _first_text( question_record, - ("answer", "gold_answer", "reference"), + ("answer", "ideal_response", "gold_answer", "reference"), ) question_type = _first_text( question_record, @@ -100,6 +101,10 @@ def load_examples(path: Path, *, split: str = "1M") -> list[BeamExample]: question=question, answer=answer, question_type=question_type or "unknown", + rubric=[ + str(item) + for item in question_record.get("rubric") or [] + ], split=split, chat_sessions=chat_sessions, ) @@ -182,8 +187,19 @@ def _parse_turn(raw_turn: Any) -> BeamTurn | None: def _parse_probing_questions(raw_questions: Any) -> list[dict[str, Any]]: raw_questions = _coerce_literal(raw_questions) if isinstance(raw_questions, dict): - values = raw_questions.values() - return [item for item in values if isinstance(item, dict)] + questions: list[dict[str, Any]] = [] + for question_type, value in raw_questions.items(): + if isinstance(value, dict): + item = dict(value) + item["question_type"] = str(question_type) + questions.append(item) + elif isinstance(value, list): + for question in value: + if isinstance(question, dict): + item = dict(question) + item["question_type"] = str(question_type) + questions.append(item) + return questions if isinstance(raw_questions, list): return [item for item in raw_questions if isinstance(item, dict)] return [] diff --git a/benchmarks/beam/evaluate.py b/benchmarks/beam/evaluate.py new file mode 100644 index 0000000..30b2e36 --- /dev/null +++ b/benchmarks/beam/evaluate.py @@ -0,0 +1,280 @@ +"""LLM-as-judge evaluation for BEAM benchmark outputs.""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import Any + +from openai import OpenAI + +from benchmarks.common.io import append_jsonl, read_jsonl, write_json + + +DEFAULT_OPENAI_API_KEY_ENV = "OPENAI_API_KEY" +DEFAULT_JUDGE_MODEL = "gpt-4o" +PASS_THRESHOLD = 0.5 + +JUDGE_PROMPT = """ +You are an expert evaluator tasked with judging whether the LLM's response +demonstrates compliance with the specified RUBRIC CRITERION. + +## EVALUATION INPUTS +- QUESTION (what the user asked): +- RUBRIC CRITERION (what to check): +- RESPONSE TO EVALUATE: + +## EVALUATION RUBRIC: +The rubric defines a specific requirement, constraint, or expected behavior +that the LLM response should demonstrate. + +IMPORTANT: Pay careful attention to whether the rubric specifies positive +requirements or negative constraints. + +## RESPONSIVENESS REQUIREMENT +A compliant response must be on-topic with respect to the QUESTION and attempt +to answer it. If the response does not address the QUESTION, score 0.0. + +## SEMANTIC TOLERANCE RULES +Judge by meaning, not exact wording. Accept paraphrases and synonyms that +preserve intent. Ignore case, punctuation, and whitespace differences. +Numbers, currencies, dates, and durations may appear in equivalent forms. + +## STYLE NEUTRALITY +Ignore tone, politeness, length, and flourish unless the rubric explicitly +requires a format or structure. + +## SCORING SCALE +- 1.0: complete compliance. +- 0.5: partial compliance. +- 0.0: no compliance. + +## OUTPUT FORMAT +Return only a JSON object: + +{ + "score": 1.0, + "reason": "why the rubric criterion was or was not satisfied" +} +""".strip() + + +def main() -> None: + try: + args = parse_args() + evaluator = BeamEvaluator( + results_path=args.results_path, + output_dir=args.output_dir or args.results_path.parent, + model=args.judge_model, + api_key_env=args.openai_api_key_env, + max_retries=args.max_retries, + retry_backoff_seconds=args.retry_backoff_seconds, + ) + print(json.dumps(evaluator.run(), indent=2)) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + + +class BeamEvaluator: + def __init__( + self, + *, + results_path: Path, + output_dir: Path, + model: str, + api_key_env: str, + max_retries: int, + retry_backoff_seconds: float, + ) -> None: + self.results_path = results_path + self.output_dir = output_dir + self.model = model + self.api_key_env = api_key_env + self.max_retries = max_retries + self.retry_backoff_seconds = retry_backoff_seconds + self.evaluations_path = output_dir / "evaluations.jsonl" + self.summary_path = output_dir / "evaluation_summary.json" + + def run(self) -> dict[str, Any]: + api_key = os.getenv(self.api_key_env, "").strip() + if not api_key: + raise RuntimeError(f"Missing API key. Set {self.api_key_env}.") + + client = OpenAI(api_key=api_key) + completed = { + str(row.get("question_id")) + for row in read_jsonl(self.evaluations_path) + } + started = time.time() + for index, result in enumerate(read_jsonl(self.results_path), start=1): + question_id = str(result.get("question_id") or "") + if question_id in completed: + continue + evaluation = self._evaluate_result(client, result) + append_jsonl(self.evaluations_path, evaluation) + print( + f"[EVAL {index}] {question_id}: " + f"score={evaluation['judge_score']} " + f"passed={evaluation['passed']}", + flush=True, + ) + + evaluations = read_jsonl(self.evaluations_path) + summary = summarize_evaluations(evaluations) + summary["results_path"] = str(self.results_path) + summary["judge_model"] = self.model + summary["pass_threshold"] = PASS_THRESHOLD + summary["duration_seconds"] = round(time.time() - started, 2) + write_json(self.summary_path, summary) + return summary + + def _evaluate_result( + self, + client: OpenAI, + result: dict[str, Any], + ) -> dict[str, Any]: + rubric = [str(item) for item in result.get("rubric") or []] + if not rubric: + reference = str(result.get("reference_answer") or "") + rubric = [reference] if reference else ["The response answers correctly."] + + judge_items = [ + self._judge_rubric_item( + client, + question=str(result.get("question") or ""), + prediction=str(result.get("prediction") or ""), + rubric_item=item, + ) + for item in rubric + ] + score = sum(float(item["score"]) for item in judge_items) / len(judge_items) + return { + "question_id": result.get("question_id"), + "conversation_id": result.get("conversation_id"), + "question_type": result.get("question_type") or "unknown", + "question": result.get("question"), + "reference_answer": result.get("reference_answer"), + "prediction": result.get("prediction"), + "judge_score": round(score, 4), + "passed": score >= PASS_THRESHOLD, + "judge_items": judge_items, + } + + def _judge_rubric_item( + self, + client: OpenAI, + *, + question: str, + prediction: str, + rubric_item: str, + ) -> dict[str, Any]: + prompt = ( + JUDGE_PROMPT.replace("", question) + .replace("", rubric_item) + .replace("", prediction) + ) + last_error: Exception | None = None + for attempt in range(self.max_retries + 1): + try: + response = client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content or "{}" + return normalize_judge_response(content) + except Exception as exc: + last_error = exc + if attempt < self.max_retries: + time.sleep(self.retry_backoff_seconds * (attempt + 1)) + raise RuntimeError(f"Judge call failed: {last_error}") from last_error + + +def normalize_judge_response(content: str) -> dict[str, Any]: + payload = parse_json_response(content) + raw_score = payload.get("score", 0) + try: + score = float(raw_score) + except (TypeError, ValueError): + score = 0.0 + score = min(1.0, max(0.0, score)) + return { + "score": score, + "reason": str(payload.get("reason") or ""), + } + + +def parse_json_response(content: str) -> dict[str, Any]: + text = content.strip() + if text.startswith("```"): + match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, re.DOTALL) + if match: + text = match.group(1).strip() + try: + payload = json.loads(text) + except json.JSONDecodeError: + match = re.search(r"(\{.*\})", text, re.DOTALL) + if not match: + raise + payload = json.loads(match.group(1)) + if not isinstance(payload, dict): + raise ValueError("Judge response must be a JSON object.") + return payload + + +def summarize_evaluations(evaluations: list[dict[str, Any]]) -> dict[str, Any]: + if not evaluations: + return {"count": 0, "overall": {}, "by_question_type": {}} + + buckets: dict[str, list[dict[str, Any]]] = defaultdict(list) + for evaluation in evaluations: + buckets[str(evaluation.get("question_type") or "unknown")].append(evaluation) + + return { + "count": len(evaluations), + "overall": summarize_bucket(evaluations), + "by_question_type": { + label: summarize_bucket(bucket) + for label, bucket in sorted(buckets.items()) + }, + } + + +def summarize_bucket(evaluations: list[dict[str, Any]]) -> dict[str, float | int]: + count = len(evaluations) + passed = sum(1 for item in evaluations if item.get("passed")) + score = sum(float(item.get("judge_score") or 0.0) for item in evaluations) + return { + "count": count, + "passed": passed, + "pass_rate": round(passed / count, 4), + "avg_judge_score": round(score / count, 4), + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Evaluate BEAM XMem results.") + parser.add_argument( + "--results-path", + type=Path, + required=True, + help="Path to BEAM results.jsonl generated by benchmarks.beam.run.", + ) + parser.add_argument("--output-dir", type=Path) + parser.add_argument("--judge-model", default=DEFAULT_JUDGE_MODEL) + parser.add_argument("--openai-api-key-env", default=DEFAULT_OPENAI_API_KEY_ENV) + parser.add_argument("--max-retries", type=int, default=3) + parser.add_argument("--retry-backoff-seconds", type=float, default=2.0) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/beam/run.py b/benchmarks/beam/run.py index a752e63..c1f1ae1 100644 --- a/benchmarks/beam/run.py +++ b/benchmarks/beam/run.py @@ -40,6 +40,9 @@ def main() -> None: limit=args.limit, offset=args.offset, question_type=args.question_type, + sample_percent_per_question_type=args.sample_percent_per_question_type, + sample_min_per_question_type=args.sample_min_per_question_type, + sample_seed=args.sample_seed, split=args.split, skip_ingest=args.skip_ingest, resume=not args.no_resume, @@ -101,6 +104,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--limit", type=int) parser.add_argument("--offset", type=int, default=0) parser.add_argument("--question-type") + parser.add_argument( + "--sample-percent-per-question-type", + type=float, + help="Select a balanced percent from each BEAM question_type.", + ) + parser.add_argument("--sample-min-per-question-type", type=int, default=1) + parser.add_argument("--sample-seed", type=int, default=13) parser.add_argument("--skip-ingest", action="store_true") parser.add_argument("--no-resume", action="store_true") parser.add_argument("--dry-run", action="store_true") diff --git a/benchmarks/beam/runner.py b/benchmarks/beam/runner.py index e9416e8..cd60a71 100644 --- a/benchmarks/beam/runner.py +++ b/benchmarks/beam/runner.py @@ -2,7 +2,10 @@ from __future__ import annotations +import math +import random import time +from collections import defaultdict from typing import Any from benchmarks.common.io import append_jsonl, read_jsonl, write_json @@ -27,6 +30,7 @@ async def run(self) -> dict[str, Any]: limit=self.config.limit, question_type=self.config.question_type, ) + examples = self._sample_examples(examples) if self.config.dry_run: return self._dry_run_summary(examples) @@ -62,6 +66,7 @@ async def run(self) -> dict[str, Any]: summary["dataset_path"] = str(self.config.dataset_path) summary["api_base_url"] = self.config.api_base_url summary["split"] = self.config.split + summary["sample"] = self._sample_summary() summary["duration_seconds"] = round(time.time() - run_started, 2) write_json(self.summary_path, summary) return summary @@ -98,6 +103,7 @@ async def _run_example( "question": example.question, "reference_answer": example.answer, "prediction": prediction, + "rubric": example.rubric, "metrics": score_answer(prediction, example.answer), "source_count": len(retrieve.data.get("sources") or []), "confidence": retrieve.data.get("confidence"), @@ -147,6 +153,27 @@ async def _ingest_items(self, client: XMemApiClient, items: list[Any]) -> float: def _completed_question_ids(self) -> set[str]: return {str(row.get("question_id")) for row in read_jsonl(self.results_path)} + def _sample_examples(self, examples: list[BeamExample]) -> list[BeamExample]: + percent = self.config.sample_percent_per_question_type + if percent is None: + return examples + if percent <= 0 or percent > 100: + raise ValueError("--sample-percent-per-question-type must be in (0, 100].") + + groups: dict[str, list[BeamExample]] = defaultdict(list) + for example in examples: + groups[example.question_type or "unknown"].append(example) + + rng = random.Random(self.config.sample_seed) + selected: list[BeamExample] = [] + for question_type in sorted(groups): + group = list(groups[question_type]) + rng.shuffle(group) + count = math.ceil(len(group) * percent / 100) + count = max(self.config.sample_min_per_question_type, count) + selected.extend(group[: min(count, len(group))]) + return selected + def _dry_run_summary(self, examples: list[BeamExample]) -> dict[str, Any]: ingest_counts = [ len( @@ -158,18 +185,37 @@ def _dry_run_summary(self, examples: list[BeamExample]) -> dict[str, Any]: ) for example in examples ] + by_question_type = self._count_by_question_type(examples) summary = { "dry_run": True, "dataset_path": str(self.config.dataset_path), "split": self.config.split, + "sample": self._sample_summary(), "selected_examples": len(examples), "total_ingest_items": sum(ingest_counts), "min_ingest_items": min(ingest_counts) if ingest_counts else 0, "max_ingest_items": max(ingest_counts) if ingest_counts else 0, - "question_types": sorted( - {example.question_type or "unknown" for example in examples} - ), + "by_question_type": by_question_type, + "question_types": sorted(by_question_type), "conversations": sorted({example.conversation_id for example in examples}), } write_json(self.summary_path, summary) return summary + + def _sample_summary(self) -> dict[str, Any]: + return { + "percent_per_question_type": ( + self.config.sample_percent_per_question_type + ), + "min_per_question_type": self.config.sample_min_per_question_type, + "seed": self.config.sample_seed, + } + + @staticmethod + def _count_by_question_type( + examples: list[BeamExample], + ) -> dict[str, int]: + counts: dict[str, int] = defaultdict(int) + for example in examples: + counts[example.question_type or "unknown"] += 1 + return dict(sorted(counts.items()))