diff --git a/README.md b/README.md index 306443d..ab67ba1 100644 --- a/README.md +++ b/README.md @@ -197,20 +197,36 @@ Task names follow [LMHarness](https://github.com/EleutherAI/lm-evaluation-harnes ### Generate + judge (pairwise) -| Task | Description | -|-----------------------|------------------------------------------------------------------------------------------------| -| `alpaca-eval` | General instruction-following benchmark | -| `arena-hard-v2.0` | Arena-Hard v2.0 from official `lmarena-ai/arena-hard-auto` source | -| `arena-hard-v0.1` | Legacy Arena-Hard v0.1 from official `lmarena-ai/arena-hard-auto` source | -| `m-arena-hard` | Translated version of Arena-Hard in 23 languages | -| `m-arena-hard-{lang}` | Language-specific variants (e.g., `ar`, `cs`, `de`) | -| `m-arena-hard-EU` | All EU languages combined | -| `mt-bench` | Multi-turn benchmark with FastChat-compatible pairwise judging | -| `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | +| Task | Description | +|------------------------------|------------------------------------------------------------------------------------------------| +| `alpaca-eval` | General instruction-following benchmark | +| `arena-hard-v2.0` | Arena-Hard v2.0 from official `lmarena-ai/arena-hard-auto` source | +| `arena-hard-v0.1` | Legacy Arena-Hard v0.1 from official `lmarena-ai/arena-hard-auto` source | +| `m-arena-hard-v0.1` | `CohereLabs/m-ArenaHard` (500 prompts, Google-Translate) across 23 languages | +| `m-arena-hard-v0.1-{lang}` | Language-specific v0.1 slice (e.g., `ar`, `cs`, `de`, `uk`, `zh`, `pl`) | +| `m-arena-hard-v0.1-EU` | All EU v0.1 languages combined | +| `m-arena-hard-v2.0` | `CohereLabs/m-ArenaHard-v2.0` (498 prompts, in-house translation) across 23 languages | +| `m-arena-hard-v2.0-{lang}` | Language-specific v2.0 slice | +| `m-arena-hard-v2.0-EU` | All EU v2.0 languages combined | +| `mt-bench` | Multi-turn benchmark with FastChat-compatible pairwise judging | +| `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | + +For MT-Bench, the default pairwise baseline is `gpt-4`. +We diverge from FastChat's own `pairwise-baseline` default (`gpt-3.5-turbo`) to keep +a stronger reference consistent with Arena-Hard v0.1; the `gpt-4.jsonl` completions +ship in the `lmsys/mt-bench` HF Space. Override per run with `--model_B`. For Arena-Hard, JudgeArena resolves baseline metadata by task version: - `arena-hard-v0.1`: `gpt-4-0314` -- `arena-hard-v2.0`: `o3-mini-2025-01-31` (standard prompts) +- `arena-hard-v2.0`: per-question baseline routed by `category`: + - `o3-mini-2025-01-31` for `hard_prompt`, `coding`, and `math` (500 prompts). + - `gemini-2.0-flash-001` for `creative_writing` (250 prompts). + +For m-Arena-Hard, baseline completions are tied to the benchmark release: +- `m-arena-hard-v0.1`: Aya Expanse 8B (`CohereLabs/aya-expanse-8b`), ingested + from `CohereLabs/deja-vu-pairwise-evals` (repeat 0) via + [`scripts/multilingual_arena_hard/ingest_deja_vu_aya_references.py`](scripts/multilingual_arena_hard/ingest_deja_vu_aya_references.py). +- `m-arena-hard-v2.0`: Gemini 2.5 Flash (`google/gemini-2.5-flash`). ### ELO rating diff --git a/judgearena/chat_models/__init__.py b/judgearena/chat_models/__init__.py new file mode 100644 index 0000000..91aa08d --- /dev/null +++ b/judgearena/chat_models/__init__.py @@ -0,0 +1,15 @@ +"""Chat-model adapters with provider-specific hardening.""" + +from judgearena.chat_models.openrouter_gemini import ( + GEMINI_SAFETY_REFUSAL_MARKER, + OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON, + OpenRouterGeminiSafetyTolerantChatOpenAI, + is_openrouter_gemini_model, +) + +__all__ = [ + "GEMINI_SAFETY_REFUSAL_MARKER", + "OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON", + "OpenRouterGeminiSafetyTolerantChatOpenAI", + "is_openrouter_gemini_model", +] diff --git a/judgearena/chat_models/openrouter_gemini.py b/judgearena/chat_models/openrouter_gemini.py new file mode 100644 index 0000000..bce637d --- /dev/null +++ b/judgearena/chat_models/openrouter_gemini.py @@ -0,0 +1,102 @@ +"""ChatOpenAI subclass tolerant to Gemini's PROHIBITED_CONTENT hard-refusals. + +Google's core policy filter rejects a small fraction of prompts (e.g. graphic +violence, sexual content involving minors) with HTTP 403 ``PROHIBITED_CONTENT`` +*regardless* of the adjustable ``safety_settings`` thresholds. These refusals +are legitimate, reproducible model behavior that a benchmark like +``m-arena-hard-v2.0`` surfaces: the baseline should contain them so the judge +can score them, not crash the run. + +The subclass intercepts the error response before LangChain raises, returns +a stub assistant message with a clearly marked refusal payload and +``finish_reason="content_filter"``, and lets the rest of the pipeline proceed +unchanged. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_openai import ChatOpenAI + +GEMINI_SAFETY_REFUSAL_MARKER = ( + "[Gemini safety refusal: PROHIBITED_CONTENT — Google's core policy filter " + "blocked this prompt regardless of safety_settings.]" +) +OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON = "content_filter" + +_PROHIBITED_CONTENT_TOKEN = "PROHIBITED_CONTENT" + + +def is_openrouter_gemini_model(model_spec: str) -> bool: + """Return True when ``model_spec`` targets a Gemini model via OpenRouter. + + Matches ``OpenRouter/google/gemini-2.5-flash`` and related variants. + """ + provider, sep, model_name = model_spec.partition("/") + if not sep: + return False + lowered = model_name.lower() + return provider == "OpenRouter" and ( + lowered.startswith("google/gemini") or lowered.startswith("google/gemma") + ) + + +def _error_is_prohibited_content(error: object) -> bool: + if error is None: + return False + return _PROHIBITED_CONTENT_TOKEN in str(error) + + +def _build_prohibited_content_stub_payload( + *, original_response: dict[str, Any], model_name: str +) -> dict[str, Any]: + stub_message = { + "role": "assistant", + "content": GEMINI_SAFETY_REFUSAL_MARKER, + } + stub_choice = { + "index": 0, + "message": stub_message, + "finish_reason": OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON, + } + return { + "id": original_response.get("id") or "openrouter-gemini-safety-stub", + "object": "chat.completion", + "created": original_response.get("created") or 0, + "model": original_response.get("model") or model_name, + "choices": [stub_choice], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + +class OpenRouterGeminiSafetyTolerantChatOpenAI(ChatOpenAI): + """ChatOpenAI that converts Gemini PROHIBITED_CONTENT errors to stubs. + + Only intercepts the specific OpenRouter error surface for Gemini's core + policy filter; all other errors propagate unchanged. The stub message has + ``content == GEMINI_SAFETY_REFUSAL_MARKER`` and ``finish_reason == + "content_filter"`` so upstream validators and judges see the refusal + explicitly rather than a silent drop. + """ + + def _create_chat_result( # type: ignore[override] + self, + response, + generation_info: dict | None = None, + ): + response_dict = ( + response if isinstance(response, dict) else response.model_dump() + ) + error = response_dict.get("error") + if _error_is_prohibited_content(error): + stub = _build_prohibited_content_stub_payload( + original_response=response_dict, + model_name=self.model_name, + ) + return super()._create_chat_result(stub, generation_info=generation_info) + return super()._create_chat_result(response, generation_info=generation_info) diff --git a/judgearena/cli.py b/judgearena/cli.py index eb94c83..649fe9e 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -17,7 +17,7 @@ ) from judgearena.estimate_elo_ratings import CliEloArgs from judgearena.estimate_elo_ratings import main as main_elo -from judgearena.generate_and_evaluate import CliArgs +from judgearena.generate_and_evaluate import CliArgs, native_pairwise_baseline from judgearena.generate_and_evaluate import main as main_generate_and_evaluate from judgearena.log import configure_logging, get_logger @@ -196,13 +196,21 @@ def _build_elo_args( provide_explanation=args.provide_explanation, swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, + judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, + battle_thinking_token_budget=args.battle_thinking_token_budget, + strip_thinking_before_judging=args.strip_thinking_before_judging, + skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, + truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, max_out_tokens_judge=args.max_out_tokens_judge, max_model_len=args.max_model_len, + max_judge_model_len=args.max_judge_model_len, chat_template=args.chat_template, result_folder=args.result_folder, engine_kwargs=parse_engine_kwargs(args.engine_kwargs), + judge_engine_kwargs=parse_engine_kwargs(args.judge_engine_kwargs), verbosity=resolve_verbosity(args), log_file=args.log_file, no_log_file=args.no_log_file, @@ -212,7 +220,9 @@ def _build_elo_args( def _build_generate_and_evaluate_args( args: argparse.Namespace, task: str, model_a: str | None ) -> CliArgs: - if model_a is None or args.model_B is None: + if model_a is None or ( + args.model_B is None and native_pairwise_baseline(task) is None + ): raise SystemExit(f"--model_A and --model_B are required for task {task!r}.") return CliArgs( task=task, @@ -224,13 +234,21 @@ def _build_generate_and_evaluate_args( provide_explanation=args.provide_explanation, swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, + judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, + battle_thinking_token_budget=args.battle_thinking_token_budget, + strip_thinking_before_judging=args.strip_thinking_before_judging, + skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, + truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, max_out_tokens_judge=args.max_out_tokens_judge, max_model_len=args.max_model_len, + max_judge_model_len=args.max_judge_model_len, chat_template=args.chat_template, result_folder=args.result_folder, engine_kwargs=parse_engine_kwargs(args.engine_kwargs), + judge_engine_kwargs=parse_engine_kwargs(args.judge_engine_kwargs), verbosity=resolve_verbosity(args), log_file=args.log_file, no_log_file=args.no_log_file, diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 58ce78b..6cfe627 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -11,6 +11,10 @@ import json from dataclasses import dataclass, field +from judgearena.judge_prompt_presets import JUDGE_PROMPT_PRESETS + +MT_BENCH_JUDGE_MODES = ("default", "fastchat_original") + @dataclass class BaseCliArgs: @@ -22,13 +26,21 @@ class BaseCliArgs: provide_explanation: bool = False swap_mode: str = "fixed" ignore_cache: bool = False + judge_prompt_preset: str = "default" + mt_bench_judge_mode: str = "default" + battle_thinking_token_budget: int | None = None + strip_thinking_before_judging: bool = False + skip_judging: bool = False truncate_all_input_chars: int = 8192 + truncate_judge_input_chars: int | None = None max_out_tokens_models: int = 32768 max_out_tokens_judge: int = 32768 max_model_len: int | None = None + max_judge_model_len: int | None = None chat_template: str | None = None result_folder: str = "results" engine_kwargs: dict = field(default_factory=dict) + judge_engine_kwargs: dict = field(default_factory=dict) verbosity: int = 0 log_file: str | None = None no_log_file: bool = False @@ -38,6 +50,22 @@ def __post_init__(self): assert self.swap_mode in supported_modes, ( f"Only {supported_modes} modes are supported but got {self.swap_mode}." ) + assert self.mt_bench_judge_mode in MT_BENCH_JUDGE_MODES, ( + "Only " + f"{list(MT_BENCH_JUDGE_MODES)} MT-Bench judge modes are supported but " + f"got {self.mt_bench_judge_mode!r}." + ) + + +def parse_optional_bool(raw: str | None) -> bool: + if raw is None: + return True + normalized = raw.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise argparse.ArgumentTypeError(f"Expected a boolean value, got '{raw}'.") def add_common_arguments(parser: argparse.ArgumentParser) -> None: @@ -61,7 +89,10 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--provide_explanation", - action="store_true", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, help=( "If specified, judge will provide explanation before making a " "judgement. Does not necessarily improve the accuracy of the judge " @@ -82,9 +113,72 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--ignore_cache", - action="store_true", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, help="If specified, ignore cache of previous completions.", ) + parser.add_argument( + "--judge_prompt_preset", + type=str, + choices=JUDGE_PROMPT_PRESETS, + default="default", + help=( + "Judge prompt preset to use. 'default' preserves the existing score-first " + "JudgeArena prompts, while 'skywork' enables an optional Skywork-style " + "verdict-first preset." + ), + ) + parser.add_argument( + "--mt_bench_judge_mode", + type=str, + choices=MT_BENCH_JUDGE_MODES, + default="default", + help=( + "MT-Bench-only judging mode. 'default' makes MT-Bench obey " + "--judge_prompt_preset like the other benchmarks, while " + "'fastchat_original' preserves the original FastChat-style " + "prompting and [[A]]/[[B]]/[[C]] verdict parsing." + ), + ) + parser.add_argument( + "--battle_thinking_token_budget", + type=int, + required=False, + default=None, + help=( + "Optional reasoning-token sub-budget for battle-model generation. " + "This stays inside --max_out_tokens_models." + ), + ) + parser.add_argument( + "--strip_thinking_before_judging", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, + help=( + "If specified, strip visible reasoning traces from model completions " + "before sending them to the judge." + ), + ) + parser.add_argument( + "--skip_judging", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, + help=( + "If specified, generate battle-model completions and write a " + "generation-only summary (gen-results-.json with limit_events) " + "but skip judge-model construction and the judging loop entirely. " + "Useful for decoupling expensive paid-judge calls from the cheap " + "local generation phase: run once with --skip_judging=True to " + "materialize the completion cache, inspect cap rates, then rerun " + "with --skip_judging=False to judge from cache." + ), + ) parser.add_argument( "--result_folder", type=str, @@ -101,9 +195,19 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: required=False, default=8192, help=( - "Character-level truncation applied before tokenization: truncates " - "each instruction before model A/B generation and truncates each " - "completion before judge evaluation." + "Character-level truncation applied to generation-side inputs: " + "truncates each instruction before model A/B generation." + ), + ) + parser.add_argument( + "--truncate_judge_input_chars", + type=int, + required=False, + default=None, + help=( + "Character cap applied to judge-side inputs (completions, " + "reference, instruction) before judge evaluation. When omitted, " + "judge inputs are not character-truncated by this CLI setting." ), ) parser.add_argument( @@ -132,10 +236,23 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: required=False, default=None, help=( - "Optional total context window for VLLM models (prompt + generation). " - "This is independent from --max_out_tokens_models/--max_out_tokens_judge, " - "which only cap generated tokens. This is useful on smaller GPUs to " - "avoid OOM." + "Optional total context window for the battle-generation VLLM " + "instances (prompt + generation). Independent from " + "--max_out_tokens_models/--max_out_tokens_judge, which only cap " + "generated tokens." + ), + ) + parser.add_argument( + "--max_judge_model_len", + type=int, + required=False, + default=None, + help=( + "Optional total context window for the judge VLLM instance. When " + "omitted, no judge max_model_len override is passed. Set higher " + "than the battle model_len when the judge needs to see longer " + "prompts (e.g. long completions from both A and B) than the " + "battle generator can fit." ), ) parser.add_argument( @@ -160,6 +277,19 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: '\'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.' ), ) + parser.add_argument( + "--judge_engine_kwargs", + type=str, + required=False, + default="{}", + help=( + "Optional JSON dict of engine-specific kwargs that override " + "``--engine_kwargs`` only for the judge model. Useful when the " + "judge needs a different tensor-parallel or quantization config " + "than the battle models, e.g. a 70B judge on TP=2 while the " + "battle models run on TP=1 to dodge compile-time deadlocks." + ), + ) parser.add_argument( "-v", "--verbose", diff --git a/judgearena/estimate_elo_ratings.py b/judgearena/estimate_elo_ratings.py index 51ba6e2..15222fc 100644 --- a/judgearena/estimate_elo_ratings.py +++ b/judgearena/estimate_elo_ratings.py @@ -269,8 +269,8 @@ def replace_slash(s: str) -> str: ] judge_extra_kwargs = {} - if args.max_model_len is not None: - judge_extra_kwargs["max_model_len"] = args.max_model_len + if args.max_judge_model_len is not None: + judge_extra_kwargs["max_model_len"] = args.max_judge_model_len if args.chat_template is not None: judge_extra_kwargs["chat_template"] = args.chat_template @@ -287,7 +287,7 @@ def run_judge() -> pd.DataFrame: completions_B=completions_B, swap_mode=args.swap_mode, provide_explanation=args.provide_explanation, - truncate_input_chars=args.truncate_all_input_chars, + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=use_tqdm, ) return pd.DataFrame( diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 7eb8599..47e608d 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -14,24 +15,46 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + ResolvedJudgePrompt, + resolve_pairwise_judge_prompt, +) from judgearena.log import get_logger +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + LimitEventTracker, compute_pref_summary, data_root, do_inference, download_hf, + infer_model_spec_from_instance, read_df, - truncate, + strip_thinking_tags, + strip_thinking_tags_with_metadata, + truncate_with_metadata, ) +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +_PREFLIGHT_MAX_ITERATIONS = 3 +_PREFLIGHT_RESERVED_TOKENS = 256 +_PREFLIGHT_MIN_COMPLETION_CHARS = 512 + logger = get_logger(__name__) class PairScore: - def __init__(self): + def __init__(self, *, parser_mode: str = "score"): super(PairScore).__init__() self.temperature = 0.3 + self.parser_mode = parser_mode def preference_from_scores(self, score_a: float, score_b: float) -> float: return 1 - np.exp(self.temperature * score_a) / ( @@ -39,17 +62,31 @@ def preference_from_scores(self, score_a: float, score_b: float) -> float: ) def parse_model_raw(self, judge_completion: str) -> float | None: - # lower case to avoid confusion, e.g. when "a" is used instead of "A" - score_a = self.get_regexp_match( - judge_completion.lower(), r'score.*?a[": *\n]*(-?\d+)' - ) - score_b = self.get_regexp_match( - judge_completion.lower(), r'score.*?b[": *\n]*(-?\d+)' - ) + judge_completion = strip_thinking_tags(judge_completion) + if self.parser_mode == "verdict": + return self._parse_bracketed_verdict(judge_completion) + if self.parser_mode == "score": + return self._parse_numeric_scores(judge_completion) + raise ValueError(f"Unsupported parser_mode '{self.parser_mode}'.") + + def _parse_numeric_scores(self, judge_completion: str) -> float | None: + lowered = judge_completion.lower() + score_a = self.get_regexp_match(lowered, r'score.*?a[": *\n]*(-?\d+)') + score_b = self.get_regexp_match(lowered, r'score.*?b[": *\n]*(-?\d+)') if score_a is None or score_b is None: return None - else: - return float(self.preference_from_scores(score_a, score_b)) + return float(self.preference_from_scores(score_a, score_b)) + + def _parse_bracketed_verdict(self, judge_completion: str) -> float | None: + verdict_match = re.search(r"\[\[\s*([ABCabc])\s*\]\]", judge_completion) + if verdict_match is None: + return None + bracketed_verdict = verdict_match.group(1).lower() + return { + "a": 0.0, + "b": 1.0, + "c": 0.5, + }[bracketed_verdict] def get_regexp_match(self, s: str, regex: str, group_index: int = 1): m = re.search(re.compile(regex), s) @@ -59,54 +96,32 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) -_COMPLETION_LABEL_SINGLE = "Answer" -_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" -_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" -_SCORE_FENCE = "\n```" - - def load_judge_system_and_user_prompt( provide_explanation: bool = True, multi_turn: bool = False, ) -> tuple[str, str]: - prompts_dir = Path(__file__).parent / "prompts" - system_prompt = (prompts_dir / "system-prompt.txt").read_text() - - prompt_filename = ( - "prompt-with-explanation.txt" if provide_explanation else "prompt.txt" - ) - user_prompt_template = (prompts_dir / prompt_filename).read_text() - user_prompt_template = user_prompt_template.replace( - "{completion_label}", - _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, - ) - user_prompt_template = user_prompt_template.replace( - "{explanation_suffix}", - _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + resolved = resolve_pairwise_judge_prompt( + prompt_preset=DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation=provide_explanation, + multi_turn=multi_turn, ) - - return system_prompt, user_prompt_template + return resolved.system_prompt or "", resolved.user_prompt_template def resolve_judge_prompts( *, provide_explanation: bool, multi_turn: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, system_prompt: str | None = None, user_prompt_template: str | None = None, -) -> tuple[str, str]: - default_system_prompt, default_user_prompt_template = ( - load_judge_system_and_user_prompt( - provide_explanation=provide_explanation, multi_turn=multi_turn - ) - ) - return ( - system_prompt if system_prompt is not None else default_system_prompt, - ( - user_prompt_template - if user_prompt_template is not None - else default_user_prompt_template - ), +) -> ResolvedJudgePrompt: + return resolve_pairwise_judge_prompt( + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + multi_turn=multi_turn, + system_prompt=system_prompt, + user_prompt_template=user_prompt_template, ) @@ -119,6 +134,8 @@ def evaluate_completions( use_tqdm: bool = False, truncate_input_chars: int | None = 8192, provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, ): """ :param dataset: @@ -144,8 +161,9 @@ def evaluate_completions( dataset=dataset, ).loc[:, "instruction"] - # A bit ugly, only loads if local path exist as we do not have a local path of completion for cases such as - # m-arena-hard. + # Only loads if the per-dataset local path exists; some datasets (e.g. + # language slices of m-arena-hard for which no baseline has been written + # yet) may not ship a local completions file. dataset_output_path = local_path_tables / "model_outputs" / f"{dataset}.csv.zip" if dataset_output_path.exists(): df_outputs = read_df(dataset_output_path) @@ -186,42 +204,63 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): from langchain_together.llms import Together judge_chat_model = Together(model="meta-llama/Llama-3.3-70B-Instruct-Turbo") + judge_model_spec = infer_model_spec_from_instance(judge_chat_model) + usage_tracker = OpenRouterReferencePricingTracker() + limit_event_tracker = LimitEventTracker() unique_string = dataset + "-" + datetime.now().strftime("%Y%m%d_%H%M%S") output_folder = data_root / "judge-evals" / unique_string logger.info("Saving results in %s", output_folder) output_folder.mkdir(parents=True, exist_ok=True) - ( - judge_system_prompt, - judge_user_prompt_template, - ) = resolve_judge_prompts(provide_explanation=provide_explanation) + resolved_prompt = resolve_judge_prompts( + provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + ) annotations = annotate_battles( judge_chat_model=judge_chat_model, instructions=instructions.tolist(), completions_A=completions_A.loc[instructions.index].tolist(), completions_B=completions_B.loc[instructions.index].tolist(), - system_prompt=judge_system_prompt, - user_prompt_template=judge_user_prompt_template, + case_ids=instructions.index.tolist(), + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=resolved_prompt.user_prompt_template, + prompt_preset=resolved_prompt.preset_name, use_tqdm=use_tqdm, truncate_input_chars=truncate_input_chars, provide_explanation=provide_explanation, + strip_thinking_before_judging=strip_thinking_before_judging, + usage_tracker=usage_tracker, + usage_phase="judge", + usage_model_spec=judge_model_spec, + limit_event_tracker=limit_event_tracker, ) # Pairwise judge results - score_parser = PairScore() + score_parser = PairScore(parser_mode=resolved_prompt.parser_mode) prefs = pd.Series( [ score_parser.parse_model_raw(annotation.judge_completion) for annotation in annotations ] ) - results = {**compute_pref_summary(prefs)} + results = { + **compute_pref_summary(prefs), + "judge_prompt_preset": resolved_prompt.preset_name, + "limit_events": limit_event_tracker.build_summary(), + } pd.DataFrame(annotations).to_csv(output_folder / "annotations.csv", index=False) logger.info("%s against %s:\n%s", method_A, method_B, results) with open(output_folder / "results.json", "w") as f: json.dump(_to_jsonable(results), f, allow_nan=False) + pricing_reference = None + if judge_model_spec is not None: + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={"judge": judge_model_spec}, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) run_metadata = { "dataset": dataset, @@ -232,6 +271,8 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): "use_tqdm": use_tqdm, "truncate_input_chars": truncate_input_chars, "provide_explanation": provide_explanation, + "judge_prompt_preset": resolved_prompt.preset_name, + "strip_thinking_before_judging": strip_thinking_before_judging, } try: @@ -246,9 +287,10 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): "completions_A": completions_A.loc[instructions.index].tolist(), "completions_B": completions_B.loc[instructions.index].tolist(), }, - judge_system_prompt=judge_system_prompt, - judge_user_prompt_template=judge_user_prompt_template, + judge_system_prompt=resolved_prompt.system_prompt, + judge_user_prompt_template=resolved_prompt.user_prompt_template, started_at_utc=run_started_at, + pricing_reference=pricing_reference, ) except OSError as e: logger.warning("Failed to write run metadata: %s", e) @@ -261,6 +303,12 @@ class JudgeAnnotation: completion_B: str # completion of the second model judge_completion: str # output of the judge judge_input: str | None = None # input that was passed to the judge + completion_A_for_judge: str | None = None + completion_B_for_judge: str | None = None + completion_A_reasoning_stripped: bool = False + completion_B_reasoning_stripped: bool = False + completion_A_truncated_for_judge: bool = False + completion_B_truncated_for_judge: bool = False def annotate_battles( @@ -268,11 +316,21 @@ def annotate_battles( instructions: list[str], completions_A: list[str], completions_B: list[str], + case_ids: list[object] | None = None, system_prompt: str | None = None, user_prompt_template: str = None, truncate_input_chars: int | None = 8192, use_tqdm: bool = False, provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, + judge_tokenizer: "PreTrainedTokenizerBase | None" = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, ) -> list[JudgeAnnotation]: """ Directly evaluate from list of instructions and completions @@ -300,47 +358,156 @@ def annotate_battles( :param user_prompt_template: :param truncate_input_chars: Max characters to truncate completions before sending to judge. :param use_tqdm: + :param judge_tokenizer: Optional HF tokenizer matching the judge model; when + supplied together with ``max_judge_model_len`` triggers a preflight + tokenize-and-retry pass that shrinks per-completion character caps until + the rendered prompt fits the judge context window. Converts the hard + ``VLLMValidationError`` class into a soft ``judge_input_token_truncation`` + limit event. + :param max_judge_model_len: Judge-side ``max_model_len``; required for the + preflight pass to be active. + :param max_out_tokens_judge: Judge-side output budget subtracted from + ``max_judge_model_len`` to derive the per-request prompt budget. :return: """ # alternatively pass list of tuples assert len(instructions) == len(completions_A) == len(completions_B) + if case_ids is None: + case_ids = [None] * len(instructions) + assert len(case_ids) == len(instructions) - system_prompt, user_prompt_template = resolve_judge_prompts( + resolved_prompt = resolve_judge_prompts( provide_explanation=provide_explanation, + prompt_preset=prompt_preset, system_prompt=system_prompt, user_prompt_template=user_prompt_template, ) - - prompt_template = ChatPromptTemplate.from_messages( - [("system", system_prompt), ("user", user_prompt_template)] - ) - - inputs = prompt_template.batch( - [ + message_templates: list[tuple[str, str]] = [] + if resolved_prompt.system_prompt is not None: + message_templates.append(("system", resolved_prompt.system_prompt)) + message_templates.append(("user", resolved_prompt.user_prompt_template)) + + prompt_template = ChatPromptTemplate.from_messages(message_templates) + truncated_completion_count = 0 + input_payloads = [] + annotation_input_metadata: list[dict[str, object]] = [] + for case_id, user_prompt, completion_A, completion_B in zip( + case_ids, instructions, completions_A, completions_B, strict=True + ): + raw_completion_A = completion_A if isinstance(completion_A, str) else "" + raw_completion_B = completion_B if isinstance(completion_B, str) else "" + completion_A_for_judge = raw_completion_A + completion_B_for_judge = raw_completion_B + stripped_A = False + stripped_B = False + if strip_thinking_before_judging: + completion_A_for_judge, stripped_A = strip_thinking_tags_with_metadata( + completion_A_for_judge + ) + completion_B_for_judge, stripped_B = strip_thinking_tags_with_metadata( + completion_B_for_judge + ) + if stripped_A and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field="completion_A", + case_id=case_id, + original_length=len(raw_completion_A), + final_length=len(completion_A_for_judge), + ) + if stripped_B and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field="completion_B", + case_id=case_id, + original_length=len(raw_completion_B), + final_length=len(completion_B_for_judge), + ) + truncated_completion_A, truncated_A = truncate_with_metadata( + completion_A_for_judge, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="judge_input_char_truncation", + stage="judge_input", + field="completion_A", + case_id=case_id, + ) + truncated_completion_B, truncated_B = truncate_with_metadata( + completion_B_for_judge, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="judge_input_char_truncation", + stage="judge_input", + field="completion_B", + case_id=case_id, + ) + truncated_completion_count += int(truncated_A) + truncated_completion_count += int(truncated_B) + input_payloads.append( { "user_prompt": user_prompt, - "completion_A": truncate(completion_A, max_len=truncate_input_chars), - "completion_B": truncate(completion_B, max_len=truncate_input_chars), + "completion_A": truncated_completion_A, + "completion_B": truncated_completion_B, } - for user_prompt, completion_A, completion_B in zip( - instructions, completions_A, completions_B, strict=True - ) - ] - ) + ) + annotation_input_metadata.append( + { + "completion_A_for_judge": truncated_completion_A, + "completion_B_for_judge": truncated_completion_B, + "completion_A_reasoning_stripped": stripped_A, + "completion_B_reasoning_stripped": stripped_B, + "completion_A_truncated_for_judge": truncated_A, + "completion_B_truncated_for_judge": truncated_B, + } + ) + if truncated_completion_count: + logger.warning( + "Warning: truncated " + f"{truncated_completion_count} judge inputs to " + f"{truncate_input_chars} characters before evaluation." + ) + inputs = prompt_template.batch(input_payloads) + + if judge_tokenizer is not None and max_judge_model_len: + inputs = _preflight_shrink_to_judge_budget( + prompt_template=prompt_template, + inputs=inputs, + input_payloads=input_payloads, + annotation_input_metadata=annotation_input_metadata, + case_ids=case_ids, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + limit_event_tracker=limit_event_tracker, + ) + logger.info("Start LLM judge annotation (%d annotations).", len(inputs)) judge_completions = do_inference( chat_model=judge_chat_model, inputs=inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) annotations = [] - for judge_input, judge_completion, instruction, completion_A, completion_B in zip( + for ( + judge_input, + judge_completion, + instruction, + completion_A, + completion_B, + annotation_input_metadata_row, + ) in zip( inputs, judge_completions, instructions, completions_A, completions_B, + annotation_input_metadata, strict=True, ): annotations.append( @@ -350,6 +517,7 @@ def annotate_battles( instruction=instruction, completion_A=completion_A, completion_B=completion_B, + **annotation_input_metadata_row, ) ) return annotations @@ -360,12 +528,23 @@ def judge_and_parse_prefs( instructions: list[str], completions_A: list[str], completions_B: list[str], + case_ids: list[object] | None = None, swap_mode: str = "fixed", provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + parser_mode: str = "score", + strip_thinking_before_judging: bool = False, system_prompt: str | None = None, user_prompt_template: str | None = None, truncate_input_chars: int = 8192, use_tqdm: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, + judge_tokenizer: "PreTrainedTokenizerBase | None" = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, ) -> tuple[list[JudgeAnnotation], list[JudgeAnnotation] | None, pd.Series]: """Run judge annotation and parse preferences, handling swap_mode='both'. @@ -389,11 +568,21 @@ def judge_and_parse_prefs( instructions=instructions, completions_A=completions_A, completions_B=completions_B, + case_ids=case_ids, provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, system_prompt=system_prompt, user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, + limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, ) annotations_reversed = None @@ -403,17 +592,27 @@ def judge_and_parse_prefs( instructions=instructions, completions_A=completions_B, completions_B=completions_A, + case_ids=case_ids, provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, system_prompt=system_prompt, user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, + limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, ) def _none_to_nan(x): return float("nan") if x is None else x - score_parser = PairScore() + score_parser = PairScore(parser_mode=parser_mode) prefs = pd.Series( [score_parser.parse_model_raw(a.judge_completion) for a in annotations] ) @@ -429,3 +628,161 @@ def _none_to_nan(x): prefs = pd.concat([prefs, (1 - prefs_reversed)]).reset_index(drop=True) return annotations, annotations_reversed, prefs + + +_LC_ROLE_MAP = {"human": "user", "ai": "assistant", "system": "system"} + + +def _count_chat_tokens(prompt_value: Any, tokenizer: Any) -> int: + """Count tokens the way vLLM's ``llm.chat()`` tokenizes after applying the + tokenizer's chat template. Falls back to raw-string encoding for tokenizers + without a chat template or if template application raises.""" + if hasattr(prompt_value, "to_messages"): + messages = [ + { + "role": _LC_ROLE_MAP.get(msg.type, msg.type), + "content": msg.content, + } + for msg in prompt_value.to_messages() + ] + try: + return len(tokenizer.apply_chat_template(messages, tokenize=True)) + except Exception: + pass + if hasattr(prompt_value, "to_string"): + text = prompt_value.to_string() + else: + text = str(prompt_value) + return len(tokenizer.encode(text)) + + +def _find_token_overflows( + inputs: list[Any], tokenizer: Any, safe_budget: int +) -> list[tuple[int, int]]: + """Return ``(index, token_count)`` for inputs whose tokenized length exceeds + ``safe_budget``.""" + overflows: list[tuple[int, int]] = [] + for idx, item in enumerate(inputs): + token_count = _count_chat_tokens(item, tokenizer) + if token_count > safe_budget: + overflows.append((idx, token_count)) + return overflows + + +def _chars_per_token(text: str, tokenizer: Any) -> float: + """Return a conservative char-to-token ratio for ``text``, floored at 1.0. + + Short/empty inputs yield a low ratio, which under-truncates rather than + overflowing - the safe direction for the preflight shrink loop. + """ + text = text if isinstance(text, str) else "" + if not text: + return 1.0 + token_count = max(1, len(tokenizer.encode(text))) + return max(1.0, len(text) / token_count) + + +def _render_with_empty_completions( + prompt_template: ChatPromptTemplate, user_prompt: str +) -> Any: + """Render the prompt template with empty completions so the fixed template + + user-prompt overhead can be measured per case. ``ChatPromptTemplate`` + uses ``str.format()`` on each message, so empty strings substitute cleanly + for both completion slots.""" + return prompt_template.invoke( + { + "user_prompt": user_prompt, + "completion_A": "", + "completion_B": "", + } + ) + + +def _preflight_shrink_to_judge_budget( + *, + prompt_template: ChatPromptTemplate, + inputs: list[Any], + input_payloads: list[dict[str, str]], + annotation_input_metadata: list[dict[str, object]], + case_ids: list[object], + judge_tokenizer: Any, + max_judge_model_len: int, + max_out_tokens_judge: int | None, + limit_event_tracker: LimitEventTracker | None, +) -> list[Any]: + """Bounded shrink-and-re-render loop that converts judge-context overflows + into soft ``judge_input_token_truncation`` limit events instead of a hard + ``VLLMValidationError`` at request time. + + The per-completion budget subtracts the case-specific template + user-prompt + overhead so that one iteration typically suffices; the 3-iteration bound is + a genuine safety net for the rare pathological case where the char-to-token + ratio shifts after truncation (e.g. dropping multi-byte glyphs). + """ + safe_budget = ( + max_judge_model_len - (max_out_tokens_judge or 0) - _PREFLIGHT_RESERVED_TOKENS + ) + for _ in range(_PREFLIGHT_MAX_ITERATIONS): + overflows = _find_token_overflows(inputs, judge_tokenizer, safe_budget) + if not overflows: + return inputs + for idx, _token_count in overflows: + payload = input_payloads[idx] + fixed_tokens = _count_chat_tokens( + _render_with_empty_completions(prompt_template, payload["user_prompt"]), + judge_tokenizer, + ) + per_completion_budget = max(256, (safe_budget - fixed_tokens) // 2) + ratio_A = _chars_per_token(payload["completion_A"], judge_tokenizer) + ratio_B = _chars_per_token(payload["completion_B"], judge_tokenizer) + new_cap_A = max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int(per_completion_budget * ratio_A * 0.9), + ) + new_cap_B = max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int(per_completion_budget * ratio_B * 0.9), + ) + payload["completion_A"], shrunk_A = truncate_with_metadata( + payload["completion_A"], + max_len=new_cap_A, + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field="completion_A", + case_id=case_ids[idx], + ) + payload["completion_B"], shrunk_B = truncate_with_metadata( + payload["completion_B"], + max_len=new_cap_B, + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field="completion_B", + case_id=case_ids[idx], + ) + metadata_row = annotation_input_metadata[idx] + metadata_row["completion_A_for_judge"] = payload["completion_A"] + metadata_row["completion_B_for_judge"] = payload["completion_B"] + if shrunk_A: + metadata_row["completion_A_truncated_for_judge"] = True + if shrunk_B: + metadata_row["completion_B_truncated_for_judge"] = True + inputs = prompt_template.batch(input_payloads) + + final_overflows = _find_token_overflows(inputs, judge_tokenizer, safe_budget) + for idx, token_count in final_overflows: + if limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_token_truncation_failed", + stage="judge_input", + case_id=case_ids[idx], + original_length=token_count, + final_length=safe_budget, + note=( + f"{_PREFLIGHT_MAX_ITERATIONS} shrink iterations did not " + f"bring tokens under {safe_budget}; falling through to " + "vLLM validation." + ), + ) + return inputs diff --git a/judgearena/generate.py b/judgearena/generate.py index 5fe1666..d33c865 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -1,13 +1,54 @@ +from typing import Any + import pandas as pd from langchain_core.prompts import ChatPromptTemplate from judgearena.utils import ( + LimitEventTracker, do_inference, make_model, - truncate, + strip_thinking_tags_with_metadata, + truncate_with_metadata, ) +def _record_generation_output_limit_events( + *, + metadata: list[dict[str, Any]], + case_ids: list[object], + field: str, + model_spec: str, + limit_event_tracker: LimitEventTracker | None, +) -> list[bool]: + hit_token_limit: list[bool] = [] + for case_id, metadata_row in zip(case_ids, metadata, strict=True): + row = metadata_row or {} + finish_reason = str(row.get("finish_reason") or "").lower() + reached_limit = finish_reason == "length" + hit_token_limit.append(reached_limit) + if limit_event_tracker is None: + continue + if reached_limit: + limit_event_tracker.record( + "generation_output_token_limit", + stage="generation_output", + field=field, + case_id=case_id, + model_spec=model_spec, + note=finish_reason, + ) + if row.get("thinking_budget_exhausted"): + limit_event_tracker.record( + "generation_thinking_token_budget", + stage="generation_output", + field=field, + case_id=case_id, + model_spec=model_spec, + note=str(row.get("thinking_token_budget")), + ) + return hit_token_limit + + def generate_instructions( instructions: pd.Series, model: str, @@ -15,9 +56,19 @@ def generate_instructions( max_tokens: int | None = 32768, use_tqdm: bool = True, system_prompt: str | None = None, + usage_tracker=None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, **engine_kwargs, ) -> pd.DataFrame: - chat_model = make_model(model, max_tokens=max_tokens, **engine_kwargs) + chat_model = make_model( + model, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **engine_kwargs, + ) # TODO improve prompt to generate instructions if system_prompt is None: @@ -28,24 +79,50 @@ def generate_instructions( [("system", system_prompt), ("user", "{user_prompt}")] ) - inputs = prompt_template.batch( - [ - { - "user_prompt": truncate(user_prompt, max_len=truncate_input_chars), - } - for user_prompt in instructions - ] - ) + prompt_truncated: list[bool] = [] + input_payloads = [] + case_ids = instructions.index.tolist() + for case_id, user_prompt in zip(case_ids, instructions, strict=True): + truncated_user_prompt, was_truncated = truncate_with_metadata( + user_prompt, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="user_prompt", + case_id=case_id, + model_spec=model, + ) + prompt_truncated.append(was_truncated) + input_payloads.append({"user_prompt": truncated_user_prompt}) + inputs = prompt_template.batch(input_payloads) - completions = do_inference( + completions, completion_metadata = do_inference( chat_model=chat_model, inputs=inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, + return_metadata=True, + ) + hit_token_limit = _record_generation_output_limit_events( + metadata=completion_metadata, + case_ids=case_ids, + field="completion", + model_spec=model, + limit_event_tracker=limit_event_tracker, ) df_outputs = pd.DataFrame( data={ "completion": completions, - "instruction_index": instructions.index.tolist(), + "instruction_index": case_ids, + "generation_prompt_truncated": prompt_truncated, + "generation_output_finish_reason": [ + metadata_row.get("finish_reason") + for metadata_row in completion_metadata + ], + "generation_output_hit_token_limit": hit_token_limit, }, ) return df_outputs @@ -69,8 +146,11 @@ def _infer_grouped_by_temperature( inputs: list, temperatures: list[float], use_tqdm: bool, -) -> list[str]: + usage_tracker=None, + usage_phase: str | None = None, +) -> tuple[list[str], list[dict[str, Any]]]: outputs: list[str] = [""] * len(inputs) + outputs_metadata: list[dict[str, Any]] = [{} for _ in inputs] groups: dict[float, list[int]] = {} for idx, temp in enumerate(temperatures): groups.setdefault(float(temp), []).append(idx) @@ -87,15 +167,20 @@ def _infer_grouped_by_temperature( model_spec, max_tokens=max_tokens, temperature=temp, **model_kwargs ) - group_outs = do_inference( + group_outs, group_metadata = do_inference( chat_model=group_model, inputs=group_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model_spec, + return_metadata=True, ) - for i, out in zip(idxs, group_outs, strict=True): + for i, out, metadata_row in zip(idxs, group_outs, group_metadata, strict=True): outputs[i] = out + outputs_metadata[i] = metadata_row - return outputs + return outputs, outputs_metadata def generate_multiturn( @@ -105,6 +190,10 @@ def generate_multiturn( max_tokens: int | None = 8192, use_tqdm: bool = True, temperature_config: dict[str, float] | None = None, + usage_tracker=None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, + strip_thinking_before_turn_2_prompt: bool = False, **model_kwargs, ) -> pd.DataFrame: """Generate two-turn completions for MT-Bench style questions.""" @@ -114,10 +203,23 @@ def generate_multiturn( if use_category_temperatures and local_provider: chat_model = make_model( - model, max_tokens=max_tokens, temperature=0.0, **model_kwargs + model, + max_tokens=max_tokens, + temperature=0.0, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **model_kwargs, ) else: - chat_model = make_model(model, max_tokens=max_tokens, **model_kwargs) + chat_model = make_model( + model, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **model_kwargs, + ) system_prompt = "You are a helpful assistant." idxs = questions.index.tolist() @@ -131,15 +233,25 @@ def generate_multiturn( turn1_template = ChatPromptTemplate.from_messages( [("system", system_prompt), ("user", "{user_prompt}")] ) - turn1_inputs = turn1_template.batch( - [ - {"user_prompt": truncate(row["turn_1"], max_len=truncate_input_chars)} - for _, row in questions.iterrows() - ] - ) + turn1_prompt_truncated: list[bool] = [] + turn1_payloads = [] + for question_id, row in questions.iterrows(): + truncated_turn_1, was_truncated = truncate_with_metadata( + row["turn_1"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1", + case_id=question_id, + model_spec=model, + ) + turn1_prompt_truncated.append(was_truncated) + turn1_payloads.append({"user_prompt": truncated_turn_1}) + turn1_inputs = turn1_template.batch(turn1_payloads) if use_category_temperatures: - completions_turn_1 = _infer_grouped_by_temperature( + completions_turn_1, turn1_metadata = _infer_grouped_by_temperature( model_spec=model, provider=provider, max_tokens=max_tokens, @@ -148,19 +260,40 @@ def generate_multiturn( inputs=turn1_inputs, temperatures=temperatures, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, ) else: - completions_turn_1 = do_inference( + completions_turn_1, turn1_metadata = do_inference( chat_model=chat_model, inputs=turn1_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, + return_metadata=True, ) + turn1_hit_token_limit = _record_generation_output_limit_events( + metadata=turn1_metadata, + case_ids=idxs, + field="completion_turn_1", + model_spec=model, + limit_event_tracker=limit_event_tracker, + ) turn2_inputs = [] - for (_, row), t1_answer in zip( + turn2_turn1_truncated: list[bool] = [] + turn2_answer_truncated: list[bool] = [] + turn2_prompt_truncated: list[bool] = [] + turn2_turn1_answer_thinking_stripped: list[bool] = [] + for (question_id, row), t1_answer in zip( questions.iterrows(), completions_turn_1, strict=True ): if row["turn_2"] is None: + turn2_turn1_truncated.append(False) + turn2_answer_truncated.append(False) + turn2_prompt_truncated.append(False) + turn2_turn1_answer_thinking_stripped.append(False) turn2_inputs.append( turn1_template.invoke({"user_prompt": "No follow-up question."}) ) @@ -173,20 +306,66 @@ def generate_multiturn( ("user", "{turn_2}"), ] ) + truncated_turn_1, turn1_was_truncated = truncate_with_metadata( + row["turn_1"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1_for_turn_2", + case_id=question_id, + model_spec=model, + ) + # Strip ... from the turn-1 answer before the + # character cap fires. Mirrors what the Qwen3 chat template does + # natively for historical assistant turns; applying it here + # ensures a 30K-char cap lands on the visible answer rather than + # deep inside a runaway reasoning block, which would silently + # destroy the closer and force the whole thinking + # fragment into the turn-2 prompt. + t1_answer_str = str(t1_answer) + if strip_thinking_before_turn_2_prompt: + t1_answer_str, thinking_stripped = strip_thinking_tags_with_metadata( + t1_answer_str + ) + else: + thinking_stripped = False + turn2_turn1_answer_thinking_stripped.append(thinking_stripped) + truncated_turn_1_answer, answer_was_truncated = truncate_with_metadata( + t1_answer_str, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1_answer", + case_id=question_id, + model_spec=model, + ) + truncated_turn_2, turn2_was_truncated = truncate_with_metadata( + row["turn_2"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_2", + case_id=question_id, + model_spec=model, + ) + turn2_turn1_truncated.append(turn1_was_truncated) + turn2_answer_truncated.append(answer_was_truncated) + turn2_prompt_truncated.append(turn2_was_truncated) turn2_inputs.append( multi_turn_template.invoke( { - "turn_1": truncate(row["turn_1"], max_len=truncate_input_chars), - "turn_1_answer": truncate( - str(t1_answer), max_len=truncate_input_chars - ), - "turn_2": truncate(row["turn_2"], max_len=truncate_input_chars), + "turn_1": truncated_turn_1, + "turn_1_answer": truncated_turn_1_answer, + "turn_2": truncated_turn_2, } ) ) if use_category_temperatures: - completions_turn_2 = _infer_grouped_by_temperature( + completions_turn_2, turn2_metadata = _infer_grouped_by_temperature( model_spec=model, provider=provider, max_tokens=max_tokens, @@ -195,19 +374,47 @@ def generate_multiturn( inputs=turn2_inputs, temperatures=temperatures, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, ) else: - completions_turn_2 = do_inference( + completions_turn_2, turn2_metadata = do_inference( chat_model=chat_model, inputs=turn2_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, + return_metadata=True, ) + turn2_hit_token_limit = _record_generation_output_limit_events( + metadata=turn2_metadata, + case_ids=idxs, + field="completion_turn_2", + model_spec=model, + limit_event_tracker=limit_event_tracker, + ) return pd.DataFrame( data={ "instruction_index": idxs, "completion_turn_1": completions_turn_1, "completion_turn_2": completions_turn_2, + "generation_turn_1_prompt_truncated": turn1_prompt_truncated, + "generation_turn_1_finish_reason": [ + metadata_row.get("finish_reason") for metadata_row in turn1_metadata + ], + "generation_turn_1_hit_token_limit": turn1_hit_token_limit, + "generation_turn_2_turn_1_prompt_truncated": turn2_turn1_truncated, + "generation_turn_2_turn_1_answer_truncated": turn2_answer_truncated, + "generation_turn_2_turn_1_answer_thinking_stripped": ( + turn2_turn1_answer_thinking_stripped + ), + "generation_turn_2_prompt_truncated": turn2_prompt_truncated, + "generation_turn_2_finish_reason": [ + metadata_row.get("finish_reason") for metadata_row in turn2_metadata + ], + "generation_turn_2_hit_token_limit": turn2_hit_token_limit, }, ) @@ -218,25 +425,65 @@ def generate_base( truncate_input_chars: int | None = 8192, max_tokens: int | None = 32768, use_tqdm: bool = False, + usage_tracker=None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, **engine_kwargs, ) -> pd.DataFrame: - model = make_model(model, max_tokens=max_tokens, **engine_kwargs) + model_spec = model + model = make_model( + model_spec, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model_spec, + **engine_kwargs, + ) - inputs = [ - truncate(instruction, max_len=truncate_input_chars) - for instruction in instructions - ] + prompt_truncated: list[bool] = [] + case_ids = instructions.index.tolist() + inputs = [] + for case_id, instruction in zip(case_ids, instructions, strict=True): + truncated_instruction, was_truncated = truncate_with_metadata( + instruction, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="instruction", + case_id=case_id, + model_spec=model_spec, + ) + prompt_truncated.append(was_truncated) + inputs.append(truncated_instruction) - completions = model.batch( + completions, completion_metadata = do_inference( + chat_model=model, inputs=inputs, - max_tokens=max_tokens, + use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model_spec, + return_metadata=True, + ) + hit_token_limit = _record_generation_output_limit_events( + metadata=completion_metadata, + case_ids=case_ids, + field="completion", + model_spec=model_spec, + limit_event_tracker=limit_event_tracker, ) - completions = [x.content if hasattr(x, "content") else x for x in completions] df_outputs = pd.DataFrame( data={ "completion": completions, - "instruction_index": instructions.index.tolist(), + "instruction_index": case_ids, + "generation_prompt_truncated": prompt_truncated, + "generation_output_finish_reason": [ + metadata_row.get("finish_reason") + for metadata_row in completion_metadata + ], + "generation_output_hit_token_limit": hit_token_limit, }, ) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 2919280..5a0b49b 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -3,40 +3,74 @@ and then evaluates them using a judge model. """ +import argparse +import hashlib import json +from collections.abc import Mapping from dataclasses import asdict, dataclass from datetime import UTC, datetime -from functools import partial from pathlib import Path import pandas as pd -from judgearena.cli_common import BaseCliArgs +from judgearena.cli_common import ( + BaseCliArgs, + add_common_arguments, + parse_engine_kwargs, + parse_optional_bool, + resolve_verbosity, +) from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions from judgearena.instruction_dataset.arena_hard import ( + ARENA_HARD_BASELINES, download_arena_hard, is_arena_hard_dataset, ) +from judgearena.instruction_dataset.m_arenahard import ( + M_ARENA_HARD_BASELINES, + split_m_arena_hard_dataset, +) +from judgearena.instruction_dataset.mt_bench import MT_BENCH_BASELINES +from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.log import ( attach_file_handler, get_logger, make_run_log_path, ) from judgearena.mt_bench.mt_bench_utils import run_mt_bench +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + LimitEventTracker, + build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, data_root, download_hf, + is_thinking_model, make_model, read_df, ) logger = get_logger(__name__) +ALPACA_EVAL_BASELINES: dict[str, str] = { + "alpaca-eval": "gpt4_1106_preview", +} + +PAIRWISE_BASELINES: dict[str, str | Mapping[str, str]] = { + **ALPACA_EVAL_BASELINES, + **ARENA_HARD_BASELINES, + **M_ARENA_HARD_BASELINES, + **MT_BENCH_BASELINES, +} + def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None @@ -89,12 +123,223 @@ class CliArgs(BaseCliArgs): model_B: str | None = None use_tqdm: bool = False + @classmethod + def parse_args(cls): + parser = argparse.ArgumentParser( + prog="Generate completion and evaluate with a judge", + ) + parser.add_argument( + "--dataset", + help="The dataset to use. For instance `alpaca-eval`, `arena-hard-v2.0`, " + "`arena-hard-v0.1`, `m-arena-hard-v0.1-EU`, `m-arena-hard-v2.0-uk` for " + "instruction tuning cases or `french-contexts`, `spanish-contexts` for " + "base models.", + ) + parser.add_argument( + "--model_A", + required=True, + help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", + ) + parser.add_argument( + "--model_B", + default=None, + help=( + "Name of the baseline LLM for a generation. Optional for Arena-Hard " + "datasets (which ship a dataset-native default per category; see " + "`ARENA_HARD_BASELINES`) and MT-Bench (see `MT_BENCH_BASELINES`, " + "defaults to `gpt-4`). Required for every other dataset." + ), + ) + parser.add_argument( + "--use_tqdm", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, + help="If specified, use tqdm, does not work with all model providers, vLLM in particular.", + ) + add_common_arguments(parser) + args = parser.parse_args() + + return cls( + task=args.dataset, + model_A=args.model_A, + model_B=args.model_B, + use_tqdm=args.use_tqdm, + judge_model=args.judge_model, + n_instructions=args.n_instructions, + provide_explanation=args.provide_explanation, + swap_mode=args.swap_mode, + ignore_cache=args.ignore_cache, + judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, + battle_thinking_token_budget=args.battle_thinking_token_budget, + strip_thinking_before_judging=args.strip_thinking_before_judging, + skip_judging=args.skip_judging, + truncate_all_input_chars=args.truncate_all_input_chars, + truncate_judge_input_chars=args.truncate_judge_input_chars, + max_out_tokens_models=args.max_out_tokens_models, + max_out_tokens_judge=args.max_out_tokens_judge, + max_model_len=args.max_model_len, + max_judge_model_len=args.max_judge_model_len, + chat_template=args.chat_template, + result_folder=args.result_folder, + engine_kwargs=parse_engine_kwargs(args.engine_kwargs), + judge_engine_kwargs=parse_engine_kwargs(args.judge_engine_kwargs), + verbosity=resolve_verbosity(args), + log_file=args.log_file, + no_log_file=args.no_log_file, + ) + + +@dataclass(frozen=True) +class BaselinePlan: + """Row-aligned baseline assignment for `--model_B`. + + Mirrors upstream's `JUDGE_SETTINGS[question["category"]]["baseline"]` lookup + in `arena-hard-auto/gen_judgment.py`: a flat plan assigns one baseline to + every row, a per-row plan assigns a different baseline per category (v2.0 + mixes `o3-mini-2025-01-31` on hard prompts with `gemini-2.0-flash-001` on + creative writing). + """ + + baseline_by_index: pd.Series + + @classmethod + def flat(cls, model: str, *, index: pd.Index) -> "BaselinePlan": + return cls( + baseline_by_index=pd.Series(model, index=index, name="model_B", dtype=str) + ) + + @classmethod + def per_row(cls, series: pd.Series) -> "BaselinePlan": + return cls(baseline_by_index=series.astype(str).rename("model_B")) + + @property + def unique_models(self) -> list[str]: + return sorted(self.baseline_by_index.dropna().unique().tolist()) + + @property + def is_flat(self) -> bool: + return len(self.unique_models) == 1 + + @property + def single_model(self) -> str: + if not self.is_flat: + raise ValueError( + "BaselinePlan is per-row; use baseline_by_index for row-level lookups" + ) + return self.unique_models[0] + + @property + def display_name(self) -> str: + return self.single_model if self.is_flat else "+".join(self.unique_models) + + def aligned_to(self, index: pd.Index) -> pd.Series: + return self.baseline_by_index.loc[index] + + +def _resolve_baseline_plan( + args: CliArgs, instructions_df: pd.DataFrame +) -> BaselinePlan: + """Explicit `--model_B` wins; otherwise fall back to the dataset-native + assignment. Non-arena-hard datasets without an override raise. + """ + if args.model_B is not None: + return BaselinePlan.flat(args.model_B, index=instructions_df.index) + native = native_pairwise_baseline(args.task) + if native is None: + raise ValueError( + f"--model_B is required for dataset '{args.task}'; no dataset-native " + "baseline is registered." + ) + if isinstance(native, str): + return BaselinePlan.flat(native, index=instructions_df.index) + if isinstance(native, Mapping): + if "category" not in instructions_df.columns: + raise ValueError( + f"{args.task} requires a 'category' column for per-category " + "baseline routing; re-run dataset download to regenerate the " + "instructions table." + ) + per_row = instructions_df["category"].map(native) + if per_row.isna().any(): + unknown = sorted( + instructions_df.loc[per_row.isna(), "category"].unique().tolist() + ) + raise ValueError( + f"Unknown Arena-Hard categories for {args.task}: {unknown}. " + f"Known: {sorted(native.keys())}" + ) + return BaselinePlan.per_row(per_row) + raise ValueError(f"Unsupported baseline shape for dataset '{args.task}'.") + + +def native_pairwise_baseline(task: str) -> str | Mapping[str, str] | None: + """Return the dataset-native pairwise baseline, if the task defines one.""" + if task in PAIRWISE_BASELINES: + return PAIRWISE_BASELINES[task] + parsed_m_arena_hard = split_m_arena_hard_dataset(task) + if parsed_m_arena_hard is not None: + version_key, _lang_or_subset = parsed_m_arena_hard + return PAIRWISE_BASELINES[version_key] + return None + def load_contexts(dataset: str) -> pd.Series: path = data_root / "contexts" / dataset return pd.read_csv(path).loc[:, "instruction"] +def _build_generation_model_kwargs( + *, args: CliArgs, model_spec: str +) -> dict[str, object]: + generation_model_kwargs = dict(args.engine_kwargs) + provider, _, model_name = model_spec.partition("/") + if ( + args.battle_thinking_token_budget is not None + and provider == "VLLM" + and is_thinking_model(model_name) + ): + generation_model_kwargs["thinking_token_budget"] = min( + int(args.battle_thinking_token_budget), + int(args.max_out_tokens_models), + ) + return generation_model_kwargs + + +def _build_judge_model_kwargs( + *, args: CliArgs, limit_event_tracker: LimitEventTracker | None +) -> dict[str, object]: + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, + args.engine_kwargs, + judge_engine_kwargs_override=args.judge_engine_kwargs, + ) + if limit_event_tracker is not None: + judge_model_kwargs["limit_event_tracker"] = limit_event_tracker + judge_model_kwargs["limit_event_stage"] = "judge_model_init" + judge_model_kwargs["limit_event_model_spec"] = args.judge_model + return judge_model_kwargs + + +def _generation_cache_name(args: CliArgs, *, model_spec: str) -> str: + generation_config = { + "truncate_all_input_chars": args.truncate_all_input_chars, + "max_out_tokens_models": args.max_out_tokens_models, + "max_model_len": args.max_model_len, + "chat_template": args.chat_template, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "engine_kwargs": _build_generation_model_kwargs( + args=args, model_spec=model_spec + ), + } + generation_config_hash = hashlib.sha256( + json.dumps(generation_config, sort_keys=True, default=str).encode("utf-8") + ).hexdigest()[:12] + return f"{args.task}_{model_spec}_{args.n_instructions}_{generation_config_hash}" + + def print_results(results): """Print battle results in a nice formatted way""" @@ -127,24 +372,8 @@ def main(args: CliArgs): """ run_started_at = datetime.now(UTC) - - # Build the result folder early so the file handler captures the entire run. - # Include a timestamp so each run gets its own unique directory. - name = f"{args.task}-{args.model_A}-{args.model_B}-{args.judge_model}" - name += f"-{args.swap_mode}" - name = name.replace("/", "_") - run_ts = run_started_at.strftime("%Y%m%d_%H%M%S") - res_folder = Path(args.result_folder) / f"{name}-{run_ts}" - res_folder.mkdir(parents=True, exist_ok=True) - if not args.no_log_file: - attach_file_handler(make_run_log_path(res_folder)) - - logger.info( - "Using task %s and evaluating models %s and %s.", - args.task, - args.model_A, - args.model_B, - ) + usage_tracker = OpenRouterReferencePricingTracker() + limit_event_tracker = LimitEventTracker() # Not working with vllm, not detecting model changes and serving the same cache for two different models... # if not args.ignore_cache: @@ -152,6 +381,16 @@ def main(args: CliArgs): ignore_cache = args.ignore_cache if args.task == "mt-bench": + name = ( + f"{args.task}-{args.model_A}-{args.model_B or 'native'}-{args.judge_model}" + ) + name += f"-{args.swap_mode}" + name = name.replace("/", "_") + run_ts = run_started_at.strftime("%Y%m%d_%H%M%S") + res_folder = Path(args.result_folder) / f"{name}-{run_ts}" + res_folder.mkdir(parents=True, exist_ok=True) + if not args.no_log_file: + attach_file_handler(make_run_log_path(res_folder)) return run_mt_bench( args, ignore_cache, @@ -166,94 +405,145 @@ def main(args: CliArgs): # to match files in https://huggingface.co/datasets/geoalgo/multilingual-contexts-to-be-completed lang = args.task.split("-")[-1] instructions = load_contexts(f"{lang}-contexts.csv") + instructions_df = pd.DataFrame({"instruction": instructions.values}) + instructions_df.index = instructions.index else: - instructions = load_instructions( + instructions_df = load_instructions( dataset=args.task, n_instructions=args.n_instructions - ).loc[:, "instruction"] + ) + instructions = instructions_df["instruction"] n_instructions = args.n_instructions if args.n_instructions else len(instructions) if args.n_instructions is not None: - instructions = instructions[:n_instructions] + instructions_df = instructions_df.head(n_instructions) + instructions = instructions.head(n_instructions) + + baseline_plan = _resolve_baseline_plan(args=args, instructions_df=instructions_df) + name = f"{args.task}-{args.model_A}-{baseline_plan.display_name}-{args.judge_model}" + name += f"-{args.swap_mode}" + name = name.replace("/", "_") + run_ts = run_started_at.strftime("%Y%m%d_%H%M%S") + res_folder = Path(args.result_folder) / f"{name}-{run_ts}" + res_folder.mkdir(parents=True, exist_ok=True) + if not args.no_log_file: + attach_file_handler(make_run_log_path(res_folder)) + + logger.info( + "Using task %s and evaluating %s vs baseline %s.", + args.task, + args.model_A, + baseline_plan.display_name, + ) logger.info( - "Generating completions for task %s with model %s and %s " + "Generating completions for task %s with model %s and baseline %s " "(or loading them directly if present)", args.task, args.model_A, - args.model_B, + baseline_plan.display_name, ) - # TODO currently we just support base models for fluency, we could also support instruction-tuned models - gen_fun = ( - partial( - generate_base, - truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - use_tqdm=args.use_tqdm, - **args.engine_kwargs, - ) - if is_fluency_task - else partial( - generate_instructions, + generation_function = generate_base if is_fluency_task else generate_instructions + + def _run_generation(model_spec: str, usage_phase: str) -> pd.DataFrame: + return generation_function( + instructions=instructions, + model=model_spec, truncate_input_chars=args.truncate_all_input_chars, max_tokens=args.max_out_tokens_models, max_model_len=args.max_model_len, chat_template=args.chat_template, use_tqdm=args.use_tqdm, - **args.engine_kwargs, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + limit_event_tracker=limit_event_tracker, + **_build_generation_model_kwargs(args=args, model_spec=model_spec), ) - ) - dataset_completions_A = try_load_dataset_completions( - args.task, args.model_A, n_instructions - ) - if dataset_completions_A is not None: - completions_A = dataset_completions_A.set_index("instruction_index").loc[ - :, "completion" - ] - else: - completions_A = cache_function_dataframe( - lambda: gen_fun( - instructions=instructions, - model=args.model_A, - use_tqdm=args.use_tqdm, - ), - ignore_cache=ignore_cache, - cache_name=f"{args.task}_{args.model_A}_{args.n_instructions}", - ).set_index("instruction_index") - completions_A = completions_A.loc[:, "completion"] - dataset_completions_B = try_load_dataset_completions( - args.task, args.model_B, n_instructions - ) - if dataset_completions_B is not None: - completions_B = dataset_completions_B.set_index("instruction_index").loc[ - :, "completion" + def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: + return df.set_index("instruction_index").loc[instructions.index].reset_index() + + def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Series: + preloaded = try_load_dataset_completions(args.task, model_spec, n_instructions) + if preloaded is not None: + aligned = _align_completion_dataframe(preloaded) + else: + aligned = _align_completion_dataframe( + cache_function_dataframe( + lambda: _run_generation(model_spec, usage_phase), + ignore_cache=ignore_cache, + cache_name=_generation_cache_name(args, model_spec=model_spec), + ) + ) + return aligned.set_index("instruction_index").loc[ + instructions.index, "completion" ] + + completions_A = _load_or_generate_completions(args.model_A, "generation_model_A") + + baseline_per_index = baseline_plan.aligned_to(instructions.index) + if baseline_plan.is_flat: + completions_B = _load_or_generate_completions( + baseline_plan.single_model, "generation_model_B" + ) else: - completions_B = cache_function_dataframe( - lambda: gen_fun( - instructions=instructions, - model=args.model_B, - use_tqdm=args.use_tqdm, - ), - ignore_cache=ignore_cache, - cache_name=f"{args.task}_{args.model_B}_{args.n_instructions}", - ).set_index("instruction_index") - completions_B = completions_B.loc[:, "completion"] + # Per-row plan: fetch one completion set per unique baseline, then stitch + # them together so completions_B[uid] uses the baseline that + # ARENA_HARD_BASELINES routes uid's category to. + per_baseline_completions: dict[str, pd.Series] = {} + for baseline_model in baseline_plan.unique_models: + per_baseline_completions[baseline_model] = _load_or_generate_completions( + baseline_model, f"generation_model_B::{baseline_model}" + ) + completions_B = pd.Series( + [ + per_baseline_completions[model].loc[uid] + for uid, model in baseline_per_index.items() + ], + index=instructions.index, + name="completion", + ) + logger.debug("First instruction/context: %s", instructions.values[0]) logger.debug("First completion of %s:\n%s", args.model_A, completions_A.values[0]) - logger.debug("First completion of %s:\n%s", args.model_B, completions_B.values[0]) + logger.debug( + "First completion of %s:\n%s", + baseline_plan.display_name, + completions_B.values[0], + ) + if args.skip_judging: + with open(res_folder / f"args-{name}.json", "w") as f: + json.dump(asdict(args), f, indent=2) + generation_summary = { + "task": args.task, + "model_A": args.model_A, + "model_B": baseline_plan.display_name, + "baseline_assignment": "per-row" if not baseline_plan.is_flat else "flat", + "baseline_models": baseline_plan.unique_models, + "judge_model": args.judge_model, + "n_instructions": n_instructions, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "limit_events": limit_event_tracker.build_summary(), + "skip_judging": True, + } + with open(res_folder / f"gen-results-{name}.json", "w") as f: + json.dump(_to_jsonable(generation_summary), f, indent=2, allow_nan=False) + logger.info( + "skip_judging=True: wrote gen-results-%s.json and returning before judge construction.", + name, + ) + return None logger.info("Evaluating completions with judge %s.", args.judge_model) judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, - max_model_len=args.max_model_len, + max_model_len=args.max_judge_model_len, chat_template=args.chat_template, - **args.engine_kwargs, + **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) + judge_tokenizer = getattr(judge_chat_model, "tokenizer", None) # save argument for results analysis with open(res_folder / f"args-{name}.json", "w") as f: @@ -269,11 +559,9 @@ def main(args: CliArgs): else: # the default system prompt of annotate is to compare instruction tuned models. system_prompt = None - ( - effective_judge_system_prompt, - judge_user_prompt_template, - ) = resolve_judge_prompts( + resolved_prompt = resolve_judge_prompts( provide_explanation=args.provide_explanation, + prompt_preset=args.judge_prompt_preset or DEFAULT_JUDGE_PROMPT_PRESET, system_prompt=system_prompt, ) @@ -282,50 +570,85 @@ def main(args: CliArgs): instructions=instructions.head(n_instructions).tolist(), completions_A=completions_A.head(n_instructions).tolist(), completions_B=completions_B.head(n_instructions).tolist(), + case_ids=instructions.head(n_instructions).index.tolist(), swap_mode=args.swap_mode, provide_explanation=args.provide_explanation, - system_prompt=effective_judge_system_prompt, - user_prompt_template=judge_user_prompt_template, - truncate_input_chars=args.truncate_all_input_chars, + prompt_preset=resolved_prompt.preset_name, + parser_mode=resolved_prompt.parser_mode, + strip_thinking_before_judging=args.strip_thinking_before_judging, + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=resolved_prompt.user_prompt_template, + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=args.use_tqdm, + usage_tracker=usage_tracker, + usage_phase="judge", + usage_model_spec=args.judge_model, + limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=args.max_judge_model_len, + max_out_tokens_judge=args.max_out_tokens_judge, ) + eval_instruction_index = instructions.head(n_instructions).index.tolist() + baseline_per_eval = baseline_per_index.loc[eval_instruction_index] + df = pd.DataFrame(annotations) - df["instruction_index"] = instructions.head(n_instructions).index.tolist() + df["instruction_index"] = eval_instruction_index df["model_A"] = args.model_A - df["model_B"] = args.model_B + df["model_B"] = baseline_per_eval.tolist() df["judge"] = args.judge_model if args.swap_mode == "both": df_reversed = pd.DataFrame(annotations_reversed) - df_reversed["instruction_index"] = instructions.head( - n_instructions - ).index.tolist() - df_reversed["model_A"] = args.model_B + df_reversed["instruction_index"] = eval_instruction_index + df_reversed["model_A"] = baseline_per_eval.tolist() df_reversed["model_B"] = args.model_A df_reversed["judge"] = args.judge_model df = pd.concat([df, df_reversed]) df.to_csv(res_folder / f"{name}-annotations.csv", index=False) - # compute and report statistics summary = compute_pref_summary(prefs) results = { "task": args.task, "model_A": args.model_A, - "model_B": args.model_B, + "model_B": baseline_plan.display_name, + "baseline_assignment": "per-row" if not baseline_plan.is_flat else "flat", + "baseline_models": baseline_plan.unique_models, "judge_model": args.judge_model, + "judge_prompt_preset": resolved_prompt.preset_name, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, **summary, + "limit_events": limit_event_tracker.build_summary(), "preferences": prefs.tolist(), } - logger.info("%s vs %s judged by %s", args.model_A, args.model_B, args.judge_model) + logger.info( + "%s vs %s judged by %s", + args.model_A, + baseline_plan.display_name, + args.judge_model, + ) print_results(results) + phase_model_specs: dict[str, str] = { + "generation_model_A": args.model_A, + "judge": args.judge_model, + } + if baseline_plan.is_flat: + phase_model_specs["generation_model_B"] = baseline_plan.single_model + else: + for baseline_model in baseline_plan.unique_models: + phase_model_specs[f"generation_model_B::{baseline_model}"] = baseline_model + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs=phase_model_specs, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) with open(res_folder / f"results-{name}.json", "w") as f: json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) - eval_instruction_index = instructions.head(n_instructions).index.tolist() eval_instructions = instructions.head(n_instructions).tolist() eval_completions_A = completions_A.head(n_instructions).tolist() eval_completions_B = completions_B.head(n_instructions).tolist() @@ -341,10 +664,12 @@ def main(args: CliArgs): "instructions": eval_instructions, "completions_A": eval_completions_A, "completions_B": eval_completions_B, + "baseline_model_B": baseline_per_eval.tolist(), }, - judge_system_prompt=effective_judge_system_prompt, - judge_user_prompt_template=judge_user_prompt_template, + judge_system_prompt=resolved_prompt.system_prompt, + judge_user_prompt_template=resolved_prompt.user_prompt_template, started_at_utc=run_started_at, + pricing_reference=pricing_reference, ) except OSError as e: logger.warning("Failed to write run metadata: %s", e) diff --git a/judgearena/instruction_dataset/__init__.py b/judgearena/instruction_dataset/__init__.py index 5f5d9f7..7d3185c 100644 --- a/judgearena/instruction_dataset/__init__.py +++ b/judgearena/instruction_dataset/__init__.py @@ -4,9 +4,11 @@ download_arena_hard, is_arena_hard_dataset, ) -from judgearena.instruction_dataset.m_arenahard import load_m_arenahard +from judgearena.instruction_dataset.m_arenahard import ( + load_m_arenahard, + split_m_arena_hard_dataset, +) from judgearena.log import get_logger -from judgearena.utils import data_root, download_hf, read_df logger = get_logger(__name__) @@ -17,44 +19,18 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat df_instructions = load_mt_bench() - elif "m-arena-hard" in dataset: - if dataset == "m-arena-hard": - language = None - else: - # read the suffix part "m-arena-hard-EU" -> "EU" - language = dataset.split("-")[-1] - assert language in [ - None, - "ar", - "cs", - "de", - "el", - "en", - "es", - "fa", - "fr", - "he", - "hi", - "id", - "it", - "ja", - "ko", - "nl", - "pl", - "pt", - "ro", - "ru", - "tr", - "uk", - "vi", - "zh", - "EU", - ] + elif (parsed := split_m_arena_hard_dataset(dataset)) is not None: + from judgearena.utils import data_root + + version_key, lang_or_subset = parsed logger.info( - "Loading m-arena-hard with language specification set to %s", language + "Loading %s with language specification set to %s", + version_key, + lang_or_subset, + ) + df_instructions = load_m_arenahard( + local_path=data_root, version=version_key, language=lang_or_subset ) - df_instructions = load_m_arenahard(local_path=data_root, language=language) - # sort by question_id, then language so that we get multiple languages if we truncate df_instructions.sort_values(["question_id", "lang"], inplace=True) df_instructions.rename( @@ -72,6 +48,8 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat "arena-hard-v0.1", "arena-hard-v2.0", ] + from judgearena.utils import data_root, download_hf, read_df + local_path_tables = data_root / "tables" if is_arena_hard_dataset(dataset): download_arena_hard(dataset=dataset, local_tables_path=local_path_tables) diff --git a/judgearena/instruction_dataset/arena_hard.py b/judgearena/instruction_dataset/arena_hard.py index 414d2f8..5a59db0 100644 --- a/judgearena/instruction_dataset/arena_hard.py +++ b/judgearena/instruction_dataset/arena_hard.py @@ -1,104 +1,103 @@ -from dataclasses import dataclass +from collections.abc import Mapping from pathlib import Path +from typing import Any import pandas as pd -from datasets import Dataset, DatasetDict, IterableDataset, load_dataset +from huggingface_hub import snapshot_download ARENA_HARD_HF_REPO_ID = "lmarena-ai/arena-hard-auto" - -@dataclass(frozen=True) -class ArenaHardSpec: - hf_variant: str - baseline_model: str - - -ARENA_HARD_DATASETS: dict[str, ArenaHardSpec] = { - "arena-hard-v0.1": ArenaHardSpec( - hf_variant="arena-hard-v0.1", - baseline_model="gpt-4-0314", - ), - "arena-hard-v2.0": ArenaHardSpec( - hf_variant="arena-hard-v2.0", - baseline_model="o3-mini-2025-01-31", - ), +# Mirrors upstream's `JUDGE_SETTINGS` baseline assignment in +# `arena-hard-auto/utils/judge_utils.py` verbatim: v0.1 has a single flat +# baseline, v2.0 routes per question category. `is_arena_hard_dataset` and +# the dispatcher in `generate_and_evaluate.py` key off this map. +# +# Note: the released v2.0 `question.jsonl` only tags rows as `hard_prompt` +# (500) or `creative_writing` (250); `coding` and `math` are inert keys +# upstream ships for forward compatibility (no question carries those +# labels, so the dispatcher never looks them up). We keep them so any +# future re-tagging upstream lights up automatically without a code +# change here. +ARENA_HARD_BASELINES: dict[str, str | Mapping[str, str]] = { + "arena-hard-v0.1": "gpt-4-0314", + "arena-hard-v2.0": { + "hard_prompt": "o3-mini-2025-01-31", + "coding": "o3-mini-2025-01-31", + "math": "o3-mini-2025-01-31", + "creative_writing": "gemini-2.0-flash-001", + }, } - -def resolve_arena_hard_spec(dataset: str) -> ArenaHardSpec | None: - return ARENA_HARD_DATASETS.get(dataset) +# Dataset name -> upstream HF `data//` directory. Kept private so the +# public API of this module is just the baseline map and helpers below. +_ARENA_HARD_HF_VARIANTS: dict[str, str] = { + "arena-hard-v0.1": "arena-hard-v0.1", + "arena-hard-v2.0": "arena-hard-v2.0", +} def is_arena_hard_dataset(dataset: str) -> bool: - return resolve_arena_hard_spec(dataset) is not None + return dataset in ARENA_HARD_BASELINES -def arena_hard_baseline_model(dataset: str) -> str | None: - spec = resolve_arena_hard_spec(dataset) - if spec is None: - return None - return spec.baseline_model +def arena_hard_native_baseline( + dataset: str, +) -> str | Mapping[str, str] | None: + """Dataset-native baseline assignment. - -def _load_official_arena_hard_dataset(spec: ArenaHardSpec) -> pd.DataFrame: - data = load_dataset( - path=ARENA_HARD_HF_REPO_ID, - data_dir=f"data/{spec.hf_variant}", - ) - return _dataset_like_to_dataframe(data) - - -def _dataset_like_to_dataframe( - data: Dataset | DatasetDict | IterableDataset, -) -> pd.DataFrame: - if isinstance(data, DatasetDict): - if "train" in data: - return data["train"].to_pandas() - first_split = next(iter(data.keys())) - return data[first_split].to_pandas() - if isinstance(data, Dataset): - return data.to_pandas() - if isinstance(data, IterableDataset): - return pd.DataFrame(list(data)) - raise TypeError(f"Unsupported dataset object type: {type(data)}") + Returns a plain string for flat datasets (v0.1), a `{category: model}` + mapping for per-category datasets (v2.0), or `None` for datasets that + don't ship a native baseline. + """ + return ARENA_HARD_BASELINES.get(dataset) def normalize_official_arena_hard( raw_df: pd.DataFrame, dataset: str ) -> tuple[pd.DataFrame, pd.DataFrame | None]: - spec = resolve_arena_hard_spec(dataset) - if spec is None: + if dataset not in _ARENA_HARD_HF_VARIANTS: raise ValueError(f"Unsupported Arena-Hard dataset: {dataset}") - - instruction_index = _pick_instruction_index(raw_df) - instruction = _pick_instruction(raw_df) - df_instructions = pd.DataFrame( - { - "instruction_index": instruction_index, - "instruction": instruction, - } - ) - df_instructions = df_instructions.dropna( - subset=["instruction_index", "instruction"] - ) - df_instructions = df_instructions.drop_duplicates(subset=["instruction_index"]) - df_instructions = df_instructions.sort_values("instruction_index") - + df_instructions = _build_instructions(raw_df) df_model_outputs = _build_model_outputs(raw_df) return df_instructions, df_model_outputs def download_arena_hard(dataset: str, local_tables_path: Path) -> None: - """Load Arena-Hard from the Hub if instruction and model-output files are missing.""" - spec = resolve_arena_hard_spec(dataset) - if spec is None: + """Populate `{dataset}.csv` and `{dataset}.csv.zip` on disk if missing. + + Pulls the raw jsonl files directly via `snapshot_download` and reads them + with pandas: upstream's per-row `messages[].content` oscillates between + string and dict across answer files, so `datasets.load_dataset` can't + materialize them into a single Arrow schema. + + Re-downloads when the instructions table is stale - currently only v2.0 + detects this, because routing by category requires the `category` column + that older caches were written without. + """ + if dataset not in _ARENA_HARD_HF_VARIANTS: return instructions_path = local_tables_path / "instructions" / f"{dataset}.csv" model_outputs_path = local_tables_path / "model_outputs" / f"{dataset}.csv.zip" - if instructions_path.exists() and model_outputs_path.exists(): + if ( + instructions_path.exists() + and model_outputs_path.exists() + and _instructions_cache_is_fresh(instructions_path, dataset) + ): return - raw_df = _load_official_arena_hard_dataset(spec) + variant = _ARENA_HARD_HF_VARIANTS[dataset] + snapshot_root = snapshot_download( + repo_id=ARENA_HARD_HF_REPO_ID, + repo_type="dataset", + allow_patterns=[ + f"data/{variant}/question.jsonl", + f"data/{variant}/model_answer/*.jsonl", + ], + force_download=False, + ) + raw_df = _read_arena_hard_jsonl_frames( + variant_dir=Path(snapshot_root) / "data" / variant + ) df_instructions, df_model_outputs = normalize_official_arena_hard( raw_df=raw_df, dataset=dataset ) @@ -109,11 +108,100 @@ def download_arena_hard(dataset: str, local_tables_path: Path) -> None: df_model_outputs.to_csv(model_outputs_path, index=False) +def _instructions_cache_is_fresh(instructions_path: Path, dataset: str) -> bool: + """Category-aware datasets need a `category` column; older caches lack it.""" + native = arena_hard_native_baseline(dataset) + if not isinstance(native, Mapping): + return True + cached_columns = pd.read_csv(instructions_path, nrows=0).columns + return "category" in cached_columns + + +def _read_arena_hard_jsonl_frames(variant_dir: Path) -> pd.DataFrame: + frames: list[pd.DataFrame] = [] + question_path = variant_dir / "question.jsonl" + if question_path.exists(): + frames.append(pd.read_json(question_path, lines=True)) + answer_dir = variant_dir / "model_answer" + if answer_dir.exists(): + for jsonl_path in sorted(answer_dir.glob("*.jsonl")): + frames.append(pd.read_json(jsonl_path, lines=True)) + if not frames: + raise FileNotFoundError(f"No Arena-Hard jsonl files found under {variant_dir}") + return pd.concat(frames, ignore_index=True, sort=False) + + +def _build_instructions(raw_df: pd.DataFrame) -> pd.DataFrame: + # Question rows are the ones with a prompt; model-answer rows don't have + # one and must not leak into the instructions table. + if "prompt" in raw_df.columns: + question_rows = raw_df[raw_df["prompt"].notna()].reset_index(drop=True) + else: + question_rows = raw_df.reset_index(drop=True) + + if len(question_rows) == 0: + return pd.DataFrame(columns=["instruction_index", "instruction"]) + + columns: dict[str, pd.Series] = { + "instruction_index": _pick_instruction_index(question_rows), + "instruction": _pick_instruction(question_rows), + } + if "category" in question_rows.columns: + columns["category"] = question_rows["category"] + df = pd.DataFrame(columns) + df = df.dropna(subset=["instruction_index", "instruction"]) + df["instruction"] = df["instruction"].astype(str) + df = df.drop_duplicates(subset=["instruction_index"]) + df = df.sort_values("instruction_index").reset_index(drop=True) + return df + + +def _build_model_outputs(raw_df: pd.DataFrame) -> pd.DataFrame | None: + if "model" not in raw_df.columns: + return None + extracted_output = raw_df.apply(_extract_assistant_output, axis=1) + instruction_index = _pick_instruction_index(raw_df) + df = pd.DataFrame( + { + "instruction_index": instruction_index, + "model": raw_df["model"], + "output": extracted_output, + } + ) + df = df[df["model"].notna() & df["output"].notna()] + df = df.dropna(subset=["instruction_index"]) + if df.empty: + return None + df["instruction_index"] = df["instruction_index"].astype(str) + df["model"] = df["model"].astype(str) + df["output"] = df["output"].astype(str) + return df.reset_index(drop=True) + + +def _extract_assistant_output(row: pd.Series) -> str | None: + """Pull the assistant response out of either a flat `output` column or + upstream's nested `messages[-1].content.answer` shape. + """ + output_value = row.get("output") + if isinstance(output_value, str) and output_value: + return output_value + messages = row.get("messages") + if isinstance(messages, list) and messages: + last = messages[-1] + content = last.get("content") if isinstance(last, dict) else None + if isinstance(content, dict): + answer = content.get("answer") + return answer if isinstance(answer, str) and answer else None + if isinstance(content, str) and content: + return content + return None + + def _pick_instruction_index(raw_df: pd.DataFrame) -> pd.Series: - for col in ["instruction_index", "question_id", "id"]: + for col in ["instruction_index", "uid", "question_id", "id"]: if col in raw_df.columns: return raw_df[col].astype(str) - return pd.Series(range(len(raw_df)), dtype=str) + return pd.Series(range(len(raw_df)), dtype=str, index=raw_df.index) def _pick_instruction(raw_df: pd.DataFrame) -> pd.Series: @@ -121,13 +209,14 @@ def _pick_instruction(raw_df: pd.DataFrame) -> pd.Series: if col in raw_df.columns: if col == "turns": return raw_df[col].apply(_turns_to_text) - return raw_df[col].astype(str) + return raw_df[col] raise ValueError( - f"Unable to infer instruction text column from Arena-Hard data. Available columns: {raw_df.columns.tolist()}" + "Unable to infer instruction text column from Arena-Hard data. " + f"Available columns: {raw_df.columns.tolist()}" ) -def _turns_to_text(turns_value) -> str: +def _turns_to_text(turns_value: Any) -> str: if isinstance(turns_value, list): if not turns_value: return "" @@ -142,17 +231,3 @@ def _turns_to_text(turns_value) -> str: if key in turns_value: return str(turns_value[key]) return str(turns_value) - - -def _build_model_outputs(raw_df: pd.DataFrame) -> pd.DataFrame | None: - if not {"model", "output"}.issubset(raw_df.columns): - return None - instruction_index = _pick_instruction_index(raw_df) - df_outputs = pd.DataFrame( - { - "instruction_index": instruction_index, - "model": raw_df["model"].astype(str), - "output": raw_df["output"].fillna("").astype(str), - } - ) - return df_outputs diff --git a/judgearena/instruction_dataset/m_arenahard.py b/judgearena/instruction_dataset/m_arenahard.py index 3d7b919..81b4f2e 100644 --- a/judgearena/instruction_dataset/m_arenahard.py +++ b/judgearena/instruction_dataset/m_arenahard.py @@ -1,39 +1,132 @@ +"""Version-aware m-ArenaHard loader. + +Mirrors ``judgearena/instruction_dataset/arena_hard.py``: each supported +``m-arena-hard-v{X.Y}`` maps to its dataset-native baseline, and a parallel +private dict carries the upstream HF repo id. The dispatcher in +``judgearena/instruction_dataset/__init__.py`` uses +``split_m_arena_hard_dataset`` to parse ``m-arena-hard-v{X.Y}[-{lang}|-EU]`` +and then calls ``load_m_arenahard``. +""" + from pathlib import Path import pandas as pd from huggingface_hub import snapshot_download -from judgearena.utils import data_root +EU_LANGUAGES: tuple[str, ...] = ( + "cs", + "de", + "el", + "en", + "es", + "fr", + "it", + "nl", + "pl", + "pt", + "ro", + "uk", +) + +NON_EU_LANGUAGES: tuple[str, ...] = ( + "ar", + "fa", + "he", + "hi", + "id", + "ja", + "ko", + "ru", + "tr", + "vi", + "zh", +) + +ALL_LANGUAGES: tuple[str, ...] = (*EU_LANGUAGES, *NON_EU_LANGUAGES) + +# Dataset name -> dataset-native baseline model. Shape mirrors +# `ARENA_HARD_BASELINES` in `arena_hard.py`. v0.1 uses Aya Expanse 8B (free +# completions from CohereLabs/deja-vu-pairwise-evals); v2.0 uses Gemini 2.5 Flash. +M_ARENA_HARD_BASELINES: dict[str, str] = { + "m-arena-hard-v0.1": "CohereLabs/aya-expanse-8b", + "m-arena-hard-v2.0": "google/gemini-2.5-flash", +} + +# Dataset name -> upstream HF repo id. Kept private; the on-disk cache subdir +# is derived from the repo's short name. +_M_ARENA_HARD_HF_REPOS: dict[str, str] = { + "m-arena-hard-v0.1": "CohereLabs/m-ArenaHard", + "m-arena-hard-v2.0": "CohereLabs/m-ArenaHard-v2.0", +} + + +def is_m_arena_hard_dataset(dataset: str) -> bool: + return split_m_arena_hard_dataset(dataset) is not None + + +def split_m_arena_hard_dataset(dataset: str) -> tuple[str, str | None] | None: + """Parse ``m-arena-hard-v{X.Y}[-{lang}|-EU]`` into ``(version, suffix)``. + Returns ``None`` for any name that doesn't match a known version or that + carries an unknown suffix. ``suffix`` is ``None`` for the all-languages + variant, ``"EU"`` for the EU subset, or a 2-letter code in + :data:`ALL_LANGUAGES`. Versioned names only -- the unversioned + ``m-arena-hard`` alias is deliberately not accepted. + """ + for version in M_ARENA_HARD_BASELINES: + if dataset == version: + return version, None + if dataset.startswith(f"{version}-"): + suffix = dataset[len(version) + 1 :] + if suffix == "EU" or suffix in ALL_LANGUAGES: + return version, suffix + return None + return None -def load_m_arenahard(local_path, language: str | None = None): + +def m_arena_hard_native_baseline(dataset: str) -> str | None: + """Baseline for a dataset name, or ``None`` if it isn't m-arena-hard.""" + parsed = split_m_arena_hard_dataset(dataset) + if parsed is None: + return None + return M_ARENA_HARD_BASELINES[parsed[0]] + + +def load_m_arenahard( + local_path: Path, + version: str, + language: str | None = None, +) -> pd.DataFrame: + """Load m-ArenaHard prompts for the requested version and language subset. + + ``version`` must be a key in :data:`M_ARENA_HARD_BASELINES`. ``language`` + is ``None`` for the full 23-language union, ``"EU"`` for the EU subset, + or a 2-letter language code for a single-language slice. + + The returned DataFrame carries the upstream columns plus a ``lang`` + column, with ``question_id`` rewritten to ``f"{question_id}-{lang}"`` so + multi-language slices have unique identifiers. + """ + if version not in _M_ARENA_HARD_HF_REPOS: + raise ValueError( + f"Unsupported m-ArenaHard version: {version!r}. " + f"Known versions: {sorted(_M_ARENA_HARD_HF_REPOS)}." + ) + repo_id = _M_ARENA_HARD_HF_REPOS[version] + local_subdir = repo_id.split("/", 1)[1] snapshot_download( - repo_id="CohereLabs/m-ArenaHard", + repo_id=repo_id, repo_type="dataset", allow_patterns="*", - local_dir=local_path / "m-ArenaHard", + local_dir=local_path / local_subdir, force_download=False, ) + m_arena_root = local_path / local_subdir - df_union = [] - m_arena_root = Path(local_path / "m-ArenaHard") - eu_languages = [ - "cs", - "de", - "el", - "en", - "es", - "fr", - "it", - "nl", - "pl", - "pt", - "ro", - "uk", - ] - for path in sorted(Path(m_arena_root).rglob("*.parquet")): + df_union: list[pd.DataFrame] = [] + for path in sorted(m_arena_root.rglob("*.parquet")): lg = path.parent.name - if language == "EU" and lg in eu_languages: + if language == "EU" and lg in EU_LANGUAGES: df = pd.read_parquet(path) df["lang"] = lg df_union.append(df) @@ -42,16 +135,17 @@ def load_m_arenahard(local_path, language: str | None = None): df["lang"] = lg df_union.append(df) - assert len(df_union) > 0, f"Invalid language passed {language}" + assert len(df_union) > 0, ( + f"No parquet matched under {m_arena_root} for language={language!r}." + ) df_res = pd.concat(df_union, ignore_index=True) - - # update index to still be unique by appendix language as a suffix df_res["question_id"] = df_res.apply( lambda row: f"{row['question_id']}-{row['lang']}", axis=1 ) - return df_res if __name__ == "__main__": - load_m_arenahard(local_path=data_root, language="EU") + from judgearena.utils import data_root + + load_m_arenahard(local_path=data_root, version="m-arena-hard-v0.1", language="EU") diff --git a/judgearena/instruction_dataset/mt_bench.py b/judgearena/instruction_dataset/mt_bench.py index e2a4233..291e13e 100644 --- a/judgearena/instruction_dataset/mt_bench.py +++ b/judgearena/instruction_dataset/mt_bench.py @@ -7,11 +7,67 @@ from judgearena.utils import data_root +MT_BENCH_SPACE_ID = "lmsys/mt-bench" +MT_BENCH_QUESTION_PATTERN = "data/mt_bench/question.jsonl" +MT_BENCH_MODEL_ANSWER_DIR = Path("data") / "mt_bench" / "model_answer" FASTCHAT_GPT4_REFERENCE_URL = ( "https://raw.githubusercontent.com/lm-sys/FastChat/main/" "fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl" ) +# Mirrors ``ARENA_HARD_BASELINES`` / ``M_ARENA_HARD_BASELINES``: dataset name -> +# dataset-native pairwise baseline. MT-Bench ships only one variant today, and +# ``gpt-4`` is the stronger-reference choice (FastChat's own ``pairwise-baseline`` +# default is ``gpt-3.5-turbo``; we deliberately diverge here). +MT_BENCH_BASELINES: dict[str, str] = { + "mt-bench": "gpt-4", +} + + +def is_mt_bench_dataset(dataset: str) -> bool: + return dataset in MT_BENCH_BASELINES + + +def mt_bench_native_baseline(dataset: str) -> str | None: + """Baseline for a dataset name, or ``None`` if it isn't mt-bench.""" + return MT_BENCH_BASELINES.get(dataset) + + +def _normalize_question_id(question_id: object) -> object: + try: + return int(question_id) + except Exception: + return question_id + + +def _snapshot_mt_bench_files( + *, + local_dir: Path, + allow_patterns: list[str], + expected_path: Path, + description: str, +) -> None: + try: + snapshot_download( + repo_id=MT_BENCH_SPACE_ID, + repo_type="space", + allow_patterns=allow_patterns, + local_dir=local_dir, + force_download=False, + ) + except Exception as e: + raise RuntimeError( + f"Failed to download {description} from HuggingFace space " + f"'{MT_BENCH_SPACE_ID}'. If you're in an offline / restricted-network " + f"environment, pre-download the space snapshot and place the file at " + f"{expected_path}, or set OPENJURY_DATA to point to that directory." + ) from e + if not expected_path.exists(): + raise FileNotFoundError( + f"Could not locate {description} after download. " + f"Expected file at {expected_path}." + ) + def _download_gpt4_references(local_dir: Path) -> Path | None: reference_dir = local_dir / "reference_answer" @@ -46,34 +102,103 @@ def download_mt_bench(local_dir: Path | None = None) -> tuple[Path, Path | None] question_path = local_dir / "data" / "mt_bench" / "question.jsonl" if not question_path.exists(): - try: - snapshot_download( - repo_id="lmsys/mt-bench", - repo_type="space", - allow_patterns=[ - "data/mt_bench/question.jsonl", - ], - local_dir=local_dir, - force_download=False, - ) - except Exception as e: - raise RuntimeError( - "Failed to download MT-Bench questions from HuggingFace space " - "'lmsys/mt-bench'. If you're in an offline / restricted-network " - "environment, pre-download the space snapshot and place the " - f"questions file at {question_path}, or set OPENJURY_DATA to " - "point to that directory." - ) from e - if not question_path.exists(): - raise FileNotFoundError( - "Could not locate MT-Bench questions after download. " - f"Expected file at {question_path}." + _snapshot_mt_bench_files( + local_dir=local_dir, + allow_patterns=[MT_BENCH_QUESTION_PATTERN], + expected_path=question_path, + description="MT-Bench questions", ) gpt4_reference_path = _download_gpt4_references(local_dir) return question_path, gpt4_reference_path +def download_mt_bench_model_answer( + model_id: str, local_dir: Path | None = None +) -> Path: + """Download a cached MT-Bench baseline answer file if missing.""" + if local_dir is None: + local_dir = data_root / "mt-bench" + answer_path = local_dir / MT_BENCH_MODEL_ANSWER_DIR / f"{model_id}.jsonl" + if answer_path.exists(): + return answer_path + answer_path.parent.mkdir(parents=True, exist_ok=True) + allow_pattern = (MT_BENCH_MODEL_ANSWER_DIR / f"{model_id}.jsonl").as_posix() + _snapshot_mt_bench_files( + local_dir=local_dir, + allow_patterns=[allow_pattern], + expected_path=answer_path, + description=f"MT-Bench model answers for '{model_id}'", + ) + return answer_path + + +def _extract_answer_turns(record: dict, source_name: str) -> tuple[object, list[str]]: + question_id = record.get("question_id", record.get("id")) + if question_id is None: + raise ValueError( + f"MT-Bench answer record from {source_name} is missing question_id/id." + ) + choices = record.get("choices") + if not (isinstance(choices, list) and choices): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} is " + "missing a non-empty choices list." + ) + first_choice = choices[0] + if not isinstance(first_choice, dict): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} has " + "a malformed first choice entry." + ) + turns = first_choice.get("turns") + if not isinstance(turns, list): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} is " + "missing a turns list." + ) + return _normalize_question_id(question_id), turns + + +def load_mt_bench_model_answers( + model: str, + n_instructions: int | None = None, + local_dir: Path | None = None, +) -> pd.DataFrame | None: + """Load pre-generated MT-Bench answers from a local file or cached model id.""" + local_path = Path(model) + if local_path.exists(): + answer_path = local_path + elif "/" not in model: + answer_path = download_mt_bench_model_answer( + model_id=model, local_dir=local_dir + ) + else: + return None + + answer_records = pd.read_json(answer_path, lines=True).to_dict(orient="records") + rows = [] + for rec in answer_records: + question_id, turns = _extract_answer_turns(rec, str(answer_path)) + rows.append( + { + "instruction_index": question_id, + "completion_turn_1": turns[0] if len(turns) > 0 else "", + "completion_turn_2": turns[1] if len(turns) > 1 else "", + } + ) + + df_answers = pd.DataFrame(rows) + if df_answers.empty: + raise ValueError( + f"MT-Bench answer file {answer_path} did not contain any rows." + ) + df_answers.sort_values("instruction_index", inplace=True) + if n_instructions is not None: + df_answers = df_answers.head(n_instructions) + return df_answers + + def load_mt_bench() -> pd.DataFrame: """Load MT-Bench questions and reference answers. @@ -126,10 +251,7 @@ def load_mt_bench() -> pd.DataFrame: raise ValueError( f"MT-Bench question record missing question_id/id: keys={list(rec.keys())}" ) - try: - qid = int(qid_raw) - except Exception: - qid = qid_raw + qid = _normalize_question_id(qid_raw) category = rec.get("category") turns = rec.get("turns") diff --git a/judgearena/judge_prompt_presets.py b/judgearena/judge_prompt_presets.py new file mode 100644 index 0000000..00c3ea3 --- /dev/null +++ b/judgearena/judge_prompt_presets.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +JudgeParserMode = Literal["score", "verdict"] + +DEFAULT_JUDGE_PROMPT_PRESET = "default" +SKYWORK_JUDGE_PROMPT_PRESET = "skywork" +JUDGE_PROMPT_PRESETS = ( + DEFAULT_JUDGE_PROMPT_PRESET, + SKYWORK_JUDGE_PROMPT_PRESET, +) + +_PROMPTS_DIR = Path(__file__).resolve().parent / "prompts" +_COMPLETION_LABEL_SINGLE = "Answer" +_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" +_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" +_SCORE_FENCE = "\n```" + + +@dataclass(frozen=True) +class PairwiseJudgePromptPreset: + name: str + parser_mode: JudgeParserMode + system_prompt_filename: str | None + user_prompt_filename: str + user_prompt_with_explanation_filename: str + + +@dataclass(frozen=True) +class ResolvedJudgePrompt: + preset_name: str + parser_mode: JudgeParserMode + system_prompt: str | None + user_prompt_template: str + + +_PAIRWISE_PROMPT_PRESETS: dict[str, PairwiseJudgePromptPreset] = { + DEFAULT_JUDGE_PROMPT_PRESET: PairwiseJudgePromptPreset( + name=DEFAULT_JUDGE_PROMPT_PRESET, + parser_mode="score", + system_prompt_filename="system-prompt.txt", + user_prompt_filename="prompt.txt", + user_prompt_with_explanation_filename="prompt-with-explanation.txt", + ), + SKYWORK_JUDGE_PROMPT_PRESET: PairwiseJudgePromptPreset( + name=SKYWORK_JUDGE_PROMPT_PRESET, + parser_mode="verdict", + system_prompt_filename=None, + user_prompt_filename="skywork-prompt.txt", + user_prompt_with_explanation_filename="skywork-prompt-with-explanation.txt", + ), +} + + +def _render_user_prompt_template( + raw_template: str, *, provide_explanation: bool, multi_turn: bool +) -> str: + template = raw_template.replace( + "{completion_label}", + _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, + ) + template = template.replace( + "{explanation_suffix}", + _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + ) + return template + + +def resolve_pairwise_judge_prompt( + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool, + multi_turn: bool = False, + system_prompt: str | None = None, + user_prompt_template: str | None = None, +) -> ResolvedJudgePrompt: + preset = _PAIRWISE_PROMPT_PRESETS.get(prompt_preset) + if preset is None: + supported = ", ".join(sorted(_PAIRWISE_PROMPT_PRESETS)) + raise ValueError( + f"Unsupported judge prompt preset '{prompt_preset}'. Choose from: {supported}." + ) + + prompt_filename = ( + preset.user_prompt_with_explanation_filename + if provide_explanation + else preset.user_prompt_filename + ) + default_system_prompt = ( + (_PROMPTS_DIR / preset.system_prompt_filename).read_text(encoding="utf-8") + if preset.system_prompt_filename is not None + else None + ) + default_user_prompt_template = _render_user_prompt_template( + (_PROMPTS_DIR / prompt_filename).read_text(encoding="utf-8"), + provide_explanation=provide_explanation, + multi_turn=multi_turn, + ) + return ResolvedJudgePrompt( + preset_name=preset.name, + parser_mode=preset.parser_mode, + system_prompt=system_prompt + if system_prompt is not None + else default_system_prompt, + user_prompt_template=user_prompt_template + if user_prompt_template is not None + else default_user_prompt_template, + ) diff --git a/judgearena/mt_bench/common.py b/judgearena/mt_bench/common.py index d676e05..9c5b095 100644 --- a/judgearena/mt_bench/common.py +++ b/judgearena/mt_bench/common.py @@ -5,7 +5,14 @@ import pandas as pd -from judgearena.utils import safe_text +from judgearena.utils import safe_text_with_metadata + +MT_BENCH_REFERENCE_CATEGORIES: set[str] = { + "math", + "reasoning", + "coding", + "arena-hard-200", +} @dataclass(frozen=True) @@ -20,6 +27,14 @@ class MTBenchPairwiseRow: answer_b_2: str ref_1: str ref_2: str + turn_1_question_truncated: bool = False + turn_2_question_truncated: bool = False + answer_a_1_truncated: bool = False + answer_a_2_truncated: bool = False + answer_b_1_truncated: bool = False + answer_b_2_truncated: bool = False + ref_1_truncated: bool = False + ref_2_truncated: bool = False def iter_mt_bench_pairwise_rows( @@ -41,27 +56,55 @@ def iter_mt_bench_pairwise_rows( if question_id in completions_b.index else completions_b.iloc[0] ) + turn_1_question, turn_1_question_truncated = safe_text_with_metadata( + row.get("turn_1"), + truncate_input_chars, + ) + turn_2_question, turn_2_question_truncated = safe_text_with_metadata( + row.get("turn_2"), + truncate_input_chars, + ) + answer_a_1, answer_a_1_truncated = safe_text_with_metadata( + comp_a_row.get("completion_turn_1", ""), + truncate_input_chars, + ) + answer_a_2, answer_a_2_truncated = safe_text_with_metadata( + comp_a_row.get("completion_turn_2", ""), + truncate_input_chars, + ) + answer_b_1, answer_b_1_truncated = safe_text_with_metadata( + comp_b_row.get("completion_turn_1", ""), + truncate_input_chars, + ) + answer_b_2, answer_b_2_truncated = safe_text_with_metadata( + comp_b_row.get("completion_turn_2", ""), + truncate_input_chars, + ) + ref_1, ref_1_truncated = safe_text_with_metadata( + row.get("reference_turn_1"), + truncate_input_chars, + ) + ref_2, ref_2_truncated = safe_text_with_metadata( + row.get("reference_turn_2"), + truncate_input_chars, + ) yield MTBenchPairwiseRow( question_id=question_id, category=row.get("category"), - turn_1_question=safe_text(row.get("turn_1"), truncate_input_chars), - turn_2_question=safe_text(row.get("turn_2"), truncate_input_chars), - answer_a_1=safe_text( - comp_a_row.get("completion_turn_1", ""), - truncate_input_chars, - ), - answer_a_2=safe_text( - comp_a_row.get("completion_turn_2", ""), - truncate_input_chars, - ), - answer_b_1=safe_text( - comp_b_row.get("completion_turn_1", ""), - truncate_input_chars, - ), - answer_b_2=safe_text( - comp_b_row.get("completion_turn_2", ""), - truncate_input_chars, - ), - ref_1=safe_text(row.get("reference_turn_1"), truncate_input_chars), - ref_2=safe_text(row.get("reference_turn_2"), truncate_input_chars), + turn_1_question=turn_1_question, + turn_2_question=turn_2_question, + answer_a_1=answer_a_1, + answer_a_2=answer_a_2, + answer_b_1=answer_b_1, + answer_b_2=answer_b_2, + ref_1=ref_1, + ref_2=ref_2, + turn_1_question_truncated=turn_1_question_truncated, + turn_2_question_truncated=turn_2_question_truncated, + answer_a_1_truncated=answer_a_1_truncated, + answer_a_2_truncated=answer_a_2_truncated, + answer_b_1_truncated=answer_b_1_truncated, + answer_b_2_truncated=answer_b_2_truncated, + ref_1_truncated=ref_1_truncated, + ref_2_truncated=ref_2_truncated, ) diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 3b0e7ec..d7c4c84 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -4,14 +4,30 @@ import math from dataclasses import dataclass -from pathlib import Path from typing import Any, Literal import pandas as pd from langchain_core.prompts import ChatPromptTemplate -from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows -from judgearena.utils import do_inference +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + SKYWORK_JUDGE_PROMPT_PRESET, +) +from judgearena.mt_bench.common import ( + MT_BENCH_REFERENCE_CATEGORIES, + iter_mt_bench_pairwise_rows, +) +from judgearena.mt_bench.prompt_templates import ( + build_mt_bench_user_prompt_template, + render_mt_bench_prompt_text, +) +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker +from judgearena.utils import ( + LimitEventTracker, + do_inference, + strip_thinking_tags, + strip_thinking_tags_with_metadata, +) FASTCHAT_TEMPERATURE_CONFIG: dict[str, float] = { "writing": 0.7, @@ -25,13 +41,6 @@ "arena-hard-200": 0.0, } -FASTCHAT_NEED_REF_CATS: set[str] = { - "math", - "reasoning", - "coding", - "arena-hard-200", -} - FastChatVerdict = Literal["A", "B", "tie", "error"] PairwiseWinner = Literal["model_A", "model_B", "tie", "error"] @@ -45,21 +54,7 @@ class FastChatPairwisePrompt: ref_based: bool -_PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" _SYSTEM_BASE_FILE = "system-base.txt" -_USER_SINGLE_BASE_FILE = "user-single-base.txt" -_USER_MULTI_BASE_FILE = "user-multi-base.txt" -_USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" -_USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" - - -def _load_prompt_text(filename: str) -> str: - path = _PROMPTS_DIR / filename - return path.read_text(encoding="utf-8") - - -def _render_prompt_text(filename: str, **kwargs: str) -> str: - return _load_prompt_text(filename).format(**kwargs) def _build_system_prompt( @@ -70,7 +65,7 @@ def _build_system_prompt( focus_line: str = "", ) -> str: focus_segment = f"{focus_line} " if focus_line else "" - return _render_prompt_text( + return render_mt_bench_prompt_text( _SYSTEM_BASE_FILE, user_subject=user_subject, task_description=task_description, @@ -79,17 +74,6 @@ def _build_system_prompt( ) -def _build_user_prompt_template(*, multi_turn: bool, ref_based: bool) -> str: - base_filename = _USER_MULTI_BASE_FILE if multi_turn else _USER_SINGLE_BASE_FILE - reference_block = "" - if ref_based: - ref_block_filename = ( - _USER_MULTI_REF_BLOCK_FILE if multi_turn else _USER_SINGLE_REF_BLOCK_FILE - ) - reference_block = _load_prompt_text(ref_block_filename).rstrip("\n") + "\n\n" - return _render_prompt_text(base_filename, reference_block=reference_block) - - def _load_pairwise_prompt( *, name: str, @@ -110,7 +94,7 @@ def _load_pairwise_prompt( begin_instruction=system_begin_instruction, focus_line=system_focus_line, ), - user_prompt_template=_build_user_prompt_template( + user_prompt_template=build_mt_bench_user_prompt_template( multi_turn=multi_turn, ref_based=ref_based, ), @@ -179,12 +163,86 @@ def _load_pairwise_prompt( ) +_SKYWORK_PAIR_V2 = _load_pairwise_prompt( + name="skywork-pair-v2", + multi_turn=False, + ref_based=False, + system_user_subject="prompt displayed below", + system_task_description=( + "You should choose the assistant that follows the user's instructions and " + "answers the user's prompt better. Your evaluation should consider factors " + "such as helpfulness, relevance, accuracy, depth, creativity, and level " + "of detail of the responses." + ), + system_begin_instruction="carefully comparing the two responses", +) + +_SKYWORK_PAIR_V2_MULTI = _load_pairwise_prompt( + name="skywork-pair-v2-multi-turn", + multi_turn=True, + ref_based=False, + system_user_subject="questions", + system_task_description=( + "You should choose the assistant that follows the user's instructions and " + "answers the user's questions better. Your evaluation should consider " + "factors such as helpfulness, relevance, accuracy, depth, creativity, and " + "level of detail of the responses." + ), + system_focus_line=( + "You should focus on which assistant better answers the second user question." + ), + system_begin_instruction="carefully comparing the two conversations", +) + +_SKYWORK_PAIR_MATH_V1 = _load_pairwise_prompt( + name="skywork-pair-math-v1", + multi_turn=False, + ref_based=True, + system_user_subject="prompt displayed below", + system_task_description=( + "You will be given a reference answer, assistant A's answer, and " + "assistant B's answer. Your evaluation should focus on correctness and " + "helpfulness while deciding which assistant is better." + ), + system_begin_instruction="carefully comparing both assistants' answers with the reference answer", +) + +_SKYWORK_PAIR_MATH_V1_MULTI = _load_pairwise_prompt( + name="skywork-pair-math-v1-multi-turn", + multi_turn=True, + ref_based=True, + system_user_subject="questions", + system_task_description=( + "You will be given reference answers together with assistant A's and " + "assistant B's answers. Your evaluation should focus on correctness and " + "helpfulness while deciding which assistant better answers the second user question." + ), + system_begin_instruction="carefully comparing both assistants' answers with the reference answers", +) + +_FASTCHAT_PROMPT_PRESET_REGISTRY: dict[str, dict[str, FastChatPairwisePrompt]] = { + DEFAULT_JUDGE_PROMPT_PRESET: { + "single": _PAIR_V2, + "multi": _PAIR_V2_MULTI, + "single_ref": _PAIR_MATH_V1, + "multi_ref": _PAIR_MATH_V1_MULTI, + }, + SKYWORK_JUDGE_PROMPT_PRESET: { + "single": _SKYWORK_PAIR_V2, + "multi": _SKYWORK_PAIR_V2_MULTI, + "single_ref": _SKYWORK_PAIR_MATH_V1, + "multi_ref": _SKYWORK_PAIR_MATH_V1_MULTI, + }, +} + + def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: - if "[[A]]" in judgment: + stripped = strip_thinking_tags(judgment).strip() + if "[[A]]" in stripped: return "A" - if "[[B]]" in judgment: + if "[[B]]" in stripped: return "B" - if "[[C]]" in judgment: + if "[[C]]" in stripped: return "tie" return "error" @@ -225,15 +283,26 @@ def _winner_to_preference(winner: PairwiseWinner) -> float: return math.nan -def _select_prompt(category: str | None, multi_turn: bool) -> FastChatPairwisePrompt: - needs_ref = (category or "") in FASTCHAT_NEED_REF_CATS +def _select_prompt( + category: str | None, + multi_turn: bool, + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, +) -> FastChatPairwisePrompt: + prompt_variants = _FASTCHAT_PROMPT_PRESET_REGISTRY.get(prompt_preset) + if prompt_variants is None: + supported = ", ".join(sorted(_FASTCHAT_PROMPT_PRESET_REGISTRY)) + raise ValueError( + f"Unsupported MT-Bench prompt preset '{prompt_preset}'. Choose from: {supported}." + ) + needs_ref = (category or "") in MT_BENCH_REFERENCE_CATEGORIES if needs_ref and multi_turn: - return _PAIR_MATH_V1_MULTI + return prompt_variants["multi_ref"] if needs_ref: - return _PAIR_MATH_V1 + return prompt_variants["single_ref"] if multi_turn: - return _PAIR_V2_MULTI - return _PAIR_V2 + return prompt_variants["multi"] + return prompt_variants["single"] def _group_indices_by_prompt( @@ -267,6 +336,9 @@ def _infer_by_prompt_groups( items: list[dict[str, Any]], use_tqdm: bool, swap_answers: bool, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, ) -> list[str]: """Run judge inference, grouping by prompt variant for batching.""" grouped_indices = _group_indices_by_prompt(items) @@ -290,6 +362,9 @@ def _infer_by_prompt_groups( chat_model=judge_chat_model, inputs=prompt_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) for i, out in zip(idxs, outs, strict=True): judgments[i] = str(out) @@ -304,8 +379,38 @@ def _build_fastchat_judge_items( eval_single: bool, eval_multi: bool, truncate_input_chars: int | None, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, + limit_event_tracker: LimitEventTracker | None = None, ) -> list[dict[str, Any]]: items: list[dict[str, Any]] = [] + + def _record_mt_bench_truncation( + *, case_id: str, field: str, truncated: bool + ) -> None: + if truncated and limit_event_tracker is not None: + limit_event_tracker.record( + "mt_bench_field_char_truncation", + stage="judge_input", + field=field, + case_id=case_id, + ) + + def _prepare_answer(answer: str, *, case_id: str, field: str) -> tuple[str, bool]: + if not strip_thinking_before_judging: + return answer, False + stripped_answer, stripped = strip_thinking_tags_with_metadata(answer) + if stripped and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field=field, + case_id=case_id, + original_length=len(answer), + final_length=len(stripped_answer), + ) + return stripped_answer, stripped + for pair_row in iter_mt_bench_pairwise_rows( questions=questions, completions_a=completions_a, @@ -314,14 +419,51 @@ def _build_fastchat_judge_items( ): category = pair_row.category if eval_single: - prompt = _select_prompt(category, multi_turn=False) + case_id = f"{pair_row.question_id}:turn1" + prompt = _select_prompt( + category, multi_turn=False, prompt_preset=prompt_preset + ) + answer_a, answer_a_stripped = _prepare_answer( + pair_row.answer_a_1, case_id=case_id, field="answer_a_1" + ) + answer_b, answer_b_stripped = _prepare_answer( + pair_row.answer_b_1, case_id=case_id, field="answer_b_1" + ) + _record_mt_bench_truncation( + case_id=case_id, + field="turn_1_question", + truncated=pair_row.turn_1_question_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_a_1", + truncated=pair_row.answer_a_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_b_1", + truncated=pair_row.answer_b_1_truncated, + ) kwargs: dict[str, str] = { "question": pair_row.turn_1_question, - "answer_a": pair_row.answer_a_1, - "answer_b": pair_row.answer_b_1, + "answer_a": answer_a, + "answer_b": answer_b, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_a_1_reasoning_stripped": answer_a_stripped, + "answer_b_1_reasoning_stripped": answer_b_stripped, } if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) kwargs["ref_answer_1"] = pair_row.ref_1 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated items.append( { "question_id": pair_row.question_id, @@ -330,22 +472,73 @@ def _build_fastchat_judge_items( "prompt": prompt, "prompt_name": prompt.name, "prompt_kwargs": kwargs, + "limit_flags": limit_flags, } ) if eval_multi and pair_row.turn_2_question: - prompt = _select_prompt(category, multi_turn=True) + case_id = f"{pair_row.question_id}:turn2" + prompt = _select_prompt( + category, multi_turn=True, prompt_preset=prompt_preset + ) + answer_a_1, answer_a_1_stripped = _prepare_answer( + pair_row.answer_a_1, case_id=case_id, field="answer_a_1" + ) + answer_a_2, answer_a_2_stripped = _prepare_answer( + pair_row.answer_a_2, case_id=case_id, field="answer_a_2" + ) + answer_b_1, answer_b_1_stripped = _prepare_answer( + pair_row.answer_b_1, case_id=case_id, field="answer_b_1" + ) + answer_b_2, answer_b_2_stripped = _prepare_answer( + pair_row.answer_b_2, case_id=case_id, field="answer_b_2" + ) + for field, truncated in ( + ("turn_1_question", pair_row.turn_1_question_truncated), + ("turn_2_question", pair_row.turn_2_question_truncated), + ("answer_a_1", pair_row.answer_a_1_truncated), + ("answer_a_2", pair_row.answer_a_2_truncated), + ("answer_b_1", pair_row.answer_b_1_truncated), + ("answer_b_2", pair_row.answer_b_2_truncated), + ): + _record_mt_bench_truncation( + case_id=case_id, field=field, truncated=truncated + ) kwargs = { "question_1": pair_row.turn_1_question, "question_2": pair_row.turn_2_question, - "answer_a_1": pair_row.answer_a_1, - "answer_a_2": pair_row.answer_a_2, - "answer_b_1": pair_row.answer_b_1, - "answer_b_2": pair_row.answer_b_2, + "answer_a_1": answer_a_1, + "answer_a_2": answer_a_2, + "answer_b_1": answer_b_1, + "answer_b_2": answer_b_2, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "turn_2_question_truncated": pair_row.turn_2_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_a_2_truncated": pair_row.answer_a_2_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_b_2_truncated": pair_row.answer_b_2_truncated, + "answer_a_1_reasoning_stripped": answer_a_1_stripped, + "answer_a_2_reasoning_stripped": answer_a_2_stripped, + "answer_b_1_reasoning_stripped": answer_b_1_stripped, + "answer_b_2_reasoning_stripped": answer_b_2_stripped, } if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="ref_2", + truncated=pair_row.ref_2_truncated, + ) kwargs["ref_answer_1"] = pair_row.ref_1 kwargs["ref_answer_2"] = pair_row.ref_2 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + limit_flags["ref_2_truncated"] = pair_row.ref_2_truncated items.append( { "question_id": pair_row.question_id, @@ -354,6 +547,7 @@ def _build_fastchat_judge_items( "prompt": prompt, "prompt_name": prompt.name, "prompt_kwargs": kwargs, + "limit_flags": limit_flags, } ) return items @@ -390,6 +584,7 @@ def _resolve_fastchat_item_result( "g1_verdict": g1_verdict, "g1_winner": g1_winner, } + annotation_row.update(item.get("limit_flags", {})) if g2_raw is not None: g2_verdict = _parse_fastchat_verdict(g2_raw) @@ -434,8 +629,13 @@ def judge_mt_bench_pairwise_fastchat( swap_mode: str, truncate_input_chars: int | None, use_tqdm: bool, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: - """Pairwise MT-Bench judging compatible with FastChat's `[[A]]/[[B]]/[[C]]` format.""" + """Run FastChat-style MT-Bench pairwise judging with bracketed verdict outputs.""" assert turns_mode in ("both", "single", "multi") assert swap_mode in ("fixed", "both") @@ -449,6 +649,9 @@ def judge_mt_bench_pairwise_fastchat( eval_single=eval_single, eval_multi=eval_multi, truncate_input_chars=truncate_input_chars, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, + limit_event_tracker=limit_event_tracker, ) g1_judgments = _infer_by_prompt_groups( @@ -456,6 +659,9 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=False, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, ) g2_judgments: list[str] | None = None @@ -465,6 +671,9 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=True, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, ) annotations: list[dict[str, Any]] = [] diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index b28f859..67afefa 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -6,10 +6,11 @@ from __future__ import annotations +import hashlib import json import os from dataclasses import asdict -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING @@ -18,13 +19,31 @@ from judgearena.eval_utils import _compute_grouped_stats, print_results from judgearena.generate import generate_multiturn from judgearena.instruction_dataset import load_instructions +from judgearena.instruction_dataset.mt_bench import ( + load_mt_bench_model_answers, + mt_bench_native_baseline, +) +from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.log import get_logger from judgearena.mt_bench.fastchat_compat import ( FASTCHAT_TEMPERATURE_CONFIG, judge_mt_bench_pairwise_fastchat, ) -from judgearena.repro import _to_jsonable -from judgearena.utils import cache_function_dataframe, compute_pref_summary, make_model +from judgearena.mt_bench.preset_judging import judge_mt_bench_with_preset +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) +from judgearena.repro import _to_jsonable, write_run_metadata +from judgearena.utils import ( + LimitEventTracker, + build_default_judge_model_kwargs, + cache_function_dataframe, + compute_pref_summary, + is_thinking_model, + make_model, +) logger = get_logger(__name__) @@ -32,14 +51,86 @@ from judgearena.generate_and_evaluate import CliArgs +# Original MT-Bench prompts include a visible explanation before the final verdict. +_MIN_MT_BENCH_JUDGE_TOKENS = 24576 +_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN = 28672 + + +def _build_mt_bench_generation_kwargs( + *, args: CliArgs, model_spec: str +) -> dict[str, object]: + generation_model_kwargs = dict(args.engine_kwargs) + provider, _, model_name = model_spec.partition("/") + if ( + args.battle_thinking_token_budget is not None + and provider == "VLLM" + and is_thinking_model(model_name) + ): + generation_model_kwargs["thinking_token_budget"] = min( + int(args.battle_thinking_token_budget), + int(args.max_out_tokens_models), + ) + return generation_model_kwargs + + +def _build_mt_bench_judge_model_kwargs( + *, args: CliArgs, limit_event_tracker: LimitEventTracker | None +) -> dict[str, object]: + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, + args.engine_kwargs, + judge_engine_kwargs_override=args.judge_engine_kwargs, + ) + if limit_event_tracker is not None: + judge_model_kwargs["limit_event_tracker"] = limit_event_tracker + judge_model_kwargs["limit_event_stage"] = "judge_model_init" + judge_model_kwargs["limit_event_model_spec"] = args.judge_model + return judge_model_kwargs + + +def _mt_bench_generation_cache_name(args: CliArgs, *, model_name: str) -> str: + generation_config = { + "truncate_all_input_chars": args.truncate_all_input_chars, + "max_out_tokens_models": args.max_out_tokens_models, + "max_model_len": args.max_model_len, + "chat_template": args.chat_template, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "engine_kwargs": _build_mt_bench_generation_kwargs( + args=args, model_spec=model_name + ), + } + generation_config_hash = hashlib.sha256( + json.dumps(generation_config, sort_keys=True, default=str).encode("utf-8") + ).hexdigest()[:12] + return f"mt-bench_{model_name}_{args.n_instructions}_{generation_config_hash}" + + +def _align_mt_bench_completions( + *, questions_df: pd.DataFrame, completions: pd.DataFrame, model_name: str +) -> pd.DataFrame: + """Align cached or generated MT-Bench completions to the question order.""" + indexed = completions.set_index("instruction_index") + missing_ids = questions_df.index.difference(indexed.index) + if not missing_ids.empty: + missing_ids_preview = ", ".join(str(x) for x in missing_ids[:5]) + raise ValueError( + f"MT-Bench completions for '{model_name}' are missing " + f"{len(missing_ids)} question(s). First missing ids: {missing_ids_preview}." + ) + return indexed.loc[questions_df.index] + + def _generate_mt_bench_completions( args: CliArgs, questions_df: pd.DataFrame, ignore_cache: bool, + usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, ) -> tuple[pd.DataFrame, pd.DataFrame]: - cache_prefix = "mt-bench" + """Load baseline MT-Bench answers or generate fresh multi-turn outputs.""" - def _run_generation(model_name: str) -> pd.DataFrame: + def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: return generate_multiturn( questions=questions_df, model=model_name, @@ -49,23 +140,42 @@ def _run_generation(model_name: str) -> pd.DataFrame: max_model_len=args.max_model_len, chat_template=args.chat_template, temperature_config=FASTCHAT_TEMPERATURE_CONFIG, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + limit_event_tracker=limit_event_tracker, + strip_thinking_before_turn_2_prompt=args.strip_thinking_before_judging, + **_build_mt_bench_generation_kwargs(args=args, model_spec=model_name), ) - completions_a = cache_function_dataframe( - lambda: _run_generation(args.model_A), - ignore_cache=ignore_cache, - cache_name=f"{cache_prefix}_{args.model_A}_{args.n_instructions}", - ).set_index("instruction_index") + def _load_or_generate(model_name: str, usage_phase: str) -> pd.DataFrame: + loaded_answers = load_mt_bench_model_answers( + model_name, n_instructions=args.n_instructions + ) + if loaded_answers is not None: + print(f"Using pre-generated MT-Bench answers for '{model_name}'.") + return _align_mt_bench_completions( + questions_df=questions_df, + completions=loaded_answers, + model_name=model_name, + ) + generated_answers = cache_function_dataframe( + lambda: _run_generation(model_name, usage_phase), + ignore_cache=ignore_cache, + cache_name=_mt_bench_generation_cache_name(args, model_name=model_name), + ) + return _align_mt_bench_completions( + questions_df=questions_df, + completions=generated_answers, + model_name=model_name, + ) - completions_b = cache_function_dataframe( - lambda: _run_generation(args.model_B), - ignore_cache=ignore_cache, - cache_name=f"{cache_prefix}_{args.model_B}_{args.n_instructions}", - ).set_index("instruction_index") + completions_a = _load_or_generate(args.model_A, "generation_model_A") + completions_b = _load_or_generate(args.model_B, "generation_model_B") return completions_a, completions_b def _build_mt_bench_result_name(args: CliArgs, suffix: str | None = None) -> str: + """Build a filesystem-safe MT-Bench result artifact prefix.""" name = f"{args.task}-{args.model_A}-{args.model_B}-{args.judge_model}" name += f"-{args.swap_mode}" if suffix: @@ -73,6 +183,23 @@ def _build_mt_bench_result_name(args: CliArgs, suffix: str | None = None) -> str return name.replace("/", "_") +def _build_mt_bench_input_payloads( + *, + questions_df: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, +) -> dict[str, object]: + return { + "instruction_index": questions_df.index.tolist(), + "turn_1": questions_df["turn_1"].tolist(), + "turn_2": questions_df["turn_2"].tolist(), + "completion_turn_1_A": completions_a["completion_turn_1"].tolist(), + "completion_turn_2_A": completions_a["completion_turn_2"].tolist(), + "completion_turn_1_B": completions_b["completion_turn_1"].tolist(), + "completion_turn_2_B": completions_b["completion_turn_2"].tolist(), + } + + def _save_mt_bench_results( *, args: CliArgs, @@ -80,7 +207,16 @@ def _save_mt_bench_results( result_name: str, results: dict[str, object], annotations_df: pd.DataFrame, + questions_df: pd.DataFrame, + pricing_reference: dict[str, object] | None, + started_at_utc: datetime, + input_payloads: dict[str, object] | None = None, + judge_system_prompt: str | None = None, + judge_user_prompt_template: str | None = None, ) -> None: + """Persist MT-Bench arguments, annotations, and aggregate results.""" + res_folder.mkdir(parents=True, exist_ok=True) + with open(res_folder / f"args-{result_name}.json", "w") as f: json.dump(_to_jsonable(asdict(args)), f, indent=2, allow_nan=False) @@ -89,6 +225,23 @@ def _save_mt_bench_results( with open(res_folder / f"results-{result_name}.json", "w") as f: json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) + write_run_metadata( + output_dir=res_folder, + entrypoint="judgearena.mt_bench.mt_bench_utils.run_mt_bench", + run=asdict(args), + results=results, + input_payloads=input_payloads + or { + "instruction_index": questions_df.index.tolist(), + "turn_1": questions_df["turn_1"].tolist(), + "turn_2": questions_df["turn_2"].tolist(), + }, + judge_system_prompt=judge_system_prompt, + judge_user_prompt_template=judge_user_prompt_template, + started_at_utc=started_at_utc, + pricing_reference=pricing_reference, + ) + def _run_mt_bench_fastchat( *, @@ -99,7 +252,12 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, + prompt_preset: str, + usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, + started_at_utc: datetime, ) -> pd.Series: + """Run FastChat-style MT-Bench judging and save the resulting artifacts.""" prefs, annotations, combined_metadata, num_inconsistent = ( judge_mt_bench_pairwise_fastchat( judge_chat_model=judge_chat_model, @@ -111,8 +269,13 @@ def _run_mt_bench_fastchat( model_b=args.model_B, turns_mode="both", swap_mode=args.swap_mode, - truncate_input_chars=args.truncate_all_input_chars, + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=args.use_tqdm, + prompt_preset=prompt_preset, + strip_thinking_before_judging=args.strip_thinking_before_judging, + usage_tracker=usage_tracker, + usage_phase="judge", + limit_event_tracker=limit_event_tracker, ) ) @@ -122,8 +285,15 @@ def _run_mt_bench_fastchat( "model_A": args.model_A, "model_B": args.model_B, "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, "num_inconsistent": num_inconsistent, **stats, + "limit_events": limit_event_tracker.build_summary() + if limit_event_tracker is not None + else {}, "per_category": _compute_grouped_stats(prefs, combined_metadata, "category"), "per_turn": _compute_grouped_stats(prefs, combined_metadata, "turn"), "preferences": prefs.tolist(), @@ -131,12 +301,132 @@ def _run_mt_bench_fastchat( "user": os.getenv("USER", ""), } print_results(results) + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={ + "generation_model_A": args.model_A, + "generation_model_B": args.model_B, + "judge": args.judge_model, + }, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) _save_mt_bench_results( args=args, res_folder=res_folder, result_name=result_name, results=results, annotations_df=pd.DataFrame(annotations), + questions_df=questions_df, + pricing_reference=pricing_reference, + started_at_utc=started_at_utc, + input_payloads=_build_mt_bench_input_payloads( + questions_df=questions_df, + completions_a=completions_a, + completions_b=completions_b, + ), + ) + return prefs + + +def _run_mt_bench_preset( + *, + args: CliArgs, + res_folder: Path, + result_name: str, + questions_df: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + judge_chat_model, + prompt_preset: str, + usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, + started_at_utc: datetime, +) -> pd.Series: + prefs, annotations, combined_metadata, _num_inconsistent = ( + judge_mt_bench_with_preset( + judge_chat_model=judge_chat_model, + judge_model=args.judge_model, + questions=questions_df, + completions_a=completions_a, + completions_b=completions_b, + model_a=args.model_A, + model_b=args.model_B, + turns_mode="both", + swap_mode=args.swap_mode, + truncate_input_chars=args.truncate_judge_input_chars, + use_tqdm=args.use_tqdm, + prompt_preset=prompt_preset, + provide_explanation=args.provide_explanation, + strip_thinking_before_judging=args.strip_thinking_before_judging, + judge_tokenizer=getattr(judge_chat_model, "tokenizer", None), + max_judge_model_len=args.max_judge_model_len, + max_out_tokens_judge=args.max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase="judge", + limit_event_tracker=limit_event_tracker, + ) + ) + + stats = compute_pref_summary(prefs) + results = { + "task": args.task, + "model_A": args.model_A, + "model_B": args.model_B, + "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + **stats, + "limit_events": limit_event_tracker.build_summary() + if limit_event_tracker is not None + else {}, + "per_category": _compute_grouped_stats(prefs, combined_metadata, "category"), + "per_turn": _compute_grouped_stats(prefs, combined_metadata, "turn"), + "preferences": prefs.tolist(), + "date": str(datetime.now().isoformat()), + "user": os.getenv("USER", ""), + } + print_results(results) + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={ + "generation_model_A": args.model_A, + "generation_model_B": args.model_B, + "judge": args.judge_model, + }, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) + unique_system_prompts = { + row.get("system_prompt") + for row in annotations + if row.get("system_prompt") is not None + } + unique_user_templates = { + row.get("user_prompt_template") + for row in annotations + if row.get("user_prompt_template") is not None + } + _save_mt_bench_results( + args=args, + res_folder=res_folder, + result_name=result_name, + results=results, + annotations_df=pd.DataFrame(annotations), + questions_df=questions_df, + pricing_reference=pricing_reference, + started_at_utc=started_at_utc, + input_payloads=_build_mt_bench_input_payloads( + questions_df=questions_df, + completions_a=completions_a, + completions_b=completions_b, + ), + judge_system_prompt=next(iter(unique_system_prompts), None) + if len(unique_system_prompts) == 1 + else None, + judge_user_prompt_template=next(iter(unique_user_templates), None) + if len(unique_user_templates) == 1 + else None, ) return prefs @@ -145,10 +435,44 @@ def run_mt_bench( args: CliArgs, ignore_cache: bool, *, - res_folder: Path, - result_name: str, + res_folder: Path | None = None, + result_name: str | None = None, ): - """MT-Bench pipeline with FastChat-compatible pairwise judging.""" + """MT-Bench pipeline with preset or FastChat-original pairwise judging.""" + run_started_at = datetime.now(UTC) + usage_tracker = OpenRouterReferencePricingTracker() + limit_event_tracker = LimitEventTracker() + prompt_preset = args.judge_prompt_preset or DEFAULT_JUDGE_PROMPT_PRESET + fastchat_mode = args.mt_bench_judge_mode == "fastchat_original" + if ( + fastchat_mode + and prompt_preset == DEFAULT_JUDGE_PROMPT_PRESET + and not args.provide_explanation + ): + logger.info( + "MT-Bench ignores provide_explanation=False and keeps the original " + "FastChat-style explanation-plus-verdict prompt." + ) + if args.model_B is None: + args.model_B = mt_bench_native_baseline(args.task) + if args.model_B is None: + raise ValueError( + f"--model_B is required for dataset '{args.task}'; " + "no dataset-native baseline registered." + ) + if result_name is None: + result_name = _build_mt_bench_result_name(args, suffix="mtbench") + if res_folder is None: + res_folder = Path(args.result_folder) / result_name + res_folder.mkdir(parents=True, exist_ok=True) + if fastchat_mode and args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: + logger.warning( + "MT-Bench judge prompts request an explanation before the final " + "verdict; max_out_tokens_judge=%s may be too small " + "(recommended >= %s).", + args.max_out_tokens_judge, + _MIN_MT_BENCH_JUDGE_TOKENS, + ) questions_df = load_instructions("mt-bench", n_instructions=args.n_instructions) logger.info( "Generating multi-turn completions for MT-Bench with %s and %s.", @@ -159,15 +483,63 @@ def run_mt_bench( args=args, questions_df=questions_df, ignore_cache=ignore_cache, + usage_tracker=usage_tracker, + limit_event_tracker=limit_event_tracker, ) + if args.skip_judging: + res_folder.mkdir(parents=True, exist_ok=True) + with open(res_folder / f"args-{result_name}.json", "w") as f: + json.dump(_to_jsonable(asdict(args)), f, indent=2, allow_nan=False) + generation_summary = { + "task": args.task, + "model_A": args.model_A, + "model_B": args.model_B, + "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, + "n_instructions": args.n_instructions + if args.n_instructions is not None + else len(questions_df), + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "limit_events": limit_event_tracker.build_summary(), + "skip_judging": True, + } + with open(res_folder / f"gen-results-{result_name}.json", "w") as f: + json.dump(_to_jsonable(generation_summary), f, indent=2, allow_nan=False) + logger.info( + "skip_judging=True: wrote gen-results-%s.json and returning before judge model construction.", + result_name, + ) + return None + if fastchat_mode and ( + args.max_judge_model_len is not None + and args.max_judge_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + ): + logger.warning( + "MT-Bench judge prompts request an explanation before the final " + "verdict; max_judge_model_len=%s may be too small for prompt plus " + "completion (recommended >= %s).", + args.max_judge_model_len, + _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN, + ) + judge_model_kwargs = { + "model": args.judge_model, + "max_tokens": args.max_out_tokens_judge, + "max_model_len": args.max_judge_model_len, + "chat_template": args.chat_template, + **_build_mt_bench_judge_model_kwargs( + args=args, + limit_event_tracker=limit_event_tracker, + ), + } + if fastchat_mode: + judge_model_kwargs["temperature"] = 0.0 judge_chat_model = make_model( - model=args.judge_model, - max_tokens=args.max_out_tokens_judge, - temperature=0.0, - max_model_len=args.max_model_len, - chat_template=args.chat_template, + **judge_model_kwargs, ) - return _run_mt_bench_fastchat( + runner = _run_mt_bench_fastchat if fastchat_mode else _run_mt_bench_preset + return runner( args=args, res_folder=res_folder, result_name=result_name, @@ -175,4 +547,8 @@ def run_mt_bench( completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, + prompt_preset=prompt_preset, + usage_tracker=usage_tracker, + limit_event_tracker=limit_event_tracker, + started_at_utc=run_started_at, ) diff --git a/judgearena/mt_bench/preset_judging.py b/judgearena/mt_bench/preset_judging.py new file mode 100644 index 0000000..ded587b --- /dev/null +++ b/judgearena/mt_bench/preset_judging.py @@ -0,0 +1,645 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + +import pandas as pd +from langchain_core.prompts import ChatPromptTemplate + +from judgearena.evaluate import ( + _PREFLIGHT_MAX_ITERATIONS, + _PREFLIGHT_MIN_COMPLETION_CHARS, + _PREFLIGHT_RESERVED_TOKENS, + PairScore, + _chars_per_token, + _count_chat_tokens, + _find_token_overflows, +) +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + ResolvedJudgePrompt, + resolve_pairwise_judge_prompt, +) +from judgearena.log import get_logger +from judgearena.mt_bench.common import ( + MT_BENCH_REFERENCE_CATEGORIES, + iter_mt_bench_pairwise_rows, +) +from judgearena.mt_bench.prompt_templates import build_mt_bench_user_prompt_template +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker +from judgearena.utils import ( + LimitEventTracker, + do_inference, + strip_thinking_tags_with_metadata, + truncate_with_metadata, +) + +logger = get_logger(__name__) + + +@dataclass(frozen=True) +class MTBenchPresetPrompt: + name: str + preset_name: str + parser_mode: str + system_prompt: str | None + user_prompt_template: str + multi_turn: bool + ref_based: bool + + +def _extract_output_section(user_prompt_template: str) -> str: + marker = "# Your output" + marker_index = user_prompt_template.find(marker) + if marker_index < 0: + raise ValueError("Could not find '# Your output' section in preset template.") + return user_prompt_template[marker_index:].lstrip() + + +def _extract_user_preamble(user_prompt_template: str) -> str: + marker = "[User Question]" + marker_index = user_prompt_template.find(marker) + if marker_index < 0: + raise ValueError("Could not find '[User Question]' section in preset template.") + return user_prompt_template[:marker_index].rstrip() + + +def _build_mt_bench_preset_user_prompt_template( + *, + resolved_prompt: ResolvedJudgePrompt, + multi_turn: bool, + ref_based: bool, +) -> str: + base_template = build_mt_bench_user_prompt_template( + multi_turn=multi_turn, + ref_based=ref_based, + ) + if resolved_prompt.system_prompt is None: + user_preamble = _extract_user_preamble(resolved_prompt.user_prompt_template) + return f"{user_preamble}\n\n{base_template}" + output_section = _extract_output_section(resolved_prompt.user_prompt_template) + return f"{base_template}\n\n{output_section}" + + +def _select_preset_prompt( + category: str | None, + multi_turn: bool, + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool, +) -> MTBenchPresetPrompt: + ref_based = (category or "") in MT_BENCH_REFERENCE_CATEGORIES + resolved_prompt = resolve_pairwise_judge_prompt( + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + multi_turn=multi_turn, + ) + suffix = "multi" if multi_turn else "single" + if ref_based: + suffix += "_ref" + return MTBenchPresetPrompt( + name=f"{resolved_prompt.preset_name}-{suffix}", + preset_name=resolved_prompt.preset_name, + parser_mode=resolved_prompt.parser_mode, + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=_build_mt_bench_preset_user_prompt_template( + resolved_prompt=resolved_prompt, + multi_turn=multi_turn, + ref_based=ref_based, + ), + multi_turn=multi_turn, + ref_based=ref_based, + ) + + +def _group_indices_by_prompt(items: list[dict[str, Any]]) -> dict[str, list[int]]: + grouped: dict[str, list[int]] = {} + for idx, item in enumerate(items): + grouped.setdefault(item["prompt_name"], []).append(idx) + return grouped + + +def _swap_prompt_kwargs(kwargs: dict[str, str], *, multi_turn: bool) -> dict[str, str]: + swapped = dict(kwargs) + if multi_turn: + swapped["answer_a_1"], swapped["answer_b_1"] = ( + swapped["answer_b_1"], + swapped["answer_a_1"], + ) + swapped["answer_a_2"], swapped["answer_b_2"] = ( + swapped["answer_b_2"], + swapped["answer_a_2"], + ) + return swapped + swapped["answer_a"], swapped["answer_b"] = swapped["answer_b"], swapped["answer_a"] + return swapped + + +def _build_chat_prompt_template(prompt: MTBenchPresetPrompt) -> ChatPromptTemplate: + message_templates: list[tuple[str, str]] = [] + if prompt.system_prompt is not None: + message_templates.append(("system", prompt.system_prompt)) + message_templates.append(("user", prompt.user_prompt_template)) + return ChatPromptTemplate.from_messages(message_templates) + + +def _answer_field_names(prompt: MTBenchPresetPrompt) -> tuple[str, ...]: + if prompt.multi_turn: + return ("answer_a_1", "answer_a_2", "answer_b_1", "answer_b_2") + return ("answer_a", "answer_b") + + +def _truncation_flag_name(field: str) -> str: + if field == "answer_a": + return "answer_a_1_truncated" + if field == "answer_b": + return "answer_b_1_truncated" + return f"{field}_truncated" + + +def _preflight_prompt_group_to_judge_budget( + *, + prompt_template: ChatPromptTemplate, + prompt_kwargs_batch: list[dict[str, str]], + batch_items: list[dict[str, Any]], + judge_tokenizer: Any, + max_judge_model_len: int, + max_out_tokens_judge: int | None, + limit_event_tracker: LimitEventTracker | None, +) -> list[Any]: + prompt_inputs = prompt_template.batch(prompt_kwargs_batch) + safe_budget = ( + max_judge_model_len - (max_out_tokens_judge or 0) - _PREFLIGHT_RESERVED_TOKENS + ) + + for _ in range(_PREFLIGHT_MAX_ITERATIONS): + overflows = _find_token_overflows(prompt_inputs, judge_tokenizer, safe_budget) + if not overflows: + return prompt_inputs + + for idx, _token_count in overflows: + prompt_kwargs = prompt_kwargs_batch[idx] + item = batch_items[idx] + answer_fields = item["answer_fields"] + if not answer_fields: + continue + + empty_kwargs = dict(prompt_kwargs) + for field in answer_fields: + empty_kwargs[field] = "" + fixed_tokens = _count_chat_tokens( + prompt_template.invoke(empty_kwargs), + judge_tokenizer, + ) + per_answer_budget = max( + 256, (safe_budget - fixed_tokens) // len(answer_fields) + ) + + for field in answer_fields: + prompt_kwargs[field], shrunk = truncate_with_metadata( + prompt_kwargs[field], + max_len=max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int( + per_answer_budget + * _chars_per_token(prompt_kwargs[field], judge_tokenizer) + * 0.9 + ), + ), + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field=field, + case_id=item["case_id"], + ) + if shrunk: + item["limit_flags"][_truncation_flag_name(field)] = True + + prompt_inputs = prompt_template.batch(prompt_kwargs_batch) + + final_overflows = _find_token_overflows(prompt_inputs, judge_tokenizer, safe_budget) + for idx, token_count in final_overflows: + if limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_token_truncation_failed", + stage="judge_input", + case_id=batch_items[idx]["case_id"], + original_length=token_count, + final_length=safe_budget, + note=( + f"{_PREFLIGHT_MAX_ITERATIONS} shrink iterations did not " + f"bring tokens under {safe_budget}; falling through to " + "vLLM validation." + ), + ) + return prompt_inputs + + +def _infer_by_prompt_groups( + *, + judge_chat_model, + items: list[dict[str, Any]], + use_tqdm: bool, + swap_answers: bool, + judge_tokenizer: Any | None = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, +) -> tuple[list[str], list[dict[str, str]]]: + judgments: list[str] = [""] * len(items) + used_prompt_kwargs: list[dict[str, str]] = [{} for _ in items] + for idxs in _group_indices_by_prompt(items).values(): + prompt: MTBenchPresetPrompt = items[idxs[0]]["prompt"] + prompt_template = _build_chat_prompt_template(prompt) + + batch_kwargs: list[dict[str, str]] = [] + batch_items = [items[item_index] for item_index in idxs] + for item_index in idxs: + prompt_kwargs = dict(items[item_index]["prompt_kwargs"]) + if swap_answers: + prompt_kwargs = _swap_prompt_kwargs( + prompt_kwargs, + multi_turn=prompt.multi_turn, + ) + batch_kwargs.append(prompt_kwargs) + + if judge_tokenizer is not None and max_judge_model_len is not None: + prompt_inputs = _preflight_prompt_group_to_judge_budget( + prompt_template=prompt_template, + prompt_kwargs_batch=batch_kwargs, + batch_items=batch_items, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + limit_event_tracker=items[idxs[0]].get("limit_event_tracker"), + ) + else: + prompt_inputs = prompt_template.batch(batch_kwargs) + outputs = do_inference( + chat_model=judge_chat_model, + inputs=prompt_inputs, + use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, + ) + for item_index, output, prompt_kwargs in zip( + idxs, outputs, batch_kwargs, strict=True + ): + judgments[item_index] = str(output) + used_prompt_kwargs[item_index] = prompt_kwargs + return judgments, used_prompt_kwargs + + +def _build_mt_bench_preset_items( + *, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + eval_single: bool, + eval_multi: bool, + truncate_input_chars: int | None, + prompt_preset: str, + provide_explanation: bool, + strip_thinking_before_judging: bool, + limit_event_tracker: LimitEventTracker | None, +) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + truncated_field_count = 0 + + def _record_mt_bench_truncation( + *, + case_id: str, + field: str, + truncated: bool, + ) -> None: + nonlocal truncated_field_count + if truncated and limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_char_truncation", + stage="judge_input", + field=field, + case_id=case_id, + ) + truncated_field_count += int(truncated) + + def _prepare_answer(answer: str, *, case_id: str, field: str) -> tuple[str, bool]: + if not strip_thinking_before_judging: + return answer, False + stripped_answer, stripped = strip_thinking_tags_with_metadata(answer) + if stripped and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field=field, + case_id=case_id, + original_length=len(answer), + final_length=len(stripped_answer), + ) + return stripped_answer, stripped + + for pair_row in iter_mt_bench_pairwise_rows( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + truncate_input_chars=truncate_input_chars, + ): + category = pair_row.category + if eval_single: + case_id = f"{pair_row.question_id}:turn1" + prompt = _select_preset_prompt( + category, + multi_turn=False, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + ) + answer_a, answer_a_stripped = _prepare_answer( + pair_row.answer_a_1, + case_id=case_id, + field="answer_a_1", + ) + answer_b, answer_b_stripped = _prepare_answer( + pair_row.answer_b_1, + case_id=case_id, + field="answer_b_1", + ) + _record_mt_bench_truncation( + case_id=case_id, + field="turn_1_question", + truncated=pair_row.turn_1_question_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_a_1", + truncated=pair_row.answer_a_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_b_1", + truncated=pair_row.answer_b_1_truncated, + ) + prompt_kwargs: dict[str, str] = { + "question": pair_row.turn_1_question, + "answer_a": answer_a, + "answer_b": answer_b, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_a_1_reasoning_stripped": answer_a_stripped, + "answer_b_1_reasoning_stripped": answer_b_stripped, + } + if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + prompt_kwargs["ref_answer_1"] = pair_row.ref_1 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + items.append( + { + "case_id": case_id, + "question_id": pair_row.question_id, + "category": category, + "turn": 1, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": prompt_kwargs, + "answer_fields": _answer_field_names(prompt), + "limit_flags": limit_flags, + "limit_event_tracker": limit_event_tracker, + } + ) + + if eval_multi and pair_row.turn_2_question: + case_id = f"{pair_row.question_id}:turn2" + prompt = _select_preset_prompt( + category, + multi_turn=True, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + ) + answer_a_1, answer_a_1_stripped = _prepare_answer( + pair_row.answer_a_1, + case_id=case_id, + field="answer_a_1", + ) + answer_a_2, answer_a_2_stripped = _prepare_answer( + pair_row.answer_a_2, + case_id=case_id, + field="answer_a_2", + ) + answer_b_1, answer_b_1_stripped = _prepare_answer( + pair_row.answer_b_1, + case_id=case_id, + field="answer_b_1", + ) + answer_b_2, answer_b_2_stripped = _prepare_answer( + pair_row.answer_b_2, + case_id=case_id, + field="answer_b_2", + ) + for field, truncated in ( + ("turn_1_question", pair_row.turn_1_question_truncated), + ("turn_2_question", pair_row.turn_2_question_truncated), + ("answer_a_1", pair_row.answer_a_1_truncated), + ("answer_a_2", pair_row.answer_a_2_truncated), + ("answer_b_1", pair_row.answer_b_1_truncated), + ("answer_b_2", pair_row.answer_b_2_truncated), + ): + _record_mt_bench_truncation( + case_id=case_id, + field=field, + truncated=truncated, + ) + prompt_kwargs = { + "question_1": pair_row.turn_1_question, + "question_2": pair_row.turn_2_question, + "answer_a_1": answer_a_1, + "answer_a_2": answer_a_2, + "answer_b_1": answer_b_1, + "answer_b_2": answer_b_2, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "turn_2_question_truncated": pair_row.turn_2_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_a_2_truncated": pair_row.answer_a_2_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_b_2_truncated": pair_row.answer_b_2_truncated, + "answer_a_1_reasoning_stripped": answer_a_1_stripped, + "answer_a_2_reasoning_stripped": answer_a_2_stripped, + "answer_b_1_reasoning_stripped": answer_b_1_stripped, + "answer_b_2_reasoning_stripped": answer_b_2_stripped, + } + if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="ref_2", + truncated=pair_row.ref_2_truncated, + ) + prompt_kwargs["ref_answer_1"] = pair_row.ref_1 + prompt_kwargs["ref_answer_2"] = pair_row.ref_2 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + limit_flags["ref_2_truncated"] = pair_row.ref_2_truncated + items.append( + { + "case_id": case_id, + "question_id": pair_row.question_id, + "category": category, + "turn": 2, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": prompt_kwargs, + "answer_fields": _answer_field_names(prompt), + "limit_flags": limit_flags, + "limit_event_tracker": limit_event_tracker, + } + ) + if truncated_field_count: + logger.warning( + "Warning: truncated %s judge inputs to %s characters before evaluation.", + truncated_field_count, + truncate_input_chars, + ) + return items + + +def _normalize_preference(preference: float | None, *, swapped: bool) -> float: + if preference is None: + return math.nan + return 1.0 - preference if swapped else float(preference) + + +def judge_mt_bench_with_preset( + *, + judge_chat_model, + judge_model: str, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + model_a: str, + model_b: str, + turns_mode: str, + swap_mode: str, + truncate_input_chars: int | None, + use_tqdm: bool, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool = False, + strip_thinking_before_judging: bool = False, + judge_tokenizer: Any | None = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, +) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: + assert turns_mode in ("both", "single", "multi") + assert swap_mode in ("fixed", "both") + + eval_single = turns_mode in ("both", "single") + eval_multi = turns_mode in ("both", "multi") + + items = _build_mt_bench_preset_items( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + eval_single=eval_single, + eval_multi=eval_multi, + truncate_input_chars=truncate_input_chars, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + strip_thinking_before_judging=strip_thinking_before_judging, + limit_event_tracker=limit_event_tracker, + ) + + judgments, prompt_kwargs_used = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=False, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, + ) + + annotations: list[dict[str, Any]] = [] + metadata: list[dict[str, object]] = [] + preferences: list[float] = [] + + def _append_results( + raw_judgments: list[str], + used_prompt_kwargs: list[dict[str, str]], + *, + swapped: bool, + ) -> None: + for item, raw_judgment, prompt_kwargs in zip( + items, raw_judgments, used_prompt_kwargs, strict=True + ): + prompt: MTBenchPresetPrompt = item["prompt"] + parsed_preference = PairScore( + parser_mode=prompt.parser_mode + ).parse_model_raw(raw_judgment) + normalized_preference = _normalize_preference( + parsed_preference, + swapped=swapped, + ) + annotation_row = { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + "model_A": model_b if swapped else model_a, + "model_B": model_a if swapped else model_b, + "judge": judge_model, + "prompt_name": prompt.name, + "prompt_preset": prompt.preset_name, + "parser_mode": prompt.parser_mode, + "system_prompt": prompt.system_prompt, + "user_prompt_template": prompt.user_prompt_template, + "user_prompt": prompt.user_prompt_template.format(**prompt_kwargs), + "judge_completion": raw_judgment, + "preference": normalized_preference, + "swapped": swapped, + } + annotation_row.update(item.get("limit_flags", {})) + annotations.append(annotation_row) + metadata.append( + { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + } + ) + preferences.append(normalized_preference) + + _append_results(judgments, prompt_kwargs_used, swapped=False) + + if swap_mode == "both": + swapped_judgments, swapped_prompt_kwargs = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=True, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, + ) + _append_results(swapped_judgments, swapped_prompt_kwargs, swapped=True) + + return pd.Series(preferences, dtype=float), annotations, metadata, 0 diff --git a/judgearena/mt_bench/prompt_templates.py b/judgearena/mt_bench/prompt_templates.py new file mode 100644 index 0000000..edef887 --- /dev/null +++ b/judgearena/mt_bench/prompt_templates.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from pathlib import Path + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" +_USER_SINGLE_BASE_FILE = "user-single-base.txt" +_USER_MULTI_BASE_FILE = "user-multi-base.txt" +_USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" +_USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" + + +def load_mt_bench_prompt_text(filename: str) -> str: + path = _PROMPTS_DIR / filename + return path.read_text(encoding="utf-8") + + +def render_mt_bench_prompt_text(filename: str, **kwargs: str) -> str: + return load_mt_bench_prompt_text(filename).format(**kwargs) + + +def build_mt_bench_user_prompt_template(*, multi_turn: bool, ref_based: bool) -> str: + base_filename = _USER_MULTI_BASE_FILE if multi_turn else _USER_SINGLE_BASE_FILE + reference_block = "" + if ref_based: + ref_block_filename = ( + _USER_MULTI_REF_BLOCK_FILE if multi_turn else _USER_SINGLE_REF_BLOCK_FILE + ) + reference_block = ( + load_mt_bench_prompt_text(ref_block_filename).rstrip("\n") + "\n\n" + ) + return render_mt_bench_prompt_text(base_filename, reference_block=reference_block) diff --git a/judgearena/openrouter_reference_pricing.py b/judgearena/openrouter_reference_pricing.py new file mode 100644 index 0000000..a1bf052 --- /dev/null +++ b/judgearena/openrouter_reference_pricing.py @@ -0,0 +1,624 @@ +"""Reference pricing utilities for local JudgeArena runs. + +This module counts local prompt/completion tokens and, when an exact +OpenRouter model match exists, attaches a comparable public-price estimate. +Refresh the cached catalog on a machine with internet access via +`uv run python -m judgearena.openrouter_reference_pricing --refresh`. +By default the cache lives under +`$JUDGEARENA_DATA/reference_pricing/openrouter_models.json`, unless +`JUDGEARENA_OPENROUTER_PRICE_CACHE` overrides it. +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import urllib.error +import urllib.request +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +OPENROUTER_MODELS_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_PRICE_CACHE_ENV = "JUDGEARENA_OPENROUTER_PRICE_CACHE" +DEFAULT_CACHE_RELATIVE_PATH = Path("reference_pricing") / "openrouter_models.json" +_KNOWN_PROVIDER_PREFIXES = frozenset( + { + "ChatOpenAI", + "Dummy", + "LlamaCpp", + "OpenAI", + "OpenRouter", + "Together", + "VLLM", + } +) +_UNAPPLIED_PRICE_COMPONENTS = ( + "image", + "input_cache_read", + "input_cache_write", + "internal_reasoning", + "web_search", +) +_LOCAL_VARIANT_SUFFIX_RE = re.compile( + r"(?i)(?:[-_](?:fp8|fp16|bf16|int8|int4|int3|awq|gptq(?:[-_][a-z0-9]+)*))+$" +) + + +def _data_root_path() -> Path: + raw = os.environ.get("JUDGEARENA_DATA") or os.environ.get("OPENJURY_DATA") + if raw: + return Path(raw).expanduser() + return Path("~/judgearena-data/").expanduser() + + +def get_openrouter_price_cache_path() -> Path: + raw = os.environ.get(OPENROUTER_PRICE_CACHE_ENV) + if raw: + return Path(raw).expanduser() + return _data_root_path() / DEFAULT_CACHE_RELATIVE_PATH + + +def _utc_now_iso() -> str: + return datetime.now(UTC).isoformat() + + +def _as_price_float(raw_value: object) -> float: + if raw_value in (None, ""): + return 0.0 + return float(raw_value) + + +@dataclass(frozen=True) +class OpenRouterModelPricing: + prompt: float + completion: float + request: float = 0.0 + image: float = 0.0 + web_search: float = 0.0 + internal_reasoning: float = 0.0 + input_cache_read: float = 0.0 + input_cache_write: float = 0.0 + + +@dataclass(frozen=True) +class OpenRouterModelEntry: + model_id: str + canonical_slug: str | None + hugging_face_id: str | None + name: str + pricing: OpenRouterModelPricing + + def exact_match_candidates(self) -> tuple[str, ...]: + candidates = [self.model_id] + if self.canonical_slug: + candidates.append(self.canonical_slug) + if self.hugging_face_id: + candidates.append(self.hugging_face_id) + return tuple(candidates) + + +@dataclass(frozen=True) +class OpenRouterPriceCatalog: + source_url: str + fetched_at_utc: str | None + cache_path: str + models: tuple[OpenRouterModelEntry, ...] + + +@dataclass(frozen=True) +class TokenUsageRecord: + phase: str + model_spec: str + prompt_tokens: int + completion_tokens: int + requests: int = 1 + + +class OpenRouterReferencePricingTracker: + def __init__(self) -> None: + self._records: list[TokenUsageRecord] = [] + + @property + def records(self) -> list[TokenUsageRecord]: + return list(self._records) + + def has_records(self) -> bool: + return bool(self._records) + + def record_batch_from_model( + self, + *, + phase: str, + model_spec: str, + chat_model: object, + inputs: list, + outputs: list[str], + ) -> bool: + if not hasattr(chat_model, "count_prompt_tokens_batch") or not hasattr( + chat_model, "count_completion_tokens_batch" + ): + return False + + prompt_tokens = chat_model.count_prompt_tokens_batch(inputs) + completion_tokens = chat_model.count_completion_tokens_batch(outputs) + if len(prompt_tokens) != len(completion_tokens) or len(prompt_tokens) != len( + outputs + ): + raise ValueError("Prompt/completion token counts must align with outputs.") + + for prompt_count, completion_count in zip( + prompt_tokens, completion_tokens, strict=True + ): + self._records.append( + TokenUsageRecord( + phase=phase, + model_spec=model_spec, + prompt_tokens=int(prompt_count), + completion_tokens=int(completion_count), + ) + ) + return True + + def record_batch_from_usage_metadata( + self, + *, + phase: str, + model_spec: str, + usages: list[dict[str, Any] | None], + ) -> bool: + """Record API-reported token usage extracted from LangChain AIMessages. + + Each entry in ``usages`` is either ``None`` (no usage data) or a dict + carrying ``input_tokens``/``output_tokens`` (langchain-core shape) or + ``prompt_tokens``/``completion_tokens`` (OpenAI-shape). This is the + path used for OpenRouter / ChatOpenAI-backed models, which do not + expose ``count_*_batch`` helpers but do return per-call usage from + the upstream API. + + Returns ``True`` if at least one entry produced a record, else + ``False`` so callers can fall back to ``record_batch_from_model``. + """ + recorded = False + for usage in usages: + if not isinstance(usage, dict): + continue + prompt = usage.get("input_tokens") + if prompt is None: + prompt = usage.get("prompt_tokens") + completion = usage.get("output_tokens") + if completion is None: + completion = usage.get("completion_tokens") + if prompt is None and completion is None: + continue + self._records.append( + TokenUsageRecord( + phase=phase, + model_spec=model_spec, + prompt_tokens=int(prompt or 0), + completion_tokens=int(completion or 0), + ) + ) + recorded = True + return recorded + + +def _parse_catalog_model(raw_model: dict[str, Any]) -> OpenRouterModelEntry: + raw_pricing = raw_model.get("pricing") or {} + pricing = OpenRouterModelPricing( + prompt=_as_price_float(raw_pricing.get("prompt")), + completion=_as_price_float(raw_pricing.get("completion")), + request=_as_price_float(raw_pricing.get("request")), + image=_as_price_float(raw_pricing.get("image")), + web_search=_as_price_float(raw_pricing.get("web_search")), + internal_reasoning=_as_price_float(raw_pricing.get("internal_reasoning")), + input_cache_read=_as_price_float(raw_pricing.get("input_cache_read")), + input_cache_write=_as_price_float(raw_pricing.get("input_cache_write")), + ) + return OpenRouterModelEntry( + model_id=str(raw_model["id"]), + canonical_slug=( + str(raw_model["canonical_slug"]) + if raw_model.get("canonical_slug") is not None + else None + ), + hugging_face_id=( + str(raw_model["hugging_face_id"]) + if raw_model.get("hugging_face_id") is not None + else None + ), + name=str(raw_model.get("name") or raw_model["id"]), + pricing=pricing, + ) + + +def parse_openrouter_catalog_payload( + payload: dict[str, Any], + *, + fetched_at_utc: str | None = None, + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + raw_models = payload.get("models") + if raw_models is None: + raw_models = payload.get("data") + if not isinstance(raw_models, list): + raise ValueError("OpenRouter models payload is missing a `data` list.") + return OpenRouterPriceCatalog( + source_url=str(payload.get("source_url") or OPENROUTER_MODELS_URL), + fetched_at_utc=( + str(payload["fetched_at_utc"]) + if payload.get("fetched_at_utc") is not None + else fetched_at_utc + ), + cache_path=str(cache_path or payload.get("cache_path") or ""), + models=tuple(_parse_catalog_model(model) for model in raw_models), + ) + + +def _cache_payload_from_raw_response(raw_payload: dict[str, Any]) -> dict[str, Any]: + return { + "source_url": OPENROUTER_MODELS_URL, + "fetched_at_utc": _utc_now_iso(), + "models": raw_payload.get("data", []), + } + + +def _fetch_openrouter_catalog_payload(timeout_seconds: float = 30.0) -> dict[str, Any]: + request = urllib.request.Request(OPENROUTER_MODELS_URL) + api_key = os.environ.get("OPENROUTER_API_KEY") + if api_key: + request.add_header("Authorization", f"Bearer {api_key}") + with urllib.request.urlopen(request, timeout=timeout_seconds) as response: + return json.loads(response.read().decode("utf-8")) + + +def load_openrouter_price_catalog( + *, + refresh: bool = False, + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + resolved_cache_path = ( + Path(cache_path) + if cache_path is not None + else get_openrouter_price_cache_path() + ) + resolved_cache_path.parent.mkdir(parents=True, exist_ok=True) + if refresh or not resolved_cache_path.is_file(): + fetched_payload = _fetch_openrouter_catalog_payload() + cache_payload = _cache_payload_from_raw_response(fetched_payload) + with open(resolved_cache_path, "w", encoding="utf-8") as handle: + json.dump(cache_payload, handle, indent=2, sort_keys=True) + with open(resolved_cache_path, encoding="utf-8") as handle: + cached_payload = json.load(handle) + return parse_openrouter_catalog_payload( + cached_payload, + cache_path=resolved_cache_path, + ) + + +def load_openrouter_price_catalog_with_fallback( + *, + refresh: bool = False, + cache_path: str | Path | None = None, +) -> tuple[OpenRouterPriceCatalog | None, str | None]: + resolved_cache_path = ( + Path(cache_path) + if cache_path is not None + else get_openrouter_price_cache_path() + ) + try: + catalog = load_openrouter_price_catalog( + refresh=refresh, + cache_path=resolved_cache_path, + ) + except (OSError, ValueError, json.JSONDecodeError, urllib.error.URLError) as exc: + if resolved_cache_path.is_file(): + try: + catalog = load_openrouter_price_catalog( + refresh=False, + cache_path=resolved_cache_path, + ) + return ( + catalog, + f"Using cached OpenRouter price catalog after refresh failed: {exc}", + ) + except (OSError, ValueError, json.JSONDecodeError) as cached_exc: + return None, ( + "OpenRouter price catalog refresh and cache load failed: " + f"{cached_exc}" + ) + return None, f"OpenRouter price catalog unavailable: {exc}" + return catalog, None + + +def _strip_provider_prefix(model_spec: str) -> str | None: + if not model_spec: + return None + if "/" not in model_spec: + return model_spec + provider, remainder = model_spec.split("/", 1) + if provider in _KNOWN_PROVIDER_PREFIXES: + return remainder + return model_spec + + +def _candidate_match_variants(candidate: str) -> tuple[str, ...]: + variants = [candidate] + owner_prefix = "" + model_name = candidate + if "/" in candidate: + owner, model_name = candidate.rsplit("/", 1) + owner_prefix = f"{owner}/" + normalized_model_name = _LOCAL_VARIANT_SUFFIX_RE.sub("", model_name) + if normalized_model_name and normalized_model_name != model_name: + variants.append(f"{owner_prefix}{normalized_model_name}") + return tuple(dict.fromkeys(variants)) + + +def find_openrouter_match( + catalog: OpenRouterPriceCatalog, + model_spec: str, +) -> tuple[OpenRouterModelEntry | None, str | None]: + candidate = _strip_provider_prefix(model_spec) + if not candidate: + return None, None + candidate_variants = _candidate_match_variants(candidate) + lowered_exact = candidate_variants[0].casefold() + lowered_normalized = {variant.casefold() for variant in candidate_variants[1:]} + for model in catalog.models: + lowered_candidates = { + match_candidate.casefold() + for match_candidate in model.exact_match_candidates() + } + if lowered_exact in lowered_candidates: + return model, "exact_case_insensitive" + if lowered_normalized.intersection(lowered_candidates): + return model, "local_variant_suffix_stripped" + return None, None + + +def find_exact_openrouter_match( + catalog: OpenRouterPriceCatalog, + model_spec: str, +) -> OpenRouterModelEntry | None: + matched_model, match_strategy = find_openrouter_match(catalog, model_spec) + if match_strategy == "exact_case_insensitive": + return matched_model + return None + + +def _sum_phase_records(records: list[TokenUsageRecord]) -> dict[str, int]: + prompt_tokens = sum(record.prompt_tokens for record in records) + completion_tokens = sum(record.completion_tokens for record in records) + requests = sum(record.requests for record in records) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + "request_count": requests, + } + + +def _ignored_pricing_components(pricing: OpenRouterModelPricing) -> list[str]: + ignored: list[str] = [] + for field_name in _UNAPPLIED_PRICE_COMPONENTS: + if getattr(pricing, field_name) != 0.0: + ignored.append(field_name) + return ignored + + +def _phase_summary_for_unmatched( + *, + model_spec: str, + usage_totals: dict[str, int], + pricing_status: str, +) -> dict[str, Any]: + return { + "model_spec": model_spec, + "pricing_status": pricing_status, + **usage_totals, + "openrouter_model_id": None, + "openrouter_canonical_slug": None, + "openrouter_hugging_face_id": None, + "openrouter_reference_cost_usd": None, + "applied_pricing_usd": None, + "ignored_pricing_components": [], + } + + +def build_openrouter_reference_pricing_summary( + *, + tracker: OpenRouterReferencePricingTracker, + phase_model_specs: dict[str, str], + refresh_catalog: bool = False, + cache_path: str | Path | None = None, +) -> dict[str, Any]: + phase_records: dict[str, list[TokenUsageRecord]] = { + phase: [record for record in tracker.records if record.phase == phase] + for phase in phase_model_specs + } + should_load_catalog = any(phase_records.values()) + catalog: OpenRouterPriceCatalog | None = None + catalog_warning: str | None = None + if should_load_catalog: + catalog, catalog_warning = load_openrouter_price_catalog_with_fallback( + refresh=refresh_catalog, + cache_path=cache_path, + ) + + phase_summaries: dict[str, dict[str, Any]] = {} + priced_costs: list[float] = [] + for phase, model_spec in phase_model_specs.items(): + records = phase_records[phase] + usage_totals = _sum_phase_records(records) + if not records: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="no_runtime_token_data", + ) + continue + if catalog is None: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="price_catalog_unavailable", + ) + continue + + matched_model, match_strategy = find_openrouter_match(catalog, model_spec) + if matched_model is None: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="no_exact_openrouter_match", + ) + continue + + ignored_components = _ignored_pricing_components(matched_model.pricing) + phase_cost = ( + usage_totals["prompt_tokens"] * matched_model.pricing.prompt + + usage_totals["completion_tokens"] * matched_model.pricing.completion + + usage_totals["request_count"] * matched_model.pricing.request + ) + priced_costs.append(phase_cost) + if match_strategy == "local_variant_suffix_stripped": + base_status = "matched_openrouter_model_after_variant_normalization" + else: + base_status = "matched_exact_openrouter_model" + phase_summaries[phase] = { + "model_spec": model_spec, + "pricing_status": ( + base_status if not ignored_components else f"{base_status}_partial" + ), + **usage_totals, + "openrouter_model_id": matched_model.model_id, + "openrouter_canonical_slug": matched_model.canonical_slug, + "openrouter_hugging_face_id": matched_model.hugging_face_id, + "openrouter_reference_cost_usd": phase_cost, + "applied_pricing_usd": { + "prompt": matched_model.pricing.prompt, + "completion": matched_model.pricing.completion, + "request": matched_model.pricing.request, + }, + "ignored_pricing_components": ignored_components, + } + + total_prompt_tokens = sum( + phase_summary["prompt_tokens"] for phase_summary in phase_summaries.values() + ) + total_completion_tokens = sum( + phase_summary["completion_tokens"] for phase_summary in phase_summaries.values() + ) + total_request_count = sum( + phase_summary["request_count"] for phase_summary in phase_summaries.values() + ) + total_reference_cost = sum(priced_costs) if priced_costs else None + + return { + "pricing_model": "openrouter_reference", + "pricing_currency": "USD", + "catalog_source_url": OPENROUTER_MODELS_URL, + "catalog_cache_path": str( + cache_path if cache_path is not None else get_openrouter_price_cache_path() + ), + "catalog_fetched_at_utc": catalog.fetched_at_utc if catalog else None, + "catalog_warning": catalog_warning, + "exact_match_policy": { + "strategy": "exact_case_insensitive", + "match_fields": ["id", "canonical_slug", "hugging_face_id"], + "fallback_normalizations": ["strip_common_local_quantization_suffixes"], + }, + "phases": phase_summaries, + "total": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + "request_count": total_request_count, + "openrouter_reference_cost_usd": total_reference_cost, + }, + } + + +def format_openrouter_reference_pricing_summary(summary: dict[str, Any]) -> str: + lines = ["OpenRouter reference pricing:"] + for phase, phase_summary in summary["phases"].items(): + phase_cost = phase_summary.get("openrouter_reference_cost_usd") + cost_str = f" | usd={phase_cost:.6f}" if phase_cost is not None else "" + lines.append( + " " + + f"{phase}: status={phase_summary['pricing_status']}" + + f" | prompt={phase_summary['prompt_tokens']}" + + f" | completion={phase_summary['completion_tokens']}" + + f" | total={phase_summary['total_tokens']}" + + cost_str + ) + total = summary["total"] + total_cost = total.get("openrouter_reference_cost_usd") + total_cost_str = f" | usd={total_cost:.6f}" if total_cost is not None else "" + lines.append( + " total:" + + f" prompt={total['prompt_tokens']}" + + f" | completion={total['completion_tokens']}" + + f" | total={total['total_tokens']}" + + total_cost_str + ) + warning = summary.get("catalog_warning") + if warning: + lines.append(f" warning: {warning}") + return "\n".join(lines) + + +def refresh_openrouter_price_catalog( + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + return load_openrouter_price_catalog(refresh=True, cache_path=cache_path) + + +def _build_cli_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m judgearena.openrouter_reference_pricing", + description="Refresh or inspect the cached OpenRouter model pricing catalog.", + ) + parser.add_argument( + "--refresh", + action="store_true", + help="Force-refresh the cached OpenRouter models catalog.", + ) + parser.add_argument( + "--model", + default=None, + help="Optional local model spec to resolve against the cached catalog.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + args = _build_cli_parser().parse_args(argv) + catalog = load_openrouter_price_catalog(refresh=args.refresh) + print( + json.dumps( + { + "catalog_source_url": catalog.source_url, + "catalog_fetched_at_utc": catalog.fetched_at_utc, + "catalog_cache_path": catalog.cache_path, + "model_count": len(catalog.models), + "matched_model": ( + asdict(find_exact_openrouter_match(catalog, args.model)) + if args.model + and find_exact_openrouter_match(catalog, args.model) is not None + else None + ), + }, + indent=2, + sort_keys=True, + ) + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/judgearena/prompts/prompt-with-explanation.txt b/judgearena/prompts/prompt-with-explanation.txt index 6600f51..3d9eb41 100644 --- a/judgearena/prompts/prompt-with-explanation.txt +++ b/judgearena/prompts/prompt-with-explanation.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's Answer|> +<|The Start of Assistant A's {completion_label}|> {completion_A} -<|The End of Assistant A's Answer|> +<|The End of Assistant A's {completion_label}|> -<|The Start of Assistant B's Answer|> +<|The Start of Assistant B's {completion_label}|> {completion_B} -<|The End of Assistant B's Answer|> +<|The End of Assistant B's {completion_label}|> # Your output diff --git a/judgearena/prompts/skywork-prompt-with-explanation.txt b/judgearena/prompts/skywork-prompt-with-explanation.txt new file mode 100644 index 0000000..e1f9250 --- /dev/null +++ b/judgearena/prompts/skywork-prompt-with-explanation.txt @@ -0,0 +1,14 @@ +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. +Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +Please briefly explain your reasoning first, then directly output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better. + +[User Question] +{user_prompt} + +[The Start of Assistant A's Answer] +{completion_A} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{completion_B} +[The End of Assistant B's Answer] diff --git a/judgearena/prompts/skywork-prompt.txt b/judgearena/prompts/skywork-prompt.txt new file mode 100644 index 0000000..97ad3c8 --- /dev/null +++ b/judgearena/prompts/skywork-prompt.txt @@ -0,0 +1,14 @@ +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. +Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +Please directly output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better. + +[User Question] +{user_prompt} + +[The Start of Assistant A's Answer] +{completion_A} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{completion_B} +[The End of Assistant B's Answer] diff --git a/judgearena/repro.py b/judgearena/repro.py index 9468c14..72b0059 100644 --- a/judgearena/repro.py +++ b/judgearena/repro.py @@ -231,6 +231,7 @@ def write_run_metadata( judge_system_prompt: str | None = None, judge_user_prompt_template: str | None = None, started_at_utc: datetime | None = None, + pricing_reference: dict[str, Any] | None = None, metadata_filename: str = METADATA_FILENAME, ) -> Path: """Write run metadata JSON and return the output path.""" @@ -282,6 +283,8 @@ def write_run_metadata( judge_user_prompt_template_hash = _hash_string_sha256(judge_user_prompt_template) if judge_user_prompt_template_hash: metadata["judge_user_prompt_template_sha256"] = judge_user_prompt_template_hash + if pricing_reference is not None: + metadata["pricing_reference"] = _to_jsonable(pricing_reference) metadata["artifacts"] = _collect_artifacts( output_path, metadata_filename=metadata_filename diff --git a/judgearena/utils.py b/judgearena/utils.py index 993ef01..09253db 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -1,9 +1,13 @@ import asyncio import os +import re import time import warnings +from collections import Counter from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path +from typing import Any import pandas as pd from huggingface_hub import snapshot_download @@ -14,11 +18,16 @@ from tqdm.asyncio import tqdm from tqdm.contrib.logging import logging_redirect_tqdm +from judgearena.chat_models import ( + OpenRouterGeminiSafetyTolerantChatOpenAI, + is_openrouter_gemini_model, +) from judgearena.instruction_dataset.arena_hard import ( download_arena_hard, is_arena_hard_dataset, ) from judgearena.log import get_logger +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker logger = get_logger(__name__) @@ -32,6 +41,153 @@ def _data_root_path() -> Path: data_root = _data_root_path() +DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET = 512 +VLLM_REASONING_START_STR = "" +VLLM_REASONING_END_STR = ( + "I have to give the solution based on the thinking directly now." +) +_THINKING_MODEL_SUBSTRINGS = ("qwen3", "smollm3") + + +def _split_model_spec(model_spec: str) -> tuple[str, str]: + provider, sep, model_name = model_spec.partition("/") + if not sep: + return model_spec, "" + return provider, model_name + + +def is_thinking_model(model_name: str) -> bool: + """Return True for reasoning models that emit `...` traces. + + Covers the Qwen3 family (e.g. `Qwen/Qwen3.5-9B`) and SmolLM3 (e.g. + `HuggingFaceTB/SmolLM3-3B`); both share the same ``/`` tag + convention so vLLM's budget enforcement and our tag-stripping apply + uniformly. Matching is case-insensitive to tolerate mixed-case HF repo + ids like `HuggingFaceTB/SmolLM3-3B`. + """ + lowered = model_name.lower() + return any(token in lowered for token in _THINKING_MODEL_SUBSTRINGS) + + +def build_default_judge_model_kwargs( + judge_model: str, + engine_kwargs: dict[str, object], + *, + judge_engine_kwargs_override: dict[str, object] | None = None, +) -> dict[str, object]: + """Copy judge engine kwargs and add supported built-in defaults. + + ``judge_engine_kwargs_override`` is layered on top of ``engine_kwargs`` + so callers can pin judge-only tweaks (e.g. a higher tensor-parallel size + for a 70B judge) without poisoning the battle-model engine config, which + must often stay on TP=1 to dodge compile-time deadlocks on hybrid models + such as Qwen3.5. + """ + judge_model_kwargs = dict(engine_kwargs) + if judge_engine_kwargs_override: + judge_model_kwargs.update(judge_engine_kwargs_override) + provider, model_name = _split_model_spec(judge_model) + if provider == "VLLM": + if "thinking_token_budget" not in judge_model_kwargs and is_thinking_model( + model_name + ): + judge_model_kwargs["thinking_token_budget"] = ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) + # FP8 weights leave little KV headroom on consumer-class GPUs; default + # to FP8 KV cache so judges like Skywork-70B-FP8 fit comfortably on + # 2x L40S at 32k context. Explicit caller overrides still win. + if "kv_cache_dtype" not in judge_model_kwargs and "fp8" in model_name.lower(): + judge_model_kwargs["kv_cache_dtype"] = "fp8" + return judge_model_kwargs + + +def _resolve_chat_template_kwargs( + *, + explicit_chat_template_kwargs: dict[str, object] | None, + disable_thinking: bool, +) -> dict[str, object] | None: + chat_template_kwargs = dict(explicit_chat_template_kwargs or {}) + if disable_thinking and "enable_thinking" not in chat_template_kwargs: + chat_template_kwargs["enable_thinking"] = False + + return chat_template_kwargs or None + + +@dataclass(frozen=True) +class LimitEvent: + kind: str + stage: str + field: str | None = None + case_id: str | None = None + model_spec: str | None = None + original_length: int | None = None + final_length: int | None = None + note: str | None = None + + +class LimitEventTracker: + def __init__(self) -> None: + self.events: list[LimitEvent] = [] + + def record( + self, + kind: str, + *, + stage: str, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, + original_length: int | None = None, + final_length: int | None = None, + note: str | None = None, + ) -> None: + self.events.append( + LimitEvent( + kind=kind, + stage=stage, + field=field, + case_id=None if case_id is None else str(case_id), + model_spec=model_spec, + original_length=original_length, + final_length=final_length, + note=note, + ) + ) + + def build_summary(self) -> dict[str, Any]: + counts_by_kind: Counter[str] = Counter() + counts_by_stage: Counter[str] = Counter() + counts_by_kind_and_field: dict[str, Counter[str]] = {} + affected_cases_total: set[str] = set() + affected_cases_by_kind: dict[str, set[str]] = {} + + for event in self.events: + counts_by_kind[event.kind] += 1 + counts_by_stage[event.stage] += 1 + field_key = event.field or "_all" + counts_by_kind_and_field.setdefault(event.kind, Counter())[field_key] += 1 + if event.case_id is None: + continue + case_key = f"{event.stage}:{event.case_id}" + affected_cases_total.add(case_key) + affected_cases_by_kind.setdefault(event.kind, set()).add(case_key) + + return { + "total_events": len(self.events), + "counts_by_kind": dict(sorted(counts_by_kind.items())), + "counts_by_stage": dict(sorted(counts_by_stage.items())), + "counts_by_kind_and_field": { + kind: dict(sorted(counter.items())) + for kind, counter in sorted(counts_by_kind_and_field.items()) + }, + "affected_cases_total": len(affected_cases_total), + "affected_cases_by_kind": { + kind: len(case_ids) + for kind, case_ids in sorted(affected_cases_by_kind.items()) + }, + } + def set_langchain_cache(): set_llm_cache(SQLiteCache(database_path=str(data_root / ".langchain.db"))) @@ -41,7 +197,7 @@ def download_hf(name: str, local_path: Path): local_path.mkdir(exist_ok=True, parents=True) # downloads the model from huggingface into `local_path` folder snapshot_download( - repo_id="geoalgo/llmjudge", + repo_id="judge-arena/judge-arena-dataset", repo_type="dataset", allow_patterns=f"*{name}*", local_dir=local_path, @@ -113,6 +269,33 @@ def truncate(s: str, max_len: int | None = None) -> str: return s +def truncate_with_metadata( + s: str | None, + max_len: int | None = None, + *, + tracker: LimitEventTracker | None = None, + kind: str | None = None, + stage: str | None = None, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, +) -> tuple[str, bool]: + original = s if isinstance(s, str) else "" + truncated = truncate(original, max_len=max_len) + was_truncated = truncated != original + if was_truncated and tracker is not None and kind is not None and stage is not None: + tracker.record( + kind, + stage=stage, + field=field, + case_id=case_id, + model_spec=model_spec, + original_length=len(original), + final_length=len(truncated), + ) + return truncated, was_truncated + + def safe_text(value: object, truncate_chars: int | None) -> str: """Coerce *value* to a string and optionally truncate. @@ -127,17 +310,144 @@ def safe_text(value: object, truncate_chars: int | None) -> str: return truncate(str(value), max_len=truncate_chars) -def do_inference(chat_model, inputs, use_tqdm: bool = False): +def safe_text_with_metadata( + value: object, + truncate_chars: int | None, + *, + tracker: LimitEventTracker | None = None, + kind: str | None = None, + stage: str | None = None, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, +) -> tuple[str, bool]: + if value is None: + return "", False + is_missing = pd.isna(value) + if isinstance(is_missing, bool) and is_missing: + return "", False + return truncate_with_metadata( + str(value), + max_len=truncate_chars, + tracker=tracker, + kind=kind, + stage=stage, + field=field, + case_id=case_id, + model_spec=model_spec, + ) + + +_THINK_BLOCK_RE = re.compile(r".*?", re.IGNORECASE | re.DOTALL) + + +def strip_thinking_tags(text: str | None) -> str: + """Remove full `...` blocks from raw model output.""" + return strip_thinking_tags_with_metadata(text)[0] + + +def strip_thinking_tags_with_metadata(text: str | None) -> tuple[str, bool]: + """Remove visible reasoning spans from raw model output.""" + if not isinstance(text, str): + return "", False + + cleaned = _THINK_BLOCK_RE.sub("", text) + if cleaned != text: + return cleaned.lstrip(), True + + lowered = text.lower() + closing_tag = "" + closing_idx = lowered.find(closing_tag) + if closing_idx != -1 and "" not in lowered[:closing_idx]: + return text[closing_idx + len(closing_tag) :].lstrip(), True + + forced_end_idx = text.find(VLLM_REASONING_END_STR) + if forced_end_idx != -1: + return ( + text[forced_end_idx + len(VLLM_REASONING_END_STR) :].lstrip(), + True, + ) + + return text, False + + +def _extract_ai_message_metadata(result: object) -> dict[str, Any]: + """Extract finish_reason/stop_reason from a LangChain AIMessage result. + + LangChain chat models (ChatOpenAI for OpenRouter, Anthropic, etc.) return + AIMessage objects with a ``response_metadata`` dict. We propagate the + subset that downstream code consumes (finish_reason is critical: it gates + truncation detection in _record_generation_output_limit_events). + """ + response_metadata = getattr(result, "response_metadata", None) or {} + finish_reason = response_metadata.get("finish_reason") + stop_reason = response_metadata.get("stop_reason") + if finish_reason is None and isinstance(result, dict): + finish_reason = result.get("finish_reason") + stop_reason = result.get("stop_reason", stop_reason) + return {"finish_reason": finish_reason, "stop_reason": stop_reason} + + +def _extract_token_usage(result: object) -> dict[str, int] | None: + """Pull API-reported token usage from a LangChain AIMessage-like result. + + Two shapes coexist depending on langchain-openai version: + - langchain-core AIMessage.usage_metadata: ``{"input_tokens", "output_tokens", "total_tokens"}`` + - response_metadata.token_usage (OpenAI-shape): ``{"prompt_tokens", "completion_tokens", "total_tokens"}`` + + Returns the first shape that carries non-null counts, or ``None`` if neither + is present (e.g. provider that does not surface usage). Used by + ``OpenRouterReferencePricingTracker.record_batch_from_usage_metadata`` to + capture per-call billing tokens for OpenRouter / ChatOpenAI runs, which + cannot be tokenised via ``count_*_batch`` helpers. + """ + usage_metadata = getattr(result, "usage_metadata", None) + if isinstance(usage_metadata, dict) and ( + usage_metadata.get("input_tokens") is not None + or usage_metadata.get("output_tokens") is not None + ): + return dict(usage_metadata) + response_metadata = getattr(result, "response_metadata", None) or {} + token_usage = ( + response_metadata.get("token_usage") + if isinstance(response_metadata, dict) + else None + ) + if isinstance(token_usage, dict) and ( + token_usage.get("prompt_tokens") is not None + or token_usage.get("completion_tokens") is not None + ): + return dict(token_usage) + return None + + +def do_inference( + chat_model, + inputs, + use_tqdm: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, + return_metadata: bool = False, +): # Retries on rate-limit/server errors with exponential backoff. # Async path retries individual calls; batch path splits into 4^attempt chunks on failure. invoke_kwargs = { # "stop": ["```"], # "max_tokens": 100, } + metadata: list[dict[str, Any]] | None = None if use_tqdm: # perform inference asynchronously to be able to update tqdm, chat_model.batch does not work as it blocks until # all requests are received + # JUDGEARENA_JUDGE_MAX_CONCURRENCY caps simultaneous in-flight ainvokes + # (e.g. against OpenRouter). Unset = unbounded, preserving prior behaviour. + cap_raw = os.environ.get("JUDGEARENA_JUDGE_MAX_CONCURRENCY") + cap = int(cap_raw) if cap_raw and int(cap_raw) > 0 else None + async def process_with_real_progress(chat_model, inputs, pbar): + sem = asyncio.Semaphore(cap) if cap else None + async def process_single(input_item, max_retries=5, base_delay=1.0): for attempt in range(max_retries): try: @@ -157,8 +467,14 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): ) await asyncio.sleep(delay) + async def gated(inp): + if sem is None: + return await process_single(inp) + async with sem: + return await process_single(inp) + # asyncio.gather preserves order (unlike as_completed) - results = await asyncio.gather(*[process_single(inp) for inp in inputs]) + results = await asyncio.gather(*[gated(inp) for inp in inputs]) return results with logging_redirect_tqdm(), tqdm(total=len(inputs)) as pbar: @@ -167,6 +483,9 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): chat_model=chat_model, inputs=inputs, pbar=pbar ) ) + # Always materialize metadata; it is cheap and keeps return_metadata + # behavior consistent with the batch path. + metadata = [_extract_ai_message_metadata(r) for r in res] else: def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): @@ -179,9 +498,26 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): ] try: results = [] + results_metadata = [] for chunk in chunks: - results.extend(chat_model.batch(inputs=chunk, **invoke_kwargs)) - return results + if return_metadata and hasattr( + chat_model, "batch_with_metadata" + ): + chunk_results, chunk_metadata = ( + chat_model.batch_with_metadata( + inputs=chunk, **invoke_kwargs + ) + ) + else: + chunk_results = chat_model.batch( + inputs=chunk, **invoke_kwargs + ) + chunk_metadata = [ + _extract_ai_message_metadata(r) for r in chunk_results + ] + results.extend(chunk_results) + results_metadata.extend(chunk_metadata) + return results, results_metadata except Exception as e: if attempt == max_retries - 1 or not _is_retryable_error(e): raise @@ -197,12 +533,43 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): ) time.sleep(delay) - res = batch_with_retry(inputs) + res, metadata = batch_with_retry(inputs) + + # Pull per-call usage from AIMessage objects BEFORE flattening to .content; + # OpenRouter / ChatOpenAI surface API-billed token counts here but lose + # them after .content extraction. + per_call_usages = [_extract_token_usage(r) for r in res] # Not sure why the API of Langchain returns sometime a string and sometimes an AIMessage object # is it because of using Chat and barebones models? # when using OpenAI, the output is AIMessage not a string... res = [x.content if hasattr(x, "content") else x for x in res] + if ( + usage_tracker is not None + and usage_phase is not None + and usage_model_spec is not None + ): + try: + recorded = usage_tracker.record_batch_from_usage_metadata( + phase=usage_phase, + model_spec=usage_model_spec, + usages=per_call_usages, + ) + if not recorded: + usage_tracker.record_batch_from_model( + phase=usage_phase, + model_spec=usage_model_spec, + chat_model=chat_model, + inputs=list(inputs), + outputs=res, + ) + except Exception as e: + print( + f"Warning: failed to record token usage for phase " + f"'{usage_phase}' ({usage_model_spec}): {e}" + ) + if return_metadata: + return res, (metadata or [{} for _ in res]) return res @@ -221,6 +588,52 @@ async def ainvoke(self, input, **invoke_kwargs): return self.message +_VLLM_INIT_RETRY_SIGNATURES = ( + "cudaErrorDevicesUnavailable", + "CUDA-capable device(s) is/are busy or unavailable", + "CUDA error: initialization error", +) +_VLLM_INIT_MAX_ATTEMPTS = int(os.getenv("JUDGEARENA_VLLM_INIT_MAX_ATTEMPTS", "4")) +_VLLM_INIT_BACKOFF_SECONDS = int( + os.getenv("JUDGEARENA_VLLM_INIT_BACKOFF_SECONDS", "20") +) + + +def _init_llm_with_retry(llm_cls, **kwargs): + """Instantiate ``vllm.LLM`` with retries on transient GPU-init races. + + On shared Slurm nodes with MaxGRESPerAccount throttling, freshly scheduled + jobs can hit ``cudaErrorDevicesUnavailable`` because the previous tenant's + driver cleanup has not finished when our process starts. This manifests as + an immediate engine-core init failure and is almost always resolved by a + 15-30 s sleep + retry on the same GPU. We retry up to + ``JUDGEARENA_VLLM_INIT_MAX_ATTEMPTS`` times with exponential backoff before + giving up, which keeps persistent configuration errors from looping. + """ + last_exc: Exception | None = None + for attempt in range(1, _VLLM_INIT_MAX_ATTEMPTS + 1): + try: + return llm_cls(**kwargs) + except Exception as exc: + message = f"{type(exc).__name__}: {exc}" + if not any(sig in message for sig in _VLLM_INIT_RETRY_SIGNATURES): + raise + last_exc = exc + if attempt == _VLLM_INIT_MAX_ATTEMPTS: + break + delay = _VLLM_INIT_BACKOFF_SECONDS * (2 ** (attempt - 1)) + warnings.warn( + f"vLLM init attempt {attempt}/{_VLLM_INIT_MAX_ATTEMPTS} failed " + f"with transient GPU-init signature ({message.splitlines()[0]}); " + f"sleeping {delay}s before retry.", + RuntimeWarning, + stacklevel=2, + ) + time.sleep(delay) + assert last_exc is not None + raise last_exc + + class ChatVLLM: """VLLM wrapper that auto-detects whether to use chat() or generate(). @@ -244,9 +657,27 @@ def __init__( **vllm_kwargs, ): from vllm import LLM, SamplingParams + from vllm.config.reasoning import ReasoningConfig self.model_path = model self.max_tokens = max_tokens + limit_event_tracker: LimitEventTracker | None = vllm_kwargs.pop( + "limit_event_tracker", None + ) + limit_event_stage = str(vllm_kwargs.pop("limit_event_stage", "model_init")) + limit_event_model_spec = str( + vllm_kwargs.pop("limit_event_model_spec", f"VLLM/{model}") + ) + disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) + thinking_token_budget = vllm_kwargs.pop("thinking_token_budget", None) + explicit_chat_template_kwargs = vllm_kwargs.pop("chat_template_kwargs", None) + explicit_reasoning_settings = ( + "reasoning_parser" in vllm_kwargs or "reasoning_config" in vllm_kwargs + ) + self._chat_template_kwargs = _resolve_chat_template_kwargs( + explicit_chat_template_kwargs=explicit_chat_template_kwargs, + disable_thinking=disable_thinking, + ) # Cap max_model_len to the model's max_position_embeddings so that # vLLM doesn't reject an overly large context window. @@ -258,6 +689,15 @@ def __init__( config = AutoConfig.from_pretrained(model, trust_remote_code=True) model_max_pos = getattr(config, "max_position_embeddings", None) if model_max_pos is not None and max_model_len > model_max_pos: + if limit_event_tracker is not None: + limit_event_tracker.record( + "max_model_len_clamped", + stage=limit_event_stage, + field="max_model_len", + model_spec=limit_event_model_spec, + original_length=int(max_model_len), + final_length=int(model_max_pos), + ) warnings.warn( f"Capping max_model_len from {max_model_len} to " f"{model_max_pos} (max_position_embeddings) for '{model}'.", @@ -272,13 +712,54 @@ def __init__( RuntimeWarning, stacklevel=2, ) + self._sampling_params_kwargs = { + "max_tokens": max_tokens, + "temperature": float(vllm_kwargs.pop("temperature", 0.6)), + "top_p": float(vllm_kwargs.pop("top_p", 0.95)), + } + self._thinking_budget_marker: str | None = None + self._thinking_budget_value: int | None = None + if thinking_token_budget is not None: + if max_tokens is not None: + thinking_token_budget = min(int(thinking_token_budget), int(max_tokens)) + if explicit_reasoning_settings: + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) + self._thinking_budget_marker = VLLM_REASONING_END_STR + self._thinking_budget_value = int(thinking_token_budget) + elif is_thinking_model(model): + vllm_kwargs.setdefault( + "reasoning_config", + ReasoningConfig( + reasoning_start_str=VLLM_REASONING_START_STR, + reasoning_end_str=VLLM_REASONING_END_STR, + ), + ) + # The `qwen3` reasoning_parser only runs inside vLLM's + # OpenAI-compatible server for `reasoning_content` extraction. + # For offline batch inference via LLM.chat() it is inert, so + # it is safe to reuse for any ``/`` model + # (Qwen3 + SmolLM3). + vllm_kwargs.setdefault("reasoning_parser", "qwen3") + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) + self._thinking_budget_marker = VLLM_REASONING_END_STR + self._thinking_budget_value = int(thinking_token_budget) + else: + warnings.warn( + f"Model '{model}' is not in JudgeArena's built-in thinking-model " + "defaults (Qwen3/SmolLM3). Ignoring thinking_token_budget unless " + "reasoning_parser or reasoning_config is provided explicitly.", + stacklevel=2, + ) + self.sampling_params = SamplingParams(**self._sampling_params_kwargs) - self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) - self.sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=0.6, - top_p=0.95, + self.llm = _init_llm_with_retry( + LLM, model=model, trust_remote_code=True, **vllm_kwargs ) + self.tokenizer = self.llm.get_tokenizer() # Resolve chat template: # 1. Explicit override always wins → use chat() with that template @@ -289,8 +770,7 @@ def __init__( self._use_generate = False logger.info("ChatVLLM: using explicit chat template for '%s'", model) else: - tokenizer = self.llm.get_tokenizer() - if not getattr(tokenizer, "chat_template", None): + if not getattr(self.tokenizer, "chat_template", None): warnings.warn( f"Model '{model}' tokenizer does not define a chat template. " f"Falling back to llm.generate() (no chat formatting). " @@ -299,11 +779,23 @@ def __init__( ) self.chat_template = None self._use_generate = True + if disable_thinking: + warnings.warn( + f"Model '{model}' has no chat template, so disable_thinking " + "cannot be applied when falling back to llm.generate().", + stacklevel=2, + ) else: self.chat_template = None # let vLLM use the tokenizer's own self._use_generate = False logger.info("ChatVLLM: using tokenizer's chat template for '%s'", model) + def set_temperature(self, temperature: float) -> None: + from vllm import SamplingParams + + self._sampling_params_kwargs["temperature"] = float(temperature) + self.sampling_params = SamplingParams(**self._sampling_params_kwargs) + def _to_messages(self, input_item) -> list[dict]: """Convert LangChain prompt input to OpenAI-style messages.""" # Map LangChain message types to OpenAI roles @@ -355,7 +847,7 @@ def _to_raw_text(self, input_item) -> str: return "\n".join(msg["content"] for msg in input_item) raise ValueError(f"Cannot extract raw text from: {type(input_item)}") - def batch(self, inputs: list, **invoke_kwargs) -> list[str]: + def _run_raw_batch(self, inputs: list): """Process a batch of inputs using vllm.LLM.chat() or llm.generate(). Uses ``llm.chat()`` when a chat template is available (instruct models), @@ -371,8 +863,72 @@ def batch(self, inputs: list, **invoke_kwargs) -> list[str]: self.sampling_params, add_generation_prompt=True, chat_template=self.chat_template, + chat_template_kwargs=self._chat_template_kwargs, ) - return [out.outputs[0].text for out in outputs] + return outputs + + def batch_with_metadata( + self, inputs: list, **invoke_kwargs + ) -> tuple[list[str], list[dict[str, Any]]]: + outputs = self._run_raw_batch(inputs) + texts: list[str] = [] + metadata: list[dict[str, Any]] = [] + marker = self._thinking_budget_marker + for out in outputs: + first_output = out.outputs[0] + text = first_output.text + texts.append(text) + row: dict[str, Any] = { + "finish_reason": getattr(first_output, "finish_reason", None), + "stop_reason": getattr(first_output, "stop_reason", None), + } + if marker is not None: + # vLLM emits the forced reasoning-end marker verbatim when the + # per-request thinking-token budget is exhausted; the marker is + # absent otherwise. Detecting it here gives + # `_record_generation_output_limit_events` a deterministic + # signal to log a `generation_thinking_token_budget` event. + row["thinking_budget_exhausted"] = marker in text + row["thinking_token_budget"] = self._thinking_budget_value + metadata.append(row) + return texts, metadata + + def batch(self, inputs: list, **invoke_kwargs) -> list[str]: + texts, _metadata = self.batch_with_metadata(inputs, **invoke_kwargs) + return texts + + def _count_chat_prompt_tokens(self, messages: list[dict]) -> int: + tokenizer_kwargs: dict[str, object] = { + "tokenize": True, + "add_generation_prompt": True, + } + if self.chat_template is not None: + tokenizer_kwargs["chat_template"] = self.chat_template + if self._chat_template_kwargs is not None: + tokenizer_kwargs["chat_template_kwargs"] = self._chat_template_kwargs + try: + token_ids = self.tokenizer.apply_chat_template(messages, **tokenizer_kwargs) + except TypeError: + tokenizer_kwargs.pop("chat_template_kwargs", None) + token_ids = self.tokenizer.apply_chat_template(messages, **tokenizer_kwargs) + return len(token_ids) + + def count_prompt_tokens_batch(self, inputs: list) -> list[int]: + counts: list[int] = [] + for input_item in inputs: + if self._use_generate: + counts.append(len(self.tokenizer.encode(self._to_raw_text(input_item)))) + else: + counts.append( + self._count_chat_prompt_tokens(self._to_messages(input_item)) + ) + return counts + + def count_completion_tokens_batch(self, outputs: list[str]) -> list[int]: + return [ + len(self.tokenizer.encode(output, add_special_tokens=False)) + for output in outputs + ] def invoke(self, input_item, **invoke_kwargs) -> str: """Process a single input.""" @@ -402,11 +958,14 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): # NOTE: this is a shallow copy since we are not modifying any # mutable objects in the dictionary. engine_kwargs = engine_kwargs.copy() + limit_event_tracker = engine_kwargs.pop("limit_event_tracker", None) + limit_event_stage = engine_kwargs.pop("limit_event_stage", None) + limit_event_model_spec = engine_kwargs.pop("limit_event_model_spec", None) # Dedicated arguments like max_tokens always win over engine_kwargs. engine_kwargs["max_tokens"] = max_tokens or 8192 - model_provider = model.split("/")[0] + model_provider, model_name = _split_model_spec(model) # vLLM-engine-only kwargs must not leak to remote-API providers # (OpenRouter, OpenAI, Together): langchain-openai forwards unknown @@ -418,13 +977,18 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): if model_provider == "Dummy": return DummyModel(model) - model_name = "/".join(model.split("/")[1:]) logger.info("Loading %s(model=%s)", model_provider, model_name) # Use our custom ChatVLLM wrapper which properly applies chat templates if model_provider == "VLLM": engine_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} engine_kwargs["chat_template"] = engine_kwargs.get("chat_template", None) + if limit_event_tracker is not None: + engine_kwargs["limit_event_tracker"] = limit_event_tracker + if limit_event_stage is not None: + engine_kwargs["limit_event_stage"] = limit_event_stage + if limit_event_model_spec is not None: + engine_kwargs["limit_event_model_spec"] = limit_event_model_spec return ChatVLLM( model=model_name, @@ -432,8 +996,16 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): ) if model_provider == "OpenRouter": - # Special case we need to override API url and key - return ChatOpenAI( + # Gemini's core policy filter rejects a small fraction of prompts with + # a hard PROHIBITED_CONTENT error that safety_settings cannot override; + # the subclass converts those into stub refusals so batch generation + # (e.g. benchmark baselines) completes instead of crashing. + chat_model_cls = ( + OpenRouterGeminiSafetyTolerantChatOpenAI + if is_openrouter_gemini_model(model) + else ChatOpenAI + ) + return chat_model_cls( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model=model_name, @@ -468,15 +1040,33 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): return model_cls_dict[model_provider](**engine_kwargs) +def infer_model_spec_from_instance(model: object) -> str | None: + if isinstance(model, DummyModel): + return model.name + if isinstance(model, ChatVLLM): + return f"VLLM/{model.model_path}" + if isinstance(model, LlamaCpp): + model_path = getattr(model, "model_path", None) + if isinstance(model_path, str): + return f"LlamaCpp/{model_path}" + model_name = getattr(model, "model_name", None) or getattr(model, "model", None) + if isinstance(model_name, str): + return f"{model.__class__.__name__}/{model_name}" + return None + + def download_all(): + from judgearena.instruction_dataset.m_arenahard import M_ARENA_HARD_BASELINES + logger.info("Downloading all datasets in %s", data_root) local_path_tables = data_root / "tables" - for dataset in [ + datasets = [ "alpaca-eval", "arena-hard-v0.1", "arena-hard-v2.0", - "m-arena-hard", - ]: + *M_ARENA_HARD_BASELINES.keys(), + ] + for dataset in datasets: if is_arena_hard_dataset(dataset): download_arena_hard(dataset=dataset, local_tables_path=local_path_tables) else: diff --git a/pyproject.toml b/pyproject.toml index 2d421a9..18012cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,23 @@ dev = [ "ruff>=0.11.0", "slurmpilot @ git+https://github.com/geoalgo/slurmpilot.git@main", ] +# `llmcompressor` pins older `compressed-tensors` / `transformers` that clash +# with the `vllm` extra (vLLM 0.19.1 pulls `compressed-tensors==0.15.0.1` and +# `transformers>=5.5.1` for Gemma-4). Keep it available under its own group so +# quantization workflows can still install it, but mark it mutually exclusive +# with the `vllm` extra via `[tool.uv] conflicts` so the universal lock can +# still resolve both sides. +quantize = [ + "llmcompressor>=0.4.0", +] + +[tool.uv] +conflicts = [ + [ + { extra = "vllm" }, + { group = "quantize" }, + ], +] [tool.ruff] target-version = "py312" @@ -80,5 +97,14 @@ quote-style = "double" indent-style = "space" [project.optional-dependencies] -vllm = ["vllm==0.10.2", "transformers>=4.55.2,<5.0.0"] +# vLLM 0.19.1 is the pinned judge/battle runtime on this branch. Its own +# Requires-Dist is `transformers!=5.0.*,...,!=5.5.0,>=4.56.0`, so the +# only sub-range >=5.x that vLLM 0.19.1 accepts is >=5.5.1 — hence the +# lower bound here. Qwen3.5 loading works under either 4.56.x or 5.5.x +# because vLLM ships its own `qwen3_5` config shim and model impl and +# registers them into `AutoConfig` at runtime. +vllm = [ + "vllm>=0.19.1,<1.0.0", + "transformers>=5.5.1,<6.0.0", +] llamacpp = ["llama-cpp-python>=0.3.0"] diff --git a/slurmpilot_scripts/launch_generation_and_evaluation.py b/slurmpilot_scripts/launch_generation_and_evaluation.py index 6668a33..5782977 100644 --- a/slurmpilot_scripts/launch_generation_and_evaluation.py +++ b/slurmpilot_scripts/launch_generation_and_evaluation.py @@ -73,7 +73,7 @@ "dataset": f"{language}-contexts", "model_A": baseline, "model_B": model, - "judge_model": "VLLM/Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8", + "judge_model": "VLLM/Qwen/Qwen3.5-27B-FP8", "n_instructions": 100, # "ignore_cache": None, } diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py new file mode 100644 index 0000000..0cf34b1 --- /dev/null +++ b/tests/test_chat_vllm.py @@ -0,0 +1,329 @@ +import sys +from types import SimpleNamespace + +import pytest + +import judgearena.utils as utils + + +def _install_fake_vllm(monkeypatch): + captured = {} + + class FakeSamplingParams: + def __init__(self, **kwargs): + captured["sampling_kwargs"] = kwargs + + class FakeReasoningConfig: + def __init__(self, **kwargs): + captured["reasoning_config_kwargs"] = kwargs + + class FakeLLM: + def __init__(self, *, model, trust_remote_code, **kwargs): + captured["llm_init"] = { + "model": model, + "trust_remote_code": trust_remote_code, + "kwargs": kwargs, + } + + def get_tokenizer(self): + return SimpleNamespace(chat_template="{{ messages }}") + + def chat(self, messages, sampling_params, **kwargs): + captured["chat_call"] = { + "messages": messages, + "sampling_params": sampling_params, + "kwargs": kwargs, + } + return [SimpleNamespace(outputs=[SimpleNamespace(text="ok")])] + + monkeypatch.setitem( + sys.modules, + "vllm", + SimpleNamespace(LLM=FakeLLM, SamplingParams=FakeSamplingParams), + ) + monkeypatch.setitem( + sys.modules, + "vllm.config.reasoning", + SimpleNamespace(ReasoningConfig=FakeReasoningConfig), + ) + return captured, FakeReasoningConfig + + +def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatch): + captured, fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=128, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 + assert "structured_outputs" not in captured["sampling_kwargs"] + assert captured["reasoning_config_kwargs"] == { + "reasoning_start_str": utils.VLLM_REASONING_START_STR, + "reasoning_end_str": utils.VLLM_REASONING_END_STR, + } + llm_kwargs = captured["llm_init"]["kwargs"] + assert llm_kwargs["reasoning_parser"] == "qwen3" + assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) + + +def test_chat_vllm_enables_reasoning_support_for_smollm3_thinking_budget(monkeypatch): + captured, fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="HuggingFaceTB/SmolLM3-3B", + max_tokens=128, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 + assert captured["reasoning_config_kwargs"] == { + "reasoning_start_str": utils.VLLM_REASONING_START_STR, + "reasoning_end_str": utils.VLLM_REASONING_END_STR, + } + llm_kwargs = captured["llm_init"]["kwargs"] + assert llm_kwargs["reasoning_parser"] == "qwen3" + assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) + + +def test_chat_vllm_clamps_thinking_budget_to_total_max_tokens(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=32, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 32 + + +def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=16, + disable_thinking=True, + gpu_memory_utilization=0.7, + ) + + outputs = chat_model.batch(["hello"]) + + assert outputs == ["ok"] + assert captured["chat_call"]["kwargs"]["chat_template_kwargs"] == { + "enable_thinking": False + } + + +def test_build_default_judge_model_kwargs_only_defaults_qwen_judges(): + assert utils.build_default_judge_model_kwargs( + "VLLM/Qwen/Qwen3.5-9B", + {"gpu_memory_utilization": 0.7}, + ) == { + "gpu_memory_utilization": 0.7, + "thinking_token_budget": 512, + } + assert utils.build_default_judge_model_kwargs( + "VLLM/meta-llama/Llama-3.3-70B-Instruct", + {"gpu_memory_utilization": 0.7}, + ) == {"gpu_memory_utilization": 0.7} + assert ( + utils.build_default_judge_model_kwargs( + "OpenRouter/qwen/qwen3-32b", + {}, + ) + == {} + ) + + +def test_build_default_judge_model_kwargs_sets_fp8_kv_cache_for_fp8_judges(): + fp8_defaults = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9}, + ) + assert fp8_defaults["kv_cache_dtype"] == "fp8" + # FP8 Skywork judge is not Qwen3/SmolLM3 so no thinking-token default. + assert "thinking_token_budget" not in fp8_defaults + + bf16_defaults = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-8B", + {"gpu_memory_utilization": 0.9}, + ) + assert "kv_cache_dtype" not in bf16_defaults + + explicit_override = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9, "kv_cache_dtype": "bfloat16"}, + ) + assert explicit_override["kv_cache_dtype"] == "bfloat16" + + # Non-VLLM providers never receive the FP8 KV default even if the name + # happens to contain "fp8". + non_vllm = utils.build_default_judge_model_kwargs("OpenRouter/some/Model-fp8", {}) + assert "kv_cache_dtype" not in non_vllm + + +def test_build_default_judge_model_kwargs_overlays_judge_override(): + """Judge-scoped overrides must win over shared ``engine_kwargs`` so the + battle engine can stay on TP=1 while the 70B FP8 judge pins TP=2.""" + merged = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9}, + judge_engine_kwargs_override={"tensor_parallel_size": 2}, + ) + assert merged["tensor_parallel_size"] == 2 + assert merged["gpu_memory_utilization"] == 0.9 + assert merged["kv_cache_dtype"] == "fp8" + + overridden = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"tensor_parallel_size": 1, "gpu_memory_utilization": 0.9}, + judge_engine_kwargs_override={"tensor_parallel_size": 4}, + ) + assert overridden["tensor_parallel_size"] == 4 + # FP8 weights + FP8 KV cache are a name-driven invariant; the TP override + # must not silently drop `kv_cache_dtype=fp8` because we run the Skywork + # 70B FP8 judge on TP=2 and TP=4 interchangeably depending on the cell. + assert overridden["kv_cache_dtype"] == "fp8" + + empty_override = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"tensor_parallel_size": 1}, + judge_engine_kwargs_override={}, + ) + assert empty_override["tensor_parallel_size"] == 1 + + +def test_is_thinking_model_matches_qwen3_and_smollm3_repo_ids(): + assert utils.is_thinking_model("Qwen/Qwen3.5-9B") + assert utils.is_thinking_model("HuggingFaceTB/SmolLM3-3B") + assert utils.is_thinking_model("Qwen/Qwen3-7B") + assert not utils.is_thinking_model("Qwen/Qwen2.5-7B") + assert not utils.is_thinking_model("utter-project/EuroLLM-9B-Instruct") + assert not utils.is_thinking_model("meta-llama/Llama-3.1-8B") + + +def test_chat_vllm_preserves_explicit_reasoning_settings_for_non_qwen(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + explicit_reasoning_config = object() + + utils.ChatVLLM( + model="meta-llama/Llama-3.3-70B-Instruct", + max_tokens=16, + thinking_token_budget=32, + reasoning_parser="custom-parser", + reasoning_config=explicit_reasoning_config, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 16 + assert captured["llm_init"]["kwargs"]["reasoning_parser"] == "custom-parser" + assert ( + captured["llm_init"]["kwargs"]["reasoning_config"] is explicit_reasoning_config + ) + + +def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + with pytest.warns(UserWarning, match="built-in thinking-model"): + utils.ChatVLLM( + model="meta-llama/Llama-3.3-70B-Instruct", + max_tokens=32, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert "thinking_token_budget" not in captured["sampling_kwargs"] + assert "reasoning_parser" not in captured["llm_init"]["kwargs"] + assert "reasoning_config" not in captured["llm_init"]["kwargs"] + + +def test_chat_vllm_records_thinking_budget_exhaustion_metadata(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + class FakeLLMWithMarker: + def __init__(self, *, model, trust_remote_code, **kwargs): + captured["llm_init"] = {"model": model, "kwargs": kwargs} + + def get_tokenizer(self): + return SimpleNamespace(chat_template="{{ messages }}") + + def chat(self, messages, sampling_params, **kwargs): + return [ + SimpleNamespace( + outputs=[ + SimpleNamespace( + text=f"pre {utils.VLLM_REASONING_END_STR} answer", + finish_reason="stop", + stop_reason=None, + ) + ] + ), + SimpleNamespace( + outputs=[ + SimpleNamespace( + text="clean answer", + finish_reason="stop", + stop_reason=None, + ) + ] + ), + ] + + monkeypatch.setitem( + sys.modules, + "vllm", + SimpleNamespace( + LLM=FakeLLMWithMarker, + SamplingParams=sys.modules["vllm"].SamplingParams, + ), + ) + + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=64, + thinking_token_budget=32, + gpu_memory_utilization=0.7, + ) + _texts, metadata = chat_model.batch_with_metadata(["a", "b"]) + + assert metadata[0]["thinking_budget_exhausted"] is True + assert metadata[0]["thinking_token_budget"] == 32 + assert metadata[1]["thinking_budget_exhausted"] is False + assert metadata[1]["thinking_token_budget"] == 32 + + +def test_chat_vllm_omits_thinking_budget_metadata_without_budget(monkeypatch): + _captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=64, + gpu_memory_utilization=0.7, + ) + assert chat_model._thinking_budget_marker is None + assert chat_model._thinking_budget_value is None + + +def test_infer_model_spec_uses_type_based_vllm_fallback(): + model = object.__new__(utils.ChatVLLM) + model.model_path = "Qwen/Qwen3.5-9B" + + assert utils.infer_model_spec_from_instance(model) == "VLLM/Qwen/Qwen3.5-9B" + + +def test_infer_model_spec_uses_type_based_llamacpp_fallback(monkeypatch): + class FakeLlamaCpp: + def __init__(self, model_path: str): + self.model_path = model_path + + monkeypatch.setattr(utils, "LlamaCpp", FakeLlamaCpp) + model = FakeLlamaCpp("./models/model.gguf") + + assert utils.infer_model_spec_from_instance(model) == "LlamaCpp/./models/model.gguf" diff --git a/tests/test_cli.py b/tests/test_cli.py index 30be4fa..e19f0d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -248,12 +248,12 @@ def test_unknown_elo_task_raises(capture_mains): ) -def test_generate_and_evaluate_requires_model_a_and_b(capture_mains): +def test_pairwise_task_without_native_baseline_requires_model_a_and_b(capture_mains): with pytest.raises(SystemExit, match="--model_A and --model_B are required"): cli_module.cli( [ "--task", - "alpaca-eval", + "fluency-french", "--model_A", "Dummy/A", "--judge", @@ -262,6 +262,33 @@ def test_generate_and_evaluate_requires_model_a_and_b(capture_mains): ) +@pytest.mark.parametrize( + "task", + [ + "alpaca-eval", + "m-arena-hard-v2.0-EU", + ], +) +def test_pairwise_task_allows_missing_model_b_when_native_baseline_exists( + capture_mains, task: str +): + cli_module.cli( + [ + "--task", + task, + "--model_A", + "Dummy/A", + "--judge", + "Dummy/J", + ] + ) + assert capture_mains["module"] == "generate_and_evaluate" + ge_args: CliArgs = capture_mains["args"] + assert ge_args.task == task + assert ge_args.model_A == "Dummy/A" + assert ge_args.model_B is None + + def test_deprecated_model_flag_routes_into_pairwise_task(capture_mains): """`--model` is a deprecated alias for `--model_A` even on pairwise tasks.""" with pytest.warns(DeprecationWarning, match="--model is deprecated"): @@ -329,3 +356,39 @@ def test_engine_kwargs_parsed_as_json(capture_mains): ) ge_args: CliArgs = capture_mains["args"] assert ge_args.engine_kwargs == {"tensor_parallel_size": 4} + + +def test_mt_bench_defaults_to_default_judge_mode(capture_mains): + cli_module.cli( + [ + "--task", + "mt-bench", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.mt_bench_judge_mode == "default" + + +def test_mt_bench_forwards_fastchat_original_mode(capture_mains): + cli_module.cli( + [ + "--task", + "mt-bench", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--mt_bench_judge_mode", + "fastchat_original", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.mt_bench_judge_mode == "fastchat_original" diff --git a/tests/test_generate_and_evaluate.py b/tests/test_generate_and_evaluate.py index 03a278f..ed81cb4 100644 --- a/tests/test_generate_and_evaluate.py +++ b/tests/test_generate_and_evaluate.py @@ -55,7 +55,8 @@ def _run_without_cache(fun, **_kwargs): "arena-hard-v2.0", "arena-hard-v0.1", "fluency-french", - "m-arena-hard-EU", + "m-arena-hard-v0.1-EU", + "m-arena-hard-v2.0-EU", ], ) def test_generate_and_evaluate_context_completion(task: str, tmp_path): @@ -96,3 +97,25 @@ def test_generate_and_evaluate_correct_order_bias(tmp_path): avg_pref = sum(prefs) / len(prefs) assert avg_pref == 0.5 + + +def test_cli_args_parse_optional_boolean_flags(monkeypatch): + monkeypatch.setattr( + "sys.argv", + [ + "generate_and_evaluate.py", + "--dataset=alpaca-eval", + "--model_A=Dummy/A", + "--model_B=Dummy/B", + "--judge_model=Dummy/Judge", + "--use_tqdm=True", + "--ignore_cache=True", + "--strip_thinking_before_judging=False", + ], + ) + + args = CliArgs.parse_args() + + assert args.use_tqdm is True + assert args.ignore_cache is True + assert args.strip_thinking_before_judging is False diff --git a/tests/test_generate_and_evaluate_arena_hard.py b/tests/test_generate_and_evaluate_arena_hard.py new file mode 100644 index 0000000..aa184a9 --- /dev/null +++ b/tests/test_generate_and_evaluate_arena_hard.py @@ -0,0 +1,129 @@ +import pandas as pd +import pytest + +from judgearena.generate_and_evaluate import ( + ALPACA_EVAL_BASELINES, + PAIRWISE_BASELINES, + BaselinePlan, + CliArgs, + _resolve_baseline_plan, + native_pairwise_baseline, +) + + +def _make_args(dataset, model_b=None): + return CliArgs( + task=dataset, + model_A="A", + model_B=model_b, + judge_model="J", + ) + + +def _instructions(ids, categories=None): + data = {"instruction": list(ids)} + if categories is not None: + data["category"] = list(categories) + return pd.DataFrame(data, index=pd.Index(ids, name="instruction_index")) + + +def test_resolve_plan_v01_flat_default(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v0.1"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "gpt-4-0314" + + +def test_resolve_plan_v20_routes_per_category(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions( + ["qh", "qc"], + categories=["hard_prompt", "creative_writing"], + ), + ) + assert not plan.is_flat + assert plan.baseline_by_index.loc["qh"] == "o3-mini-2025-01-31" + assert plan.baseline_by_index.loc["qc"] == "gemini-2.0-flash-001" + + +def test_resolve_plan_alpaca_eval_uses_native_baseline(): + plan = _resolve_baseline_plan( + args=_make_args("alpaca-eval"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "gpt4_1106_preview" + + +def test_native_pairwise_baseline_mapping_covers_flat_tasks(): + assert ALPACA_EVAL_BASELINES == {"alpaca-eval": "gpt4_1106_preview"} + assert PAIRWISE_BASELINES["alpaca-eval"] == "gpt4_1106_preview" + assert native_pairwise_baseline("alpaca-eval") == "gpt4_1106_preview" + assert native_pairwise_baseline("mt-bench") == "gpt-4" + + +def test_resolve_plan_m_arena_hard_uses_native_baseline(): + plan = _resolve_baseline_plan( + args=_make_args("m-arena-hard-v2.0-EU"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "google/gemini-2.5-flash" + + +def test_native_pairwise_baseline_resolves_m_arena_hard_variants(): + assert ( + native_pairwise_baseline("m-arena-hard-v0.1-uk") == "CohereLabs/aya-expanse-8b" + ) + assert native_pairwise_baseline("m-arena-hard-v2.0-EU") == "google/gemini-2.5-flash" + + +def test_resolve_plan_explicit_model_b_overrides_native(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0", model_b="override"), + instructions_df=_instructions( + ["q1", "q2"], categories=["hard_prompt", "creative_writing"] + ), + ) + assert plan.is_flat + assert plan.single_model == "override" + + +def test_resolve_plan_task_without_native_baseline_requires_model_b(): + with pytest.raises(ValueError, match="model_B"): + _resolve_baseline_plan( + args=_make_args("fluency-french"), + instructions_df=_instructions(["q1"]), + ) + + +def test_resolve_plan_v20_missing_category_raises(): + with pytest.raises(ValueError, match="category"): + _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions(["q1"]), + ) + + +def test_resolve_plan_v20_unknown_category_raises(): + with pytest.raises(ValueError, match="brand_new"): + _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions(["q1"], categories=["brand_new"]), + ) + + +def test_baseline_plan_flat_repeats_model(): + plan = BaselinePlan.flat("b", index=pd.Index(["a", "b"])) + assert plan.is_flat + assert plan.baseline_by_index.tolist() == ["b", "b"] + + +def test_baseline_plan_per_row_preserves_order(): + series = pd.Series(["m1", "m2"], index=["a", "b"], name="model_B") + plan = BaselinePlan.per_row(series) + assert not plan.is_flat + assert plan.unique_models == ["m1", "m2"] diff --git a/tests/test_instruction_dataset.py b/tests/test_instruction_dataset.py index 4daa144..849fd94 100644 --- a/tests/test_instruction_dataset.py +++ b/tests/test_instruction_dataset.py @@ -1,21 +1,65 @@ from pathlib import Path import pandas as pd +import pytest import judgearena.generate_and_evaluate as generate_and_evaluate import judgearena.instruction_dataset as instruction_dataset +import judgearena.utils as judgearena_utils from judgearena.instruction_dataset.arena_hard import ( - arena_hard_baseline_model, + ARENA_HARD_BASELINES, + _build_instructions, + _build_model_outputs, + _extract_assistant_output, + arena_hard_native_baseline, normalize_official_arena_hard, ) -def test_arena_hard_baseline_resolution(): - assert arena_hard_baseline_model("arena-hard-v0.1") == "gpt-4-0314" - assert arena_hard_baseline_model("arena-hard-v2.0") == "o3-mini-2025-01-31" +def test_arena_hard_native_baseline_v01_is_flat_string(): + assert arena_hard_native_baseline("arena-hard-v0.1") == "gpt-4-0314" -def test_normalize_official_arena_hard_v01_shape(): +def test_arena_hard_native_baseline_v20_is_per_category_mapping(): + native = arena_hard_native_baseline("arena-hard-v2.0") + assert isinstance(native, dict) + assert native["hard_prompt"] == "o3-mini-2025-01-31" + assert native["coding"] == "o3-mini-2025-01-31" + assert native["math"] == "o3-mini-2025-01-31" + assert native["creative_writing"] == "gemini-2.0-flash-001" + + +def test_arena_hard_baselines_mapping_matches_upstream(): + """Pin the exact baseline assignment so a silent edit to + ARENA_HARD_BASELINES can't drift away from upstream + (arena-hard-auto/utils/judge_utils.py::JUDGE_SETTINGS). + """ + assert ARENA_HARD_BASELINES == { + "arena-hard-v0.1": "gpt-4-0314", + "arena-hard-v2.0": { + "hard_prompt": "o3-mini-2025-01-31", + "coding": "o3-mini-2025-01-31", + "math": "o3-mini-2025-01-31", + "creative_writing": "gemini-2.0-flash-001", + }, + } + + +def test_mt_bench_native_baseline_is_flat_string(): + from judgearena.instruction_dataset.mt_bench import ( + MT_BENCH_BASELINES, + is_mt_bench_dataset, + mt_bench_native_baseline, + ) + + assert is_mt_bench_dataset("mt-bench") is True + assert is_mt_bench_dataset("alpaca-eval") is False + assert mt_bench_native_baseline("mt-bench") == "gpt-4" + assert mt_bench_native_baseline("alpaca-eval") is None + assert MT_BENCH_BASELINES == {"mt-bench": "gpt-4"} + + +def test_normalize_official_arena_hard_v01_drops_no_category(): raw_df = pd.DataFrame( { "question_id": ["q1", "q2"], @@ -35,6 +79,165 @@ def test_normalize_official_arena_hard_v01_shape(): assert set(df_outputs.columns) == {"instruction_index", "model", "output"} +def test_normalize_official_arena_hard_v20_preserves_category(): + raw_df = pd.DataFrame( + { + "question_id": ["q1", "q2", "q1"], + "prompt": ["First prompt", "Second prompt", None], + "category": ["hard_prompt", "creative_writing", None], + "model": [None, None, "o3-mini-2025-01-31"], + "output": [None, None, "answer text"], + } + ) + df_instructions, df_outputs = normalize_official_arena_hard( + raw_df=raw_df, dataset="arena-hard-v2.0" + ) + + assert "category" in df_instructions.columns + assert df_instructions.set_index("instruction_index")["category"].to_dict() == { + "q1": "hard_prompt", + "q2": "creative_writing", + } + assert df_outputs is not None + assert df_outputs["model"].tolist() == ["o3-mini-2025-01-31"] + assert df_outputs["output"].tolist() == ["answer text"] + + +def test_build_model_outputs_extracts_upstream_messages_shape(): + """Upstream's `model_answer/*.jsonl` rows keep the assistant response in + `messages[-1].content.answer` rather than a flat `output` column. Without + this extractor, a fresh `download_arena_hard` clone would silently drop + every baseline answer. + """ + raw_df = pd.DataFrame( + [ + { + "uid": "q1", + "model": "o3-mini-2025-01-31", + "messages": [ + {"role": "user", "content": "Prompt"}, + { + "role": "assistant", + "content": {"answer": "nested answer", "reasoning": "..."}, + }, + ], + }, + { + "uid": "q2", + "model": "gemini-2.0-flash-001", + "messages": [ + {"role": "user", "content": "Prompt"}, + {"role": "assistant", "content": "plain string answer"}, + ], + }, + { + "uid": "q3", + "model": "baseline", + "output": "flat output column", + }, + { + "uid": "q4", + "model": "no-output-model", + "messages": [{"role": "assistant", "content": {"reasoning": "..."}}], + }, + ] + ) + + df_outputs = _build_model_outputs(raw_df) + + assert df_outputs is not None + outputs_by_model = dict(zip(df_outputs["model"], df_outputs["output"], strict=True)) + assert outputs_by_model == { + "o3-mini-2025-01-31": "nested answer", + "gemini-2.0-flash-001": "plain string answer", + "baseline": "flat output column", + } + assert "no-output-model" not in outputs_by_model + + +@pytest.mark.parametrize( + "row, expected", + [ + ({"output": "flat"}, "flat"), + ( + { + "messages": [ + {"role": "user", "content": "p"}, + {"role": "assistant", "content": {"answer": "nested"}}, + ] + }, + "nested", + ), + ( + { + "messages": [ + {"role": "user", "content": "p"}, + {"role": "assistant", "content": "plain"}, + ] + }, + "plain", + ), + ({"output": None, "messages": None}, None), + ( + {"messages": [{"role": "assistant", "content": {"reasoning": "only"}}]}, + None, + ), + ], +) +def test_extract_assistant_output_covers_known_shapes(row, expected): + assert _extract_assistant_output(pd.Series(row)) == expected + + +def test_build_model_outputs_returns_multi_model_rows_per_upstream_zip(): + """The fresh-clone loader must produce one row per (model, uid) so the + flat zip consumed by `try_load_dataset_completions` pivots cleanly. + """ + raw_df = pd.DataFrame( + [ + { + "uid": "q1", + "model": "o3-mini-2025-01-31", + "messages": [{"role": "assistant", "content": {"answer": "o3 q1"}}], + }, + { + "uid": "q2", + "model": "o3-mini-2025-01-31", + "messages": [{"role": "assistant", "content": {"answer": "o3 q2"}}], + }, + { + "uid": "q1", + "model": "gemini-2.0-flash-001", + "messages": [{"role": "assistant", "content": {"answer": "gemini q1"}}], + }, + ] + ) + + df_outputs = _build_model_outputs(raw_df) + + assert df_outputs is not None + assert sorted(df_outputs["model"].unique().tolist()) == [ + "gemini-2.0-flash-001", + "o3-mini-2025-01-31", + ] + assert df_outputs.shape[0] == 3 + + +def test_build_instructions_drops_model_answer_rows(): + """Question rows and model-answer rows share a dataframe on fresh clone; + `_build_instructions` has to keep only the prompt rows so the instruction + table doesn't leak rows with no prompt text. + """ + raw_df = pd.DataFrame( + [ + {"uid": "q1", "prompt": "real prompt", "category": "hard_prompt"}, + {"uid": "q1", "model": "baseline", "output": "answer"}, + ] + ) + df = _build_instructions(raw_df) + assert df["instruction_index"].tolist() == ["q1"] + assert df["instruction"].tolist() == ["real prompt"] + + def test_load_instructions_uses_explicit_version_filename(monkeypatch): captured = {} @@ -52,7 +255,7 @@ def _fake_read_df(path: Path): ) monkeypatch.setattr(instruction_dataset, "download_arena_hard", _fake_ensure) - monkeypatch.setattr(instruction_dataset, "read_df", _fake_read_df) + monkeypatch.setattr(judgearena_utils, "read_df", _fake_read_df) df = instruction_dataset.load_instructions(dataset="arena-hard-v2.0") assert captured["dataset"] == "arena-hard-v2.0" @@ -60,6 +263,35 @@ def _fake_read_df(path: Path): assert df.index.tolist() == ["0", "1"] +def test_load_instructions_surfaces_category_for_v20(monkeypatch): + """The per-category baseline plan in `generate_and_evaluate` keys off + the `category` column, so `load_instructions` must keep it round-tripping + from the cached CSV. + """ + monkeypatch.setattr( + instruction_dataset, + "download_arena_hard", + lambda dataset, local_tables_path: None, + ) + monkeypatch.setattr( + judgearena_utils, + "read_df", + lambda path: pd.DataFrame( + { + "instruction_index": ["q1", "q2"], + "instruction": ["a", "b"], + "category": ["hard_prompt", "creative_writing"], + } + ), + ) + + df = instruction_dataset.load_instructions(dataset="arena-hard-v2.0") + + assert "category" in df.columns + assert df.loc["q1", "category"] == "hard_prompt" + assert df.loc["q2", "category"] == "creative_writing" + + def test_try_load_dataset_completions_uses_dataset_output_file(monkeypatch, tmp_path): tables_dir = tmp_path / "tables" / "model_outputs" tables_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py new file mode 100644 index 0000000..eef3916 --- /dev/null +++ b/tests/test_local_completion_loading.py @@ -0,0 +1,474 @@ +import pandas as pd + +import judgearena.evaluate as evaluate +import judgearena.generate_and_evaluate as generate_and_evaluate +from judgearena.cli_common import parse_optional_bool +from judgearena.generate_and_evaluate import CliArgs +from judgearena.generate_and_evaluate import main as main_generate_and_eval +from judgearena.judge_prompt_presets import SKYWORK_JUDGE_PROMPT_PRESET + + +def test_load_judge_prompt_without_explanation_uses_freeform_scores(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=False + ) + + assert "valid JSON" not in user_prompt + assert "score_A:" in user_prompt + assert "score_B:" in user_prompt + assert "Assistant A's Answer" in user_prompt + + +def test_load_judge_prompt_with_explanation_uses_freeform_scores(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=True + ) + + assert "valid JSON" not in user_prompt + assert "first starts with an explanation of your judgement" in user_prompt + assert "score_A:" in user_prompt + assert "score_B:" in user_prompt + assert "Assistant B's Answer" in user_prompt + + +def test_load_judge_prompt_multi_turn_uses_conversation_label(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=False, + multi_turn=True, + ) + + assert "Assistant A's Conversation with User" in user_prompt + assert "Assistant B's Conversation with User" in user_prompt + assert "Assistant A's Answer" not in user_prompt + + +def test_parse_optional_bool_accepts_explicit_true_false_values(): + assert parse_optional_bool(None) is True + assert parse_optional_bool("true") is True + assert parse_optional_bool("False") is False + + +def test_main_passes_qwen_defaults_and_aligns_dataset_completions( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction B", "Instruction A"]}, + index=pd.Index(["b", "a"], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + + def fake_try_load_dataset_completions(dataset, model, n_instructions): + if model == "Dummy/model-a": + return pd.DataFrame( + { + "instruction_index": ["a", "b"], + "completion": ["Answer A", "no answer"], + } + ) + return pd.DataFrame( + { + "instruction_index": ["a", "b"], + "completion": ["Answer B", "Answer C"], + } + ) + + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + fake_try_load_dataset_completions, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = { + "model": model, + "max_tokens": max_tokens, + "max_model_len": max_model_len, + "chat_template": chat_template, + "kwargs": kwargs, + } + return object() + + def fake_judge_and_parse_prefs(**kwargs): + captured["judge_kwargs"] = kwargs + annotations = [{"judge_completion": "score_A: 0\nscore_B: 10"}] * len( + kwargs["instructions"] + ) + return annotations, None, pd.Series([1.0] * len(kwargs["instructions"])) + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + fake_judge_and_parse_prefs, + ) + + prefs = main_generate_and_eval( + CliArgs( + task="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=2, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [1.0, 1.0] + assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 + assert captured["make_model"]["kwargs"]["limit_event_stage"] == "judge_model_init" + assert captured["make_model"]["kwargs"]["limit_event_model_spec"] == ( + "VLLM/Qwen/Qwen3.5-27B-FP8" + ) + assert captured["judge_kwargs"]["instructions"] == [ + "Instruction B", + "Instruction A", + ] + assert captured["judge_kwargs"]["completions_A"] == ["no answer", "Answer A"] + assert captured["judge_kwargs"]["completions_B"] == ["Answer C", "Answer B"] + assert captured["judge_kwargs"]["case_ids"] == ["b", "a"] + assert captured["judge_kwargs"]["prompt_preset"] == "default" + assert captured["judge_kwargs"]["parser_mode"] == "score" + assert captured["judge_kwargs"]["strip_thinking_before_judging"] is False + + +def test_main_does_not_pass_thinking_budget_to_non_reasoning_vllm_judge( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = kwargs + return object() + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + lambda **kwargs: ( + [{"judge_completion": "score_A: 1\nscore_B: 2"}], + None, + pd.Series([1.0]), + ), + ) + + prefs = main_generate_and_eval( + CliArgs( + task="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", + n_instructions=1, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [1.0] + assert "thinking_token_budget" not in captured["make_model"] + assert captured["make_model"]["limit_event_stage"] == "judge_model_init" + + +def test_main_preserves_explicit_reasoning_engine_kwargs_for_non_qwen_vllm_judge( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = kwargs + return object() + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + lambda **kwargs: ( + [{"judge_completion": "score_A: 1\nscore_B: 2"}], + None, + pd.Series([1.0]), + ), + ) + + prefs = main_generate_and_eval( + CliArgs( + task="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", + n_instructions=1, + result_folder=str(tmp_path / "results"), + engine_kwargs={ + "reasoning_parser": "custom-parser", + "thinking_token_budget": 2048, + }, + ) + ) + + assert prefs.tolist() == [1.0] + assert captured["make_model"]["reasoning_parser"] == "custom-parser" + assert captured["make_model"]["thinking_token_budget"] == 2048 + + +def test_annotate_battles_warns_when_judge_inputs_are_truncated(monkeypatch, caplog): + captured = {} + + def fake_do_inference( + *, + chat_model, + inputs, + use_tqdm, + usage_tracker, + usage_phase, + usage_model_spec, + ): + captured["judge_prompt"] = inputs[0].to_messages()[1].content + return ["score_A: 0\nscore_B: 10"] + + monkeypatch.setattr(evaluate, "do_inference", fake_do_inference) + caplog.set_level("WARNING", logger=evaluate.__name__) + + annotations = evaluate.annotate_battles( + judge_chat_model=object(), + instructions=["Instruction"], + completions_A=["Answer A"], + completions_B=["Answer B"], + truncate_input_chars=3, + ) + + assert ( + "Warning: truncated 2 judge inputs to 3 characters before evaluation." + in caplog.text + ) + assert "Ans" in captured["judge_prompt"] + assert "Answer A" not in captured["judge_prompt"] + assert "Answer B" not in captured["judge_prompt"] + assert "valid JSON" not in captured["judge_prompt"] + assert "score_A:" in captured["judge_prompt"] + assert annotations[0].completion_A == "Answer A" + assert annotations[0].completion_B == "Answer B" + + +def test_resolve_judge_prompts_supports_optional_skywork_preset(): + resolved = evaluate.resolve_judge_prompts( + provide_explanation=False, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert resolved.preset_name == SKYWORK_JUDGE_PROMPT_PRESET + assert resolved.parser_mode == "verdict" + assert resolved.system_prompt is None + assert "[[A]]" in resolved.user_prompt_template + assert "score_A:" not in resolved.user_prompt_template + assert "[User Question]" in resolved.user_prompt_template + assert "Assistant A's Answer" in resolved.user_prompt_template + + +def test_resolve_judge_prompts_skywork_explanation_prompt_has_fixed_answer_labels(): + resolved = evaluate.resolve_judge_prompts( + provide_explanation=True, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert ( + "Please briefly explain your reasoning first" in resolved.user_prompt_template + ) + assert "Assistant B's Answer" in resolved.user_prompt_template + + +def test_annotate_battles_records_limit_events_for_stripping_and_truncation( + monkeypatch, +): + tracker = evaluate.LimitEventTracker() + + monkeypatch.setattr( + evaluate, + "do_inference", + lambda **kwargs: ["[[A]]"], + ) + + evaluate.annotate_battles( + judge_chat_model=object(), + instructions=["Instruction"], + completions_A=["hiddenVisible answer"], + completions_B=["Short"], + case_ids=["case-1"], + truncate_input_chars=5, + strip_thinking_before_judging=True, + limit_event_tracker=tracker, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + summary = tracker.build_summary() + + assert summary["counts_by_kind"]["thinking_trace_stripped_before_judging"] == 1 + assert summary["counts_by_kind"]["judge_input_char_truncation"] == 1 + + +def test_main_passes_qwen_only_battle_budget_and_prompt_preset(tmp_path, monkeypatch): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: None, + ) + monkeypatch.setattr( + generate_and_evaluate, + "cache_function_dataframe", + lambda fun, **_kwargs: fun(), + ) + + def fake_generate_instructions( + *, + instructions, + model, + truncate_input_chars, + max_tokens, + max_model_len, + chat_template, + use_tqdm, + usage_tracker, + usage_phase, + limit_event_tracker, + **engine_kwargs, + ): + captured.setdefault("generation_calls", []).append( + { + "model": model, + "max_tokens": max_tokens, + "engine_kwargs": engine_kwargs, + } + ) + return pd.DataFrame( + { + "instruction_index": [1], + "completion": [f"{model}-answer"], + "generation_prompt_truncated": [False], + "generation_output_finish_reason": [None], + "generation_output_hit_token_limit": [False], + } + ) + + monkeypatch.setattr( + generate_and_evaluate, "generate_instructions", fake_generate_instructions + ) + + def fake_judge_and_parse_prefs(**kwargs): + captured["judge_kwargs"] = kwargs + return [{"judge_completion": "[[A]]"}], None, pd.Series([0.0]) + + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + fake_judge_and_parse_prefs, + ) + monkeypatch.setattr( + generate_and_evaluate, + "make_model", + lambda **kwargs: object(), + ) + + prefs = main_generate_and_eval( + CliArgs( + task="alpaca-eval", + model_A="VLLM/Qwen/Qwen3.5-27B-FP8", + model_B="VLLM/allenai/Olmo-3-7B-Instruct", + judge_model="Dummy/judge", + n_instructions=1, + judge_prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + battle_thinking_token_budget=512, + strip_thinking_before_judging=True, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [0.0] + assert len(captured["generation_calls"]) == 2 + assert ( + captured["generation_calls"][0]["engine_kwargs"]["thinking_token_budget"] == 512 + ) + assert ( + "thinking_token_budget" not in captured["generation_calls"][1]["engine_kwargs"] + ) + assert captured["judge_kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET + assert captured["judge_kwargs"]["parser_mode"] == "verdict" + assert captured["judge_kwargs"]["strip_thinking_before_judging"] is True + + +def test_generation_cache_name_changes_with_generation_settings(): + args = CliArgs( + task="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="Dummy/judge", + n_instructions=1, + max_out_tokens_models=1024, + battle_thinking_token_budget=256, + ) + changed_args = CliArgs( + task="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="Dummy/judge", + n_instructions=1, + max_out_tokens_models=4096, + battle_thinking_token_budget=512, + ) + + cache_name = generate_and_evaluate._generation_cache_name( + args, model_spec="VLLM/Qwen/Qwen3.5-27B-FP8" + ) + changed_cache_name = generate_and_evaluate._generation_cache_name( + changed_args, model_spec="VLLM/Qwen/Qwen3.5-27B-FP8" + ) + + assert cache_name != changed_cache_name diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index a91f785..368cdab 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -1,5 +1,36 @@ +import importlib +from datetime import UTC, datetime +from types import SimpleNamespace + +import pandas as pd + import judgearena.instruction_dataset.mt_bench as mt_bench +import judgearena.mt_bench.fastchat_compat as fastchat_compat +import judgearena.mt_bench.mt_bench_utils as mt_bench_utils import judgearena.utils as utils +from judgearena.cli_common import BaseCliArgs +from judgearena.judge_prompt_presets import SKYWORK_JUDGE_PROMPT_PRESET + + +def _mt_bench_args( + *, + dataset: str, + model_A: str, + model_B: str, + use_tqdm: bool = False, + **base_overrides, +) -> BaseCliArgs: + """Construct a ``BaseCliArgs`` with MT-Bench CLI-style extras attached. + + Using the real dataclass here keeps tests close to the production CLI + contract while attaching the task/model fields owned by ``CliArgs``. + """ + args = BaseCliArgs(**base_overrides) + args.task = dataset + args.model_A = model_A + args.model_B = model_B + args.use_tqdm = use_tqdm + return args def test_download_mt_bench_skips_question_download_if_cached(tmp_path, monkeypatch): @@ -66,7 +97,8 @@ def _contexts_snapshot_stub(**_kwargs): tables_dir = tmp_path / "tables" assert [name for name, _ in hf_datasets] == [ "alpaca-eval", - "m-arena-hard", + "m-arena-hard-v0.1", + "m-arena-hard-v2.0", ] assert arena_hard_datasets == [ ("arena-hard-v0.1", tables_dir), @@ -74,3 +106,704 @@ def _contexts_snapshot_stub(**_kwargs): ] assert calls["contexts"] == 1 assert calls["mt_bench"] == 1 + + +def test_load_mt_bench_model_answers_reads_cached_baseline_file(tmp_path): + answer_path = tmp_path / "data" / "mt_bench" / "model_answer" / "gpt-4.jsonl" + answer_path.parent.mkdir(parents=True, exist_ok=True) + answer_path.write_text( + '{"question_id": 2, "choices": [{"turns": ["A2", "B2"]}]}\n' + '{"question_id": 1, "choices": [{"turns": ["A1"]}]}\n' + ) + + df_answers = mt_bench.load_mt_bench_model_answers("gpt-4", local_dir=tmp_path) + + assert df_answers.to_dict(orient="records") == [ + { + "instruction_index": 1, + "completion_turn_1": "A1", + "completion_turn_2": "", + }, + { + "instruction_index": 2, + "completion_turn_1": "A2", + "completion_turn_2": "B2", + }, + ] + + +def test_generate_mt_bench_completions_uses_pregenerated_baseline(monkeypatch): + questions_df = pd.DataFrame( + {"turn_1": ["Q1", "Q2"], "turn_2": ["Q1b", "Q2b"]}, + index=pd.Index([1, 2], name="instruction_index"), + ) + generated_models = [] + generation_kwargs = [] + generation_strip_flags = [] + + monkeypatch.setattr( + mt_bench_utils, "cache_function_dataframe", lambda fun, **_kwargs: fun() + ) + + def fake_generate_multiturn( + *, + questions, + model, + truncate_input_chars, + max_tokens, + use_tqdm, + max_model_len, + chat_template, + temperature_config, + usage_tracker, + usage_phase, + limit_event_tracker, + strip_thinking_before_turn_2_prompt, + **engine_kwargs, + ): + generated_models.append(model) + generation_kwargs.append(engine_kwargs) + generation_strip_flags.append(strip_thinking_before_turn_2_prompt) + return pd.DataFrame( + { + "instruction_index": [1, 2], + "completion_turn_1": ["Gen A1", "Gen A2"], + "completion_turn_2": ["Gen B1", "Gen B2"], + } + ) + + monkeypatch.setattr(mt_bench_utils, "generate_multiturn", fake_generate_multiturn) + monkeypatch.setattr( + mt_bench_utils, + "load_mt_bench_model_answers", + lambda model, n_instructions=None: ( + pd.DataFrame( + { + "instruction_index": [2, 1], + "completion_turn_1": ["Base A2", "Base A1"], + "completion_turn_2": ["Base B2", "Base B1"], + } + ) + if model == "gpt-4" + else None + ), + ) + + args = SimpleNamespace( + model_A="VLLM/example/model-a", + model_B="gpt-4", + n_instructions=2, + truncate_all_input_chars=8192, + max_out_tokens_models=1024, + use_tqdm=False, + max_model_len=16384, + chat_template=None, + battle_thinking_token_budget=None, + strip_thinking_before_judging=True, + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + completions_a, completions_b = mt_bench_utils._generate_mt_bench_completions( + args=args, + questions_df=questions_df, + ignore_cache=False, + usage_tracker=object(), + limit_event_tracker=None, + ) + + assert generated_models == ["VLLM/example/model-a"] + assert generation_kwargs == [ + {"gpu_memory_utilization": 0.7, "language_model_only": True} + ] + assert generation_strip_flags == [True] + assert completions_a.loc[1, "completion_turn_1"] == "Gen A1" + assert completions_b.loc[1, "completion_turn_1"] == "Base A1" + assert completions_b.loc[2, "completion_turn_2"] == "Base B2" + + +def test_parse_fastchat_verdict_accepts_bracketed_verdicts_after_thinking(): + assert ( + fastchat_compat._parse_fastchat_verdict( + "Need a longer chain.[[A]]" + ) + == "A" + ) + assert fastchat_compat._parse_fastchat_verdict("[[B]]") == "B" + assert fastchat_compat._parse_fastchat_verdict("[[C]]") == "tie" + + +def test_parse_fastchat_verdict_marks_non_bracketed_outputs_as_error(): + assert fastchat_compat._parse_fastchat_verdict("A") == "error" + assert fastchat_compat._parse_fastchat_verdict('{"verdict":"B"}') == "error" + + +def test_pair_v2_system_prompt_matches_original_fastchat_contract(): + rendered = fastchat_compat._PAIR_V2.system_prompt + + assert "provide a short explanation" in rendered + assert "valid JSON" not in rendered + assert '"[[A]]"' in rendered + assert '"[[B]]"' in rendered + assert '"[[C]]"' in rendered + assert rendered.endswith( + 'After providing your explanation, output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie.\n' + ) + + +def test_mt_bench_prompt_templates_preserve_multi_turn_reference_blocks(): + prompt_templates = importlib.import_module("judgearena.mt_bench.prompt_templates") + + rendered = prompt_templates.build_mt_bench_user_prompt_template( + multi_turn=True, + ref_based=True, + ) + + assert "<|The Start of Reference Answer|>" in rendered + assert "### User:\n{question_1}" in rendered + assert "### Reference answer:\n{ref_answer_2}" in rendered + assert "### Assistant A:\n{answer_a_2}" in rendered + assert "### Assistant B:\n{answer_b_2}" in rendered + + +def test_select_preset_prompt_uses_default_score_mode_with_mt_bench_template(): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + + prompt = preset_judging._select_preset_prompt( + "math", + multi_turn=True, + prompt_preset="default", + provide_explanation=False, + ) + + assert prompt.parser_mode == "score" + assert prompt.ref_based is True + assert prompt.multi_turn is True + assert prompt.system_prompt + assert "<|The Start of Reference Answer|>" in prompt.user_prompt_template + assert "### Assistant A:\n{answer_a_2}" in prompt.user_prompt_template + + +def test_select_preset_prompt_uses_skywork_verdict_mode_with_mt_bench_template(): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + + prompt = preset_judging._select_preset_prompt( + "writing", + multi_turn=True, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + provide_explanation=True, + ) + + assert prompt.parser_mode == "verdict" + assert prompt.ref_based is False + assert prompt.multi_turn is True + assert prompt.system_prompt is None + assert "Please briefly explain your reasoning first" in prompt.user_prompt_template + assert "### Assistant B:\n{answer_b_2}" in prompt.user_prompt_template + + +def test_preset_judging_uses_shared_swap_mode_both_semantics(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + questions_df = pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Q1"], + "turn_2": ["Q2"], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ) + completions_a = pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ) + completions_b = pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ) + call_count = {"count": 0} + + def fake_do_inference(**kwargs): + call_count["count"] += 1 + if call_count["count"] <= 2: + return ["score_A: 9\nscore_B: 3"] + return ["score_A: 2\nscore_B: 7"] + + monkeypatch.setattr(preset_judging, "do_inference", fake_do_inference) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=questions_df, + completions_a=completions_a, + completions_b=completions_b, + model_a="Model/A", + model_b="Model/B", + turns_mode="both", + swap_mode="both", + truncate_input_chars=None, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + ) + ) + + assert len(prefs) == 4 + assert len(annotations) == 4 + assert len(metadata) == 4 + assert num_inconsistent == 0 + assert [row["swapped"] for row in annotations] == [False, False, True, True] + + +def test_preset_judging_uses_shared_char_truncation_event_kind(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + tracker = utils.LimitEventTracker() + + monkeypatch.setattr( + preset_judging, + "do_inference", + lambda **kwargs: ["score_A: 8\nscore_B: 4"], + ) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Q" * 20], + "turn_2": [""], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ), + completions_a=pd.DataFrame( + {"completion_turn_1": ["A" * 30], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + completions_b=pd.DataFrame( + {"completion_turn_1": ["B" * 30], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + model_a="Model/A", + model_b="Model/B", + turns_mode="single", + swap_mode="fixed", + truncate_input_chars=10, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + limit_event_tracker=tracker, + ) + ) + + assert len(prefs) == 1 + assert len(annotations) == 1 + assert len(metadata) == 1 + assert num_inconsistent == 0 + assert tracker.build_summary()["counts_by_kind"]["judge_input_char_truncation"] == 3 + + +def test_preset_judging_preflights_token_budget_for_default_mode(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + tracker = utils.LimitEventTracker() + captured = {} + + class FakeTokenizer: + def apply_chat_template(self, messages, tokenize=True): + text = "".join( + str( + message["content"] if isinstance(message, dict) else message.content + ) + for message in messages + ) + return [0] * len(text) + + def encode(self, text): + return [0] * len(text) + + def fake_do_inference(*, inputs, **kwargs): + captured["inputs"] = inputs + return ["score_A: 8\nscore_B: 4"] + + monkeypatch.setattr(preset_judging, "do_inference", fake_do_inference) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Question"], + "turn_2": [""], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ), + completions_a=pd.DataFrame( + {"completion_turn_1": ["A" * 1200], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + completions_b=pd.DataFrame( + {"completion_turn_1": ["B" * 1200], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + model_a="Model/A", + model_b="Model/B", + turns_mode="single", + swap_mode="fixed", + truncate_input_chars=None, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + limit_event_tracker=tracker, + judge_tokenizer=FakeTokenizer(), + max_judge_model_len=2300, + max_out_tokens_judge=32, + ) + ) + + prompt_value = captured["inputs"][0] + assert len(prefs) == 1 + assert len(annotations) == 1 + assert len(metadata) == 1 + assert num_inconsistent == 0 + assert len(FakeTokenizer().apply_chat_template(prompt_value.to_messages())) <= 2012 + assert annotations[0]["answer_a_1_truncated"] is True + assert annotations[0]["answer_b_1_truncated"] is True + assert ( + tracker.build_summary()["counts_by_kind"]["judge_input_token_truncation"] >= 1 + ) + + +def test_conservative_winner_marks_one_sided_parse_failures_as_error(): + assert fastchat_compat._conservative_winner("model_A", "error") == ( + "error", + False, + ) + assert fastchat_compat._conservative_winner("error", "model_B") == ( + "error", + False, + ) + assert fastchat_compat._conservative_winner("error", "error") == ("error", False) + assert fastchat_compat._conservative_winner("model_A", "model_B") == ("tie", True) + + +def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch, caplog): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + + def fake_make_model( + *, + model, + max_tokens, + temperature=None, + max_model_len, + chat_template, + **kwargs, + ): + captured["make_model"] = { + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "max_model_len": max_model_len, + "chat_template": chat_template, + "kwargs": kwargs, + } + return object() + + monkeypatch.setattr(mt_bench_utils, "make_model", fake_make_model) + + def fake_run_mt_bench_preset(**kwargs): + captured["run_mt_bench_preset"] = kwargs + return pd.Series( + kwargs["questions_df"].index.to_list(), + dtype=float, + ) + + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_preset", + fake_run_mt_bench_preset, + ) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + truncate_all_input_chars=8192, + max_out_tokens_models=1024, + max_out_tokens_judge=256, + max_model_len=16384, + chat_template=None, + provide_explanation=False, + swap_mode="fixed", + judge_prompt_preset="default", + battle_thinking_token_budget=None, + strip_thinking_before_judging=False, + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + caplog.set_level("WARNING", logger=mt_bench_utils.__name__) + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert args.swap_mode == "fixed" + assert args.max_out_tokens_judge == 256 + assert args.max_model_len == 16384 + assert args.max_judge_model_len is None + assert captured["make_model"]["max_tokens"] == 256 + assert captured["make_model"]["max_model_len"] is None + assert captured["make_model"]["kwargs"] == { + "gpu_memory_utilization": 0.7, + "language_model_only": True, + "thinking_token_budget": 512, + "kv_cache_dtype": "fp8", + "limit_event_stage": "judge_model_init", + "limit_event_model_spec": "VLLM/Qwen/Qwen3.5-27B-FP8", + "limit_event_tracker": captured["make_model"]["kwargs"]["limit_event_tracker"], + } + assert captured["make_model"]["temperature"] is None + assert captured["run_mt_bench_preset"]["args"].swap_mode == "fixed" + assert captured["run_mt_bench_preset"]["prompt_preset"] == "default" + assert ( + captured["run_mt_bench_preset"]["args"].strip_thinking_before_judging is False + ) + assert "MT-Bench ignores provide_explanation=False" not in caplog.text + + +def test_select_prompt_supports_optional_skywork_mt_bench_preset(): + prompt = fastchat_compat._select_prompt( + "writing", + multi_turn=False, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert prompt.name == "skywork-pair-v2" + assert prompt.ref_based is False + + +def test_run_mt_bench_keeps_skywork_prompt_preset(monkeypatch): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + monkeypatch.setattr(mt_bench_utils, "make_model", lambda **kwargs: object()) + + def fake_run_mt_bench_preset(**kwargs): + captured["kwargs"] = kwargs + return pd.Series([0.0], dtype=float) + + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_preset", + fake_run_mt_bench_preset, + ) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Skywork/Skywork-Critic-Llama-3.1-8B", + n_instructions=1, + truncate_all_input_chars=8192, + truncate_judge_input_chars=80000, + max_out_tokens_models=1024, + max_out_tokens_judge=256, + max_model_len=16384, + max_judge_model_len=65536, + chat_template=None, + provide_explanation=False, + swap_mode="both", + judge_prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + battle_thinking_token_budget=512, + strip_thinking_before_judging=True, + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert captured["kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET + assert captured["kwargs"]["args"].strip_thinking_before_judging is True + assert args.max_judge_model_len == 65536 + assert args.truncate_judge_input_chars == 80000 + assert captured["kwargs"]["args"].truncate_judge_input_chars == 80000 + assert captured["kwargs"]["args"].max_judge_model_len == 65536 + + +def test_run_mt_bench_default_respects_judge_temperature_from_engine_kwargs( + monkeypatch, +): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + + def fake_make_model( + *, + model, + max_tokens, + temperature=None, + max_model_len, + chat_template, + **kwargs, + ): + captured["temperature"] = temperature + return object() + + monkeypatch.setattr(mt_bench_utils, "make_model", fake_make_model) + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_preset", + lambda **kwargs: pd.Series([0.0], dtype=float), + ) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + max_out_tokens_judge=256, + engine_kwargs={"temperature": 0.8}, + judge_engine_kwargs={"temperature": 0.4}, + ) + + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert captured["temperature"] == 0.4 + + +def test_save_mt_bench_results_uses_explicit_res_folder(tmp_path, monkeypatch): + captured = {} + + def fake_write_run_metadata(**kwargs): + captured["output_dir"] = kwargs["output_dir"] + return kwargs["output_dir"] / "run-metadata.v1.json" + + monkeypatch.setattr(mt_bench_utils, "write_run_metadata", fake_write_run_metadata) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + result_folder=str(tmp_path / "results-root"), + ) + explicit_res_folder = tmp_path / "explicit-run" + result_name = "mt-bench-run" + + mt_bench_utils._save_mt_bench_results( + args=args, + res_folder=explicit_res_folder, + result_name=result_name, + results={"task": "mt-bench"}, + annotations_df=pd.DataFrame([{"question_id": 1, "turn": 1}]), + questions_df=pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q2"]}, + index=pd.Index([1], name="question_id"), + ), + pricing_reference=None, + started_at_utc=datetime.now(UTC), + ) + + assert captured["output_dir"] == explicit_res_folder + assert (explicit_res_folder / f"args-{result_name}.json").exists() + assert (explicit_res_folder / f"results-{result_name}.json").exists() + assert (explicit_res_folder / f"{result_name}-annotations.csv").exists() + assert not (tmp_path / "results-root" / result_name).exists() diff --git a/tests/test_openrouter_reference_pricing.py b/tests/test_openrouter_reference_pricing.py new file mode 100644 index 0000000..c26ce48 --- /dev/null +++ b/tests/test_openrouter_reference_pricing.py @@ -0,0 +1,330 @@ +import json + +import judgearena.openrouter_reference_pricing as pricing +from judgearena.repro import write_run_metadata +from judgearena.utils import do_inference + + +class CountingModel: + def batch(self, inputs, **invoke_kwargs): + return [f"output-{idx}" for idx, _input in enumerate(inputs)] + + def count_prompt_tokens_batch(self, inputs): + return [len(str(input_item)) for input_item in inputs] + + def count_completion_tokens_batch(self, outputs): + return [len(output) for output in outputs] + + +def test_do_inference_records_token_usage(): + tracker = pricing.OpenRouterReferencePricingTracker() + model = CountingModel() + + outputs = do_inference( + chat_model=model, + inputs=["abc", "de"], + usage_tracker=tracker, + usage_phase="judge", + usage_model_spec="VLLM/org/model", + ) + + assert outputs == ["output-0", "output-1"] + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="VLLM/org/model", + prompt_tokens=3, + completion_tokens=8, + requests=1, + ), + pricing.TokenUsageRecord( + phase="judge", + model_spec="VLLM/org/model", + prompt_tokens=2, + completion_tokens=8, + requests=1, + ), + ] + + +class _FakeAIMessage: + """Minimal AIMessage stand-in: .content + langchain-core .usage_metadata.""" + + def __init__(self, content: str, usage_metadata: dict[str, int] | None) -> None: + self.content = content + self.usage_metadata = usage_metadata + self.response_metadata: dict[str, object] = {} + + +class _OpenRouterShapeModel: + """Mimics ChatOpenAI: returns AIMessage objects with usage_metadata, no + count_*_batch helpers (so the fallback tokeniser path is unavailable).""" + + def __init__(self, usages: list[dict[str, int] | None]) -> None: + self._usages = usages + + def batch(self, inputs, **invoke_kwargs): + return [ + _FakeAIMessage(content=f"output-{idx}", usage_metadata=self._usages[idx]) + for idx, _ in enumerate(inputs) + ] + + +def test_do_inference_records_usage_metadata_for_openrouter_shape_models(): + """For OpenRouter ChatOpenAI calls, the tracker must record API-reported + token counts pulled from AIMessage.usage_metadata; ``record_batch_from_model`` + is a no-op because ChatOpenAI lacks ``count_*_batch`` helpers.""" + tracker = pricing.OpenRouterReferencePricingTracker() + model = _OpenRouterShapeModel( + usages=[ + {"input_tokens": 1234, "output_tokens": 17, "total_tokens": 1251}, + {"input_tokens": 800, "output_tokens": 42, "total_tokens": 842}, + ] + ) + + outputs = do_inference( + chat_model=model, + inputs=["prompt-1", "prompt-2"], + usage_tracker=tracker, + usage_phase="judge", + usage_model_spec="OpenRouter/google/gemma-4-31b-it", + ) + + assert outputs == ["output-0", "output-1"] + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=1234, + completion_tokens=17, + requests=1, + ), + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=800, + completion_tokens=42, + requests=1, + ), + ] + + +def test_do_inference_falls_back_to_count_batch_when_usage_metadata_missing(): + """When AIMessage results carry no ``usage_metadata`` (e.g. local vLLM path), + ``do_inference`` must fall back to ``record_batch_from_model`` and use the + chat_model's ``count_*_batch`` helpers.""" + tracker = pricing.OpenRouterReferencePricingTracker() + model = CountingModel() + + do_inference( + chat_model=model, + inputs=["abc", "de"], + usage_tracker=tracker, + usage_phase="generation_model_A", + usage_model_spec="VLLM/org/model", + ) + + assert [(r.prompt_tokens, r.completion_tokens) for r in tracker.records] == [ + (3, 8), + (2, 8), + ] + + +def test_record_batch_from_usage_metadata_returns_false_on_all_none(): + """A batch where every entry is ``None`` must signal "no records added" so + the caller can fall through to the tokeniser-based path.""" + tracker = pricing.OpenRouterReferencePricingTracker() + + recorded = tracker.record_batch_from_usage_metadata( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + usages=[None, None, None], + ) + + assert recorded is False + assert tracker.records == [] + + +def test_record_batch_from_usage_metadata_accepts_openai_shape_keys(): + """OpenAI-shape keys (``prompt_tokens``/``completion_tokens``) appear on + ``response_metadata.token_usage`` for older langchain-openai versions.""" + tracker = pricing.OpenRouterReferencePricingTracker() + + recorded = tracker.record_batch_from_usage_metadata( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + usages=[{"prompt_tokens": 100, "completion_tokens": 25, "total_tokens": 125}], + ) + + assert recorded is True + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=100, + completion_tokens=25, + requests=1, + ) + ] + + +def test_build_reference_pricing_summary_uses_exact_match_and_reports_partial_cost( + monkeypatch, +): + catalog = pricing.parse_openrouter_catalog_payload( + { + "data": [ + { + "id": "openrouter/example-model", + "canonical_slug": "openrouter/example-model", + "hugging_face_id": "Org/Example-Model", + "name": "Example Model", + "pricing": { + "prompt": "0.001", + "completion": "0.002", + "request": "0.01", + "internal_reasoning": "0.5", + }, + } + ], + "fetched_at_utc": "2026-04-07T00:00:00+00:00", + } + ) + monkeypatch.setattr( + pricing, + "load_openrouter_price_catalog_with_fallback", + lambda **kwargs: (catalog, None), + ) + + tracker = pricing.OpenRouterReferencePricingTracker() + tracker._records.extend( + [ + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Org/Example-Model", + prompt_tokens=100, + completion_tokens=20, + ), + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Org/Example-Model", + prompt_tokens=50, + completion_tokens=5, + ), + pricing.TokenUsageRecord( + phase="generation_model_B", + model_spec="VLLM/No/Match", + prompt_tokens=10, + completion_tokens=2, + ), + ] + ) + + summary = pricing.build_openrouter_reference_pricing_summary( + tracker=tracker, + phase_model_specs={ + "generation_model_A": "VLLM/Org/Example-Model", + "generation_model_B": "VLLM/No/Match", + "judge": "VLLM/No/Runtime", + }, + ) + + matched = summary["phases"]["generation_model_A"] + assert matched["openrouter_model_id"] == "openrouter/example-model" + assert matched["pricing_status"] == "matched_exact_openrouter_model_partial" + assert matched["prompt_tokens"] == 150 + assert matched["completion_tokens"] == 25 + assert matched["request_count"] == 2 + assert matched["openrouter_reference_cost_usd"] == 0.22 + assert matched["ignored_pricing_components"] == ["internal_reasoning"] + + unmatched = summary["phases"]["generation_model_B"] + assert unmatched["pricing_status"] == "no_exact_openrouter_match" + assert unmatched["openrouter_reference_cost_usd"] is None + + no_runtime = summary["phases"]["judge"] + assert no_runtime["pricing_status"] == "no_runtime_token_data" + assert no_runtime["total_tokens"] == 0 + + assert summary["total"]["openrouter_reference_cost_usd"] == 0.22 + + +def test_build_reference_pricing_summary_matches_quantized_local_variant( + monkeypatch, +): + catalog = pricing.parse_openrouter_catalog_payload( + { + "data": [ + { + "id": "qwen/qwen3.5-27b", + "canonical_slug": "qwen/qwen3.5-27b", + "hugging_face_id": "Qwen/Qwen3.5-27B", + "name": "Qwen: Qwen3.5-27B", + "pricing": { + "prompt": "0.001", + "completion": "0.002", + "request": "0.01", + }, + } + ], + "fetched_at_utc": "2026-04-15T00:00:00+00:00", + } + ) + monkeypatch.setattr( + pricing, + "load_openrouter_price_catalog_with_fallback", + lambda **kwargs: (catalog, None), + ) + + tracker = pricing.OpenRouterReferencePricingTracker() + tracker._records.append( + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Qwen/Qwen3.5-27B-FP8", + prompt_tokens=10, + completion_tokens=5, + ) + ) + + summary = pricing.build_openrouter_reference_pricing_summary( + tracker=tracker, + phase_model_specs={ + "generation_model_A": "VLLM/Qwen/Qwen3.5-27B-FP8", + "judge": "VLLM/No/Runtime", + }, + ) + + matched = summary["phases"]["generation_model_A"] + assert ( + matched["pricing_status"] + == "matched_openrouter_model_after_variant_normalization" + ) + assert matched["openrouter_model_id"] == "qwen/qwen3.5-27b" + assert matched["openrouter_reference_cost_usd"] == 0.03 + assert summary["exact_match_policy"]["fallback_normalizations"] == [ + "strip_common_local_quantization_suffixes" + ] + + +def test_write_run_metadata_includes_pricing_reference(tmp_path, monkeypatch): + monkeypatch.setattr( + "judgearena.repro._get_dependency_versions", + lambda *args, **kwargs: {}, + ) + monkeypatch.setattr("judgearena.repro._get_git_hash", lambda *args, **kwargs: None) + + metadata_path = write_run_metadata( + output_dir=tmp_path, + entrypoint="judgearena.test", + run={"dataset": "alpaca-eval"}, + pricing_reference={ + "pricing_model": "openrouter_reference", + "total": {"openrouter_reference_cost_usd": 1.23}, + }, + ) + + metadata = json.loads(metadata_path.read_text()) + assert metadata["pricing_reference"]["pricing_model"] == "openrouter_reference" + assert ( + metadata["pricing_reference"]["total"]["openrouter_reference_cost_usd"] == 1.23 + ) diff --git a/tests/test_regexp.py b/tests/test_regexp.py index 39efa41..f11c1cd 100644 --- a/tests/test_regexp.py +++ b/tests/test_regexp.py @@ -1,4 +1,5 @@ from judgearena.evaluate import PairScore +from judgearena.utils import strip_thinking_tags def test_pair_score(): @@ -38,3 +39,72 @@ def test_regexp(): assert pref == 0.5744425168116589 print(pref) + + +def test_pair_score_ignores_scores_inside_thinking_tags(): + raw_text = """ + + Early draft: + score_A: 2 + score_B: 1 + + Explanation: Assistant B is clearly better overall. + score_A: 0 + score_B: 10 + """ + + scorer = PairScore() + pref = scorer.parse_model_raw(raw_text) + + assert pref is not None + assert pref == 0.9525741268224333 + + +def test_pair_score_score_mode_does_not_parse_bracketed_verdicts(): + scorer = PairScore() + + assert scorer.parse_model_raw("Explanation: ok\n[[A]]") is None + assert scorer.parse_model_raw("Explanation: ok\n[[B]]") is None + assert scorer.parse_model_raw("Explanation: ok\n[[C]]") is None + + +def test_pair_score_score_mode_ignores_bracketed_verdict_after_thinking(): + raw_text = """ + + score_A: 0 + score_B: 10 + + Concise verdict only. + [[B]] + """ + + scorer = PairScore() + + assert scorer.parse_model_raw(raw_text) is None + + +def test_strip_thinking_tags_handles_closing_tag_without_opening_tag(): + raw_text = ( + "Reasoning that started implicitly and kept going.\n" + "Still reasoning.\n" + "\n" + "Final answer." + ) + + assert strip_thinking_tags(raw_text) == "Final answer." + + +def test_pair_score_verdict_mode_uses_bracketed_verdicts(): + raw_text = "score_A: 10\nscore_B: 0\n[[B]]" + + scorer = PairScore(parser_mode="verdict") + + assert scorer.parse_model_raw(raw_text) == 1.0 + + +def test_pair_score_verdict_mode_does_not_parse_score_only_outputs(): + raw_text = "score_A: 10\nscore_B: 0" + + scorer = PairScore(parser_mode="verdict") + + assert scorer.parse_model_raw(raw_text) is None diff --git a/tests/test_strip_thinking_carryover.py b/tests/test_strip_thinking_carryover.py new file mode 100644 index 0000000..7506313 --- /dev/null +++ b/tests/test_strip_thinking_carryover.py @@ -0,0 +1,173 @@ +"""Tests for the strip-thinking-before-char-cap fix applied to turn-1 +answers when constructing the turn-2 prompt in MT-Bench generation. + +Background: ``truncate_all_input_chars`` fires before the chat template +renders the turn-2 prompt. If the turn-1 answer contains a +``...`` block (Qwen3.5, SmolLM3 thinking mode) or a vLLM +forced-budget closer, a char cap landing inside the reasoning span would +destroy the ```` tag and force the full reasoning fragment into the +turn-2 context. Stripping the visible reasoning span first mirrors what +the Qwen3 chat template does natively for historical assistant turns and +keeps the cap on the visible answer. + +These tests pin down the composition that +``judgearena.generate.generate_multiturn`` performs: +``strip_thinking_tags_with_metadata`` -> ``truncate_with_metadata``. +""" + +from __future__ import annotations + +from dataclasses import replace + +from judgearena.generate_and_evaluate import CliArgs +from judgearena.mt_bench.mt_bench_utils import _mt_bench_generation_cache_name +from judgearena.utils import ( + VLLM_REASONING_END_STR, + strip_thinking_tags_with_metadata, + truncate_with_metadata, +) + + +def _strip_then_cap( + answer: str, cap: int, *, strip: bool = True +) -> tuple[str, bool, bool]: + """Reproduce the exact sequence inside ``generate_multiturn``'s turn-2 loop.""" + if strip: + stripped_text, thinking_stripped = strip_thinking_tags_with_metadata(answer) + else: + stripped_text, thinking_stripped = answer, False + truncated, was_truncated = truncate_with_metadata(stripped_text, max_len=cap) + return truncated, was_truncated, thinking_stripped + + +def test_well_formed_think_block_is_stripped_before_cap(): + """Nominal Qwen3.5 case: a complete ``...`` wrapper sits + in front of the visible answer. Stripping removes the whole span; the + char cap then applies to the visible answer only.""" + reasoning = "so let me think through this... " * 400 # ~12K chars + visible = "The capital of France is Paris." + answer = f"{reasoning}\n\n{visible}" + # Cap below the reasoning length but above the visible answer length + # so the old behaviour would have clipped inside . + cap = 1024 + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + assert "" not in truncated + assert "" not in truncated + + +def test_vllm_forced_thinking_budget_closer_is_stripped(): + """When the thinking budget is exhausted, vLLM inserts a forced closer + (``VLLM_REASONING_END_STR``) without a paired ```` opener. The + strip helper treats everything up to and including that marker as + reasoning and drops it before the cap fires.""" + forced_reasoning = "step 1... " * 500 # ~5K chars of runaway thought + visible = "Final answer: 42." + answer = f"{forced_reasoning}{VLLM_REASONING_END_STR}{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=256) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + + +def test_dangling_closing_tag_is_stripped(): + """Qwen3.5 sometimes emits ```` without a preceding ```` + opener (e.g. when the opener was chopped off during generation rollover). + The strip helper drops the preamble up to ```` and keeps the + postamble. Without this, the cap would land inside the dangling + preamble and the ```` closer would survive in the turn-2 + context, confusing the chat template.""" + preamble = "leftover reasoning fragment " * 100 + visible = "Answer: yes." + answer = f"{preamble}\n{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=512) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + + +def test_no_thinking_tags_passthrough(): + """Non-thinking models (e.g. EuroLLM, Apertus) produce answers without + any ```` markers. Strip is a no-op; the cap behaves exactly as + before the fix.""" + visible = "Paris is the capital of France. " * 50 # ~1.6K chars + cap = 512 + + truncated, was_truncated, thinking_stripped = _strip_then_cap(visible, cap) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated == visible[:cap] + + +def test_unclosed_think_block_is_unfixable_by_stripping(): + """Pathological case: the model writes ```` and hits the + generation limit before emitting ````. No ```` tag or + vLLM closer appears anywhere in the output, so the strip helper + returns the text unchanged and the cap still clips inside the + reasoning span. Stripping cannot fix this; the escape hatch is a + larger ``battle_thinking_token_budget``.""" + reasoning = "still reasoning " * 1000 + answer = f"{reasoning}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=256) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated.startswith("") + + +def test_strip_disabled_reverts_to_pre_fix_behaviour(): + """When turn-1 carryover stripping is disabled, the cap clips inside the + ```` block and the ```` closer is lost.""" + reasoning = "deep thinking " * 400 + visible = "Short answer." + answer = f"{reasoning}\n{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap( + answer, cap=1024, strip=False + ) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated.startswith("") + assert "" not in truncated + + +def _make_mt_bench_cli_args(**overrides) -> CliArgs: + args = CliArgs( + judge_model="OpenRouter/google/gemma-4-31b-it", + task="mt-bench", + model_A="VLLM/Qwen/Qwen3.5-9B", + model_B="VLLM/Qwen/Qwen3.5-9B", + n_instructions=3, + truncate_all_input_chars=30000, + max_out_tokens_models=49152, + max_model_len=57344, + battle_thinking_token_budget=32768, + ) + return replace(args, **overrides) if overrides else args + + +def test_mt_bench_cache_key_changes_when_flag_flipped(): + """Judge-side thinking stripping now also controls MT-Bench turn-1 + carryover stripping, so it must participate in the generation cache key.""" + args_on = _make_mt_bench_cli_args(strip_thinking_before_judging=True) + args_off = _make_mt_bench_cli_args(strip_thinking_before_judging=False) + + key_on = _mt_bench_generation_cache_name(args_on, model_name="VLLM/Qwen/Qwen3.5-9B") + key_off = _mt_bench_generation_cache_name( + args_off, model_name="VLLM/Qwen/Qwen3.5-9B" + ) + + assert key_on != key_off + assert key_on.startswith("mt-bench_VLLM/Qwen/Qwen3.5-9B_3_") + assert key_off.startswith("mt-bench_VLLM/Qwen/Qwen3.5-9B_3_") diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ca02f4..aea4990 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,164 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + import judgearena.utils as utils from judgearena.utils import make_model +def test_extract_ai_message_metadata_reads_finish_reason(): + ai_message = SimpleNamespace( + content="hi", + response_metadata={"finish_reason": "length", "stop_reason": None}, + ) + md = utils._extract_ai_message_metadata(ai_message) + assert md == {"finish_reason": "length", "stop_reason": None} + + +def test_extract_ai_message_metadata_handles_missing_response_metadata(): + bare_ai_message = SimpleNamespace(content="hello") + md = utils._extract_ai_message_metadata(bare_ai_message) + assert md == {"finish_reason": None, "stop_reason": None} + + +def test_extract_ai_message_metadata_handles_plain_dict_fallback(): + md = utils._extract_ai_message_metadata( + {"finish_reason": "stop", "stop_reason": "eos"} + ) + assert md == {"finish_reason": "stop", "stop_reason": "eos"} + + +def test_do_inference_async_path_propagates_finish_reason(monkeypatch): + async_results = [ + SimpleNamespace( + content="out1", + response_metadata={"finish_reason": "stop"}, + ), + SimpleNamespace( + content="out2", + response_metadata={"finish_reason": "length"}, + ), + ] + + async def fake_ainvoke(_input, **_kwargs): + return async_results.pop(0) + + chat_model = SimpleNamespace(ainvoke=fake_ainvoke) + texts, metadata = utils.do_inference( + chat_model=chat_model, + inputs=["prompt1", "prompt2"], + use_tqdm=True, + return_metadata=True, + ) + assert texts == ["out1", "out2"] + assert metadata == [ + {"finish_reason": "stop", "stop_reason": None}, + {"finish_reason": "length", "stop_reason": None}, + ] + + +def _build_inflight_tracking_chat_model(*, hold_seconds: float = 0.05): + """Helper: mock chat model whose `ainvoke` records peak concurrent in-flight calls.""" + + state = {"in_flight": 0, "peak": 0} + + async def fake_ainvoke(input_item, **_kwargs): + state["in_flight"] += 1 + state["peak"] = max(state["peak"], state["in_flight"]) + try: + await asyncio.sleep(hold_seconds) + return SimpleNamespace( + content=f"out-{input_item}", + response_metadata={"finish_reason": "stop"}, + ) + finally: + state["in_flight"] -= 1 + + return SimpleNamespace(ainvoke=fake_ainvoke), state + + +def test_do_inference_async_path_respects_concurrency_cap(monkeypatch): + """With JUDGEARENA_JUDGE_MAX_CONCURRENCY=4 and 16 inputs, peak in-flight must stay <= 4.""" + monkeypatch.setenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", "4") + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] <= 4, ( + f"Concurrency cap violated: peak in-flight={state['peak']}, expected <= 4" + ) + assert state["peak"] >= 1 + + +def test_do_inference_async_path_unbounded_when_env_unset(monkeypatch): + """Without JUDGEARENA_JUDGE_MAX_CONCURRENCY set, all 16 calls fire concurrently.""" + monkeypatch.delenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", raising=False) + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] > 4, ( + f"Expected unbounded concurrency to overshoot the capped variant; got peak={state['peak']}" + ) + + +def test_do_inference_async_path_zero_cap_is_unbounded(monkeypatch): + """JUDGEARENA_JUDGE_MAX_CONCURRENCY=0 falls back to unbounded (defensive default).""" + monkeypatch.setenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", "0") + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] > 4 + + +def test_do_inference_batch_path_propagates_finish_reason_without_batch_with_metadata(): + batch_results = [ + SimpleNamespace( + content="a", + response_metadata={"finish_reason": "stop"}, + ), + SimpleNamespace( + content="b", + response_metadata={"finish_reason": "length"}, + ), + ] + chat_model = MagicMock() + chat_model.batch = MagicMock(return_value=batch_results) + # Ensure no batch_with_metadata attr so the else branch runs + if hasattr(chat_model, "batch_with_metadata"): + del chat_model.batch_with_metadata + + texts, metadata = utils.do_inference( + chat_model=chat_model, + inputs=["p1", "p2"], + use_tqdm=False, + return_metadata=True, + ) + assert [m["finish_reason"] for m in metadata] == ["stop", "length"] + assert texts == ["a", "b"] + + def test_download_all_dispatches_arena_hard_versions(monkeypatch, tmp_path): calls: list[tuple[str, str, object]] = [] @@ -29,13 +186,14 @@ def test_download_all_dispatches_arena_hard_versions(monkeypatch, tmp_path): utils.download_all() tables_dir = tmp_path / "tables" - assert calls[:4] == [ + assert calls[:5] == [ ("hf", "alpaca-eval", tables_dir), ("arena", "arena-hard-v0.1", tables_dir), ("arena", "arena-hard-v2.0", tables_dir), - ("hf", "m-arena-hard", tables_dir), + ("hf", "m-arena-hard-v0.1", tables_dir), + ("hf", "m-arena-hard-v2.0", tables_dir), ] - assert calls[4] == ( + assert calls[5] == ( "snapshot", "geoalgo/multilingual-contexts-to-be-completed", tmp_path / "contexts", @@ -64,3 +222,75 @@ def test_make_model_openrouter_strips_vllm_only_kwargs(monkeypatch): assert "chat_template" not in model.model_kwargs assert model.max_tokens == 16 assert model.temperature == 0.5 + + +def test_init_llm_with_retry_recovers_from_transient_cuda_error(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 3) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + monkeypatch.setattr(utils.time, "sleep", lambda *_a, **_k: None) + + calls: list[dict] = [] + + def fake_llm(**kwargs): + calls.append(kwargs) + if len(calls) < 3: + raise RuntimeError( + "CUDA error: CUDA-capable device(s) is/are busy or unavailable\n" + "Search for 'cudaErrorDevicesUnavailable' ..." + ) + return "llm" + + result = utils._init_llm_with_retry(fake_llm, model="m", trust_remote_code=True) + assert result == "llm" + assert len(calls) == 3 + + +def test_init_llm_with_retry_gives_up_after_max_attempts(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 2) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + monkeypatch.setattr(utils.time, "sleep", lambda *_a, **_k: None) + + def always_fails(**_kwargs): + raise RuntimeError("cudaErrorDevicesUnavailable") + + with pytest.raises(RuntimeError, match="cudaErrorDevicesUnavailable"): + utils._init_llm_with_retry(always_fails, model="m") + + +def test_init_llm_with_retry_reraises_non_matching_errors_immediately(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 4) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + + call_count = 0 + + def fails_once(**_kwargs): + nonlocal call_count + call_count += 1 + raise ValueError("bad config") + + with pytest.raises(ValueError, match="bad config"): + utils._init_llm_with_retry(fails_once, model="m") + assert call_count == 1 + + +@pytest.mark.parametrize( + "message", + [ + "CUDA error: unknown error", + "NCCL error", + ], +) +def test_init_llm_with_retry_does_not_retry_broad_runtime_errors(monkeypatch, message): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 4) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + + call_count = 0 + + def fails_once(**_kwargs): + nonlocal call_count + call_count += 1 + raise RuntimeError(message) + + with pytest.raises(RuntimeError, match=message): + utils._init_llm_with_retry(fails_once, model="m") + assert call_count == 1