From c6b2b0a9e8fdeef03a531e3c45752f024cb5e175 Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Wed, 1 Apr 2026 01:15:30 +0200 Subject: [PATCH 01/28] update dependencies to support Qwen 3.5 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c4b20bc..037522c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,5 +81,5 @@ quote-style = "double" indent-style = "space" [project.optional-dependencies] -vllm = ["vllm==0.10.2", "transformers>=4.55.2,<5.0.0"] +vllm = ["vllm>=0.11.0,<1.0.0", "transformers>=5.2.0,<6.0.0"] llamacpp = ["llama-cpp-python>=0.3.0"] From 1f4bae81fdbc8118b32dc298ef04f361c720a771 Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Wed, 1 Apr 2026 02:01:32 +0200 Subject: [PATCH 02/28] slurmpilot scripts --- .../launch_generation_and_evaluation.py | 2 +- .../launch_kislurm_qwen35_smoke.py | 65 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 slurmpilot_scripts/launch_kislurm_qwen35_smoke.py diff --git a/slurmpilot_scripts/launch_generation_and_evaluation.py b/slurmpilot_scripts/launch_generation_and_evaluation.py index 6668a33..5782977 100644 --- a/slurmpilot_scripts/launch_generation_and_evaluation.py +++ b/slurmpilot_scripts/launch_generation_and_evaluation.py @@ -73,7 +73,7 @@ "dataset": f"{language}-contexts", "model_A": baseline, "model_B": model, - "judge_model": "VLLM/Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8", + "judge_model": "VLLM/Qwen/Qwen3.5-27B-FP8", "n_instructions": 100, # "ignore_cache": None, } diff --git a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py new file mode 100644 index 0000000..86704c4 --- /dev/null +++ b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py @@ -0,0 +1,65 @@ +from pathlib import Path + +from slurmpilot import JobCreationInfo, SlurmPilot, unify + +CLUSTER = "kislurm" +REMOTE_PROJECT_ROOT = Path("/work/dlclarge1/lushtake-hiwi/JudgeArena") +LOCAL_PROJECT_ROOT = Path(__file__).resolve().parent.parent +PYTHON_BINARY = REMOTE_PROJECT_ROOT / ".venv" / "bin" / "python" +ENTRYPOINT = "generate_and_evaluate.py" +SRC_DIR = str(LOCAL_PROJECT_ROOT / "judgearena") + +# Use L40S partitions from the all_dlc / ml_dlc families. +PARTITION_ALL_DLC_L40S = "testdlc2_gpu-l40s" +PARTITION_ML_DLC_L40S = "mldlc2_gpu-l40s" + +# Same weights as `VLLM/Qwen/Qwen3.5-27B-FP8`; repo-id loading fails offline in vLLM +# without a resolved revision — point at the HF hub snapshot dir under `HF_HOME`. +QWEN35_27B_FP8_SNAPSHOT = ( + "/work/dlclarge1/lushtake-hiwi/.cache/huggingface/hub/" + "models--Qwen--Qwen3.5-27B-FP8/snapshots/" + "2e1b21350ce589fcaafbb3c7d7eac526a7aed582" +) +JUDGE_MODEL = f"VLLM//{QWEN35_27B_FP8_SNAPSHOT.lstrip('/')}" + + +def submit_smoke_job(partition: str = PARTITION_ALL_DLC_L40S) -> tuple[str, str, int]: + slurm = SlurmPilot(clusters=[CLUSTER]) + dataset = "alpaca-eval" + jobname = unify("qwen3.5-smoke/judgearena-canonical", method="date") + + job_info = JobCreationInfo( + cluster=CLUSTER, + partition=partition, + jobname=jobname, + entrypoint=ENTRYPOINT, + python_binary=str(PYTHON_BINARY), + python_args={ + "dataset": dataset, + "model_A": "Dummy/no_answer", + "model_B": "Dummy/open_answer", + "judge_model": JUDGE_MODEL, + "n_instructions": 1, + "max_out_tokens_judge": 64, + }, + src_dir=SRC_DIR, + n_cpus=1, + max_runtime_minutes=20, + env={ + "HF_HUB_OFFLINE": "1", + # Ensure Hugging Face uses the shared cache location that + # already contains the Qwen3.5 FP8 checkpoint. + "HF_HOME": "/work/dlclarge1/lushtake-hiwi", + "JUDGEARENA_DATA": "/work/dlclarge1/lushtake-hiwi/judgearena-data", + }, + ) + job_id = slurm.schedule_job(job_info) + print(f"Submitted {dataset}: jobname={job_info.jobname}, job_id={job_id}") + return dataset, job_info.jobname, job_id + + +if __name__ == "__main__": + print(f"Using LOCAL_PROJECT_ROOT={LOCAL_PROJECT_ROOT}") + print(f"Using REMOTE_PROJECT_ROOT={REMOTE_PROJECT_ROOT}") + print(f"Using PYTHON_BINARY={PYTHON_BINARY}") + submit_smoke_job(partition=PARTITION_ALL_DLC_L40S) From 25b0355d127c81bb63b34b0e24586a2bc1cc2530 Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Wed, 1 Apr 2026 02:02:36 +0200 Subject: [PATCH 03/28] update dep versions --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 037522c..cc1c113 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,5 +81,6 @@ quote-style = "double" indent-style = "space" [project.optional-dependencies] -vllm = ["vllm>=0.11.0,<1.0.0", "transformers>=5.2.0,<6.0.0"] +# vLLM on PyPI pins transformers<5; optional extra matches that so `uv lock` can resolve. +vllm = ["vllm>=0.17.0,<1.0.0", "transformers>=4.56.0,<5.0.0"] llamacpp = ["llama-cpp-python>=0.3.0"] From ab065fd25a2dafeb58db06af617687f839398346 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Mon, 6 Apr 2026 22:43:53 +0200 Subject: [PATCH 04/28] fix support for VLLM - fix dependencies - add structured output to prevent judge from not respecting the prompt --- judgearena/evaluate.py | 52 +++++++---- judgearena/generate_and_evaluate.py | 44 ++++++++- judgearena/utils.py | 23 ++++- tests/test_local_completion_loading.py | 121 +++++++++++++++++++++++++ 4 files changed, 212 insertions(+), 28 deletions(-) create mode 100644 tests/test_local_completion_loading.py diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index c86d123..0733db2 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -17,6 +17,7 @@ do_inference, download_hf, read_df, + truncate, ) @@ -51,6 +52,18 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) +_PAIR_SCORE_MIN = 0 +_PAIR_SCORE_MAX = 10 + + +def build_pair_score_output_choices() -> list[str]: + return [ + f"score_A: {a}\nscore_B: {b}" + for a in range(_PAIR_SCORE_MIN, _PAIR_SCORE_MAX + 1) + for b in range(_PAIR_SCORE_MIN, _PAIR_SCORE_MAX + 1) + ] + + _COMPLETION_LABEL_SINGLE = "Answer" _COMPLETION_LABEL_MULTI_TURN = "Conversation with User" _EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" @@ -302,27 +315,30 @@ def annotate_battles( prompt_template = ChatPromptTemplate.from_messages( [("system", system_prompt), ("user", user_prompt_template)] ) - - def truncate(s: str, max_len: int | None = None): - if not isinstance(s, str): - return "" - if max_len is not None: - return s[:max_len] - else: - return s - - inputs = prompt_template.batch( - [ + truncated_completion_count = 0 + input_payloads = [] + for user_prompt, completion_A, completion_B in zip( + instructions, completions_A, completions_B, strict=True + ): + truncated_completion_A = truncate(completion_A, max_len=truncate_input_chars) + truncated_completion_B = truncate(completion_B, max_len=truncate_input_chars) + truncated_completion_count += int(truncated_completion_A != completion_A) + truncated_completion_count += int(truncated_completion_B != completion_B) + input_payloads.append( { "user_prompt": user_prompt, - "completion_A": truncate(completion_A, max_len=truncate_input_chars), - "completion_B": truncate(completion_B, max_len=truncate_input_chars), + "completion_A": truncated_completion_A, + "completion_B": truncated_completion_B, } - for user_prompt, completion_A, completion_B in zip( - instructions, completions_A, completions_B, strict=True - ) - ] - ) + ) + if truncated_completion_count: + print( + "Warning: truncated " + f"{truncated_completion_count} judge completions to " + f"{truncate_input_chars} characters before evaluation." + ) + inputs = prompt_template.batch(input_payloads) + print(f"Start LLM judge annotation ({len(inputs)} annotations).") judge_completions = do_inference( chat_model=judge_chat_model, diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 8502201..def7402 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -12,7 +12,11 @@ import pandas as pd -from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts +from judgearena.evaluate import ( + build_pair_score_output_choices, + judge_and_parse_prefs, + resolve_judge_prompts, +) from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions from judgearena.mt_bench.mt_bench_utils import run_mt_bench @@ -30,16 +34,40 @@ def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None ) -> pd.DataFrame | None: - """Try loading pre-existing completions from the dataset. + """Try loading pre-existing completions from the dataset or a local file. Some datasets (e.g. alpaca-eval) ship with completions for well-known models such as ``gpt4_1106_preview``. When ``model`` matches a column in ``model_outputs/{dataset}.csv.zip``, those completions are returned directly so that no model instantiation / generation is needed. + ``model`` may also be a local dataframe path. Local files must contain + ``instruction_index`` and ``output`` columns. + Returns a DataFrame with columns ``completion`` and ``instruction_index``, or ``None`` when no pre-existing completions are found. """ + local_path = Path(model) + if local_path.exists(): + print(f"Loading completions from local path '{local_path}'.") + df_outputs = read_df(local_path) + required_columns = {"instruction_index", "output"} + missing_columns = required_columns.difference(df_outputs.columns) + if missing_columns: + missing_columns_list = ", ".join(sorted(missing_columns)) + raise ValueError( + f"Local completion file '{local_path}' is missing required columns: " + f"{missing_columns_list}." + ) + + df_outputs = df_outputs.loc[:, ["instruction_index", "output"]].rename( + columns={"output": "completion"} + ) + df_outputs.loc[:, "completion"] = df_outputs.loc[:, "completion"].fillna("") + if n_instructions is not None: + df_outputs = df_outputs.head(n_instructions) + return df_outputs + local_path_tables = data_root / "tables" download_hf(name=dataset, local_path=local_path_tables) output_path = local_path_tables / "model_outputs" / f"{dataset}.csv.zip" @@ -337,7 +365,7 @@ def main(args: CliArgs): ) if dataset_completions_A is not None: completions_A = dataset_completions_A.set_index("instruction_index").loc[ - :, "completion" + instructions.index, "completion" ] else: completions_A = cache_function_dataframe( @@ -356,7 +384,7 @@ def main(args: CliArgs): ) if dataset_completions_B is not None: completions_B = dataset_completions_B.set_index("instruction_index").loc[ - :, "completion" + instructions.index, "completion" ] else: completions_B = cache_function_dataframe( @@ -377,12 +405,18 @@ def main(args: CliArgs): print(completions_B.values[0]) print(f"Evaluating completions with judge {args.judge_model}.") + judge_model_kwargs = dict(args.engine_kwargs) + if not args.provide_explanation and args.judge_model.split("/")[0] == "VLLM": + judge_model_kwargs["structured_outputs_choice"] = ( + build_pair_score_output_choices() + ) + judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, max_model_len=args.max_model_len, chat_template=args.chat_template, - **args.engine_kwargs, + **judge_model_kwargs, ) name = f"{args.dataset}-{args.model_A}-{args.model_B}-{args.judge_model}" diff --git a/judgearena/utils.py b/judgearena/utils.py index 4ecd801..4b31ea0 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -202,6 +202,7 @@ def __init__( **vllm_kwargs, ): from vllm import LLM, SamplingParams + from vllm.sampling_params import StructuredOutputsParams self.model_path = model self.max_tokens = max_tokens @@ -230,13 +231,19 @@ def __init__( RuntimeWarning, stacklevel=2, ) + self._sampling_params_kwargs = { + "max_tokens": max_tokens, + "temperature": float(vllm_kwargs.pop("temperature", 0.6)), + "top_p": float(vllm_kwargs.pop("top_p", 0.95)), + } + structured_outputs_choice = vllm_kwargs.pop("structured_outputs_choice", None) + if structured_outputs_choice is not None: + self._sampling_params_kwargs["structured_outputs"] = ( + StructuredOutputsParams(choice=structured_outputs_choice) + ) + self.sampling_params = SamplingParams(**self._sampling_params_kwargs) self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) - self.sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=0.6, - top_p=0.95, - ) # Resolve chat template: # 1. Explicit override always wins → use chat() with that template @@ -262,6 +269,12 @@ def __init__( self._use_generate = False print(f"ChatVLLM: using tokenizer's chat template for '{model}'") + def set_temperature(self, temperature: float) -> None: + from vllm import SamplingParams + + self._sampling_params_kwargs["temperature"] = float(temperature) + self.sampling_params = SamplingParams(**self._sampling_params_kwargs) + def _to_messages(self, input_item) -> list[dict]: """Convert LangChain prompt input to OpenAI-style messages.""" # Map LangChain message types to OpenAI roles diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py new file mode 100644 index 0000000..531796b --- /dev/null +++ b/tests/test_local_completion_loading.py @@ -0,0 +1,121 @@ +import pandas as pd + +import judgearena.evaluate as evaluate +import judgearena.generate_and_evaluate as generate_and_evaluate +from judgearena.generate_and_evaluate import CliArgs +from judgearena.generate_and_evaluate import main as main_generate_and_eval + + +def test_build_pair_score_output_choices_covers_all_integer_pairs(): + choices = evaluate.build_pair_score_output_choices() + + assert len(choices) == 121 + assert len(set(choices)) == 121 + assert "score_A: 0\nscore_B: 0" in choices + assert "score_A: 10\nscore_B: 10" in choices + + +def test_main_aligns_local_reference_by_instruction_index(tmp_path, monkeypatch): + instructions = pd.DataFrame( + {"instruction": ["Instruction B", "Instruction A"]}, + index=pd.Index(["b", "a"], name="instruction_index"), + ) + reference_path = tmp_path / "m-arena-hard-en-reference.csv" + pd.DataFrame( + { + "instruction_index": ["a", "b"], + "output": ["Answer A", "Answer B"], + } + ).to_csv(reference_path, index=False) + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: ( + instructions.head(n_instructions) + if n_instructions is not None + else instructions + ), + ) + monkeypatch.setattr( + generate_and_evaluate, + "cache_function_dataframe", + lambda fun, **_kwargs: fun(), + ) + + captured = {} + + def fake_judge_and_parse_prefs( + *, + judge_chat_model, + instructions, + completions_A, + completions_B, + swap_mode, + provide_explanation, + system_prompt, + user_prompt_template, + truncate_input_chars, + use_tqdm, + ): + captured["instructions"] = instructions + captured["completions_A"] = completions_A + captured["completions_B"] = completions_B + annotations = [{"judge_completion": "score A: 0 score B: 10"}] * len( + instructions + ) + prefs = pd.Series([1.0] * len(instructions)) + return annotations, [], prefs + + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + fake_judge_and_parse_prefs, + ) + + prefs = main_generate_and_eval( + CliArgs( + dataset="m-arena-hard-en", + model_A="Dummy/no answer", + model_B=str(reference_path), + judge_model="Dummy/score A: 0 score B: 10", + n_instructions=2, + result_folder=str(tmp_path / "results"), + ) + ) + + assert captured["instructions"] == ["Instruction B", "Instruction A"] + assert captured["completions_A"] == ["no answer", "no answer"] + assert captured["completions_B"] == ["Answer B", "Answer A"] + assert prefs.tolist() == [1.0, 1.0] + + +def test_annotate_battles_warns_when_judge_completions_are_truncated( + monkeypatch, capsys +): + captured = {} + + def fake_do_inference(*, chat_model, inputs, use_tqdm): + captured["judge_prompt"] = inputs[0].to_messages()[1].content + return ["score_A: 0\nscore_B: 10"] + + monkeypatch.setattr(evaluate, "do_inference", fake_do_inference) + + annotations = evaluate.annotate_battles( + judge_chat_model=object(), + instructions=["Instruction"], + completions_A=["Answer A"], + completions_B=["Answer B"], + truncate_input_chars=3, + ) + + stdout = capsys.readouterr().out + assert ( + "Warning: truncated 2 judge completions to 3 characters before evaluation." + in stdout + ) + assert "Ans" in captured["judge_prompt"] + assert "Answer A" not in captured["judge_prompt"] + assert "Answer B" not in captured["judge_prompt"] + assert annotations[0].completion_A == "Answer A" + assert annotations[0].completion_B == "Answer B" From ef1c92ca192f10ac372bbbc59f388e83b8f83317 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Mon, 6 Apr 2026 23:08:13 +0200 Subject: [PATCH 05/28] remove qwen35 smoke launcher --- .../launch_kislurm_qwen35_smoke.py | 65 ------------------- 1 file changed, 65 deletions(-) delete mode 100644 slurmpilot_scripts/launch_kislurm_qwen35_smoke.py diff --git a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py deleted file mode 100644 index 86704c4..0000000 --- a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py +++ /dev/null @@ -1,65 +0,0 @@ -from pathlib import Path - -from slurmpilot import JobCreationInfo, SlurmPilot, unify - -CLUSTER = "kislurm" -REMOTE_PROJECT_ROOT = Path("/work/dlclarge1/lushtake-hiwi/JudgeArena") -LOCAL_PROJECT_ROOT = Path(__file__).resolve().parent.parent -PYTHON_BINARY = REMOTE_PROJECT_ROOT / ".venv" / "bin" / "python" -ENTRYPOINT = "generate_and_evaluate.py" -SRC_DIR = str(LOCAL_PROJECT_ROOT / "judgearena") - -# Use L40S partitions from the all_dlc / ml_dlc families. -PARTITION_ALL_DLC_L40S = "testdlc2_gpu-l40s" -PARTITION_ML_DLC_L40S = "mldlc2_gpu-l40s" - -# Same weights as `VLLM/Qwen/Qwen3.5-27B-FP8`; repo-id loading fails offline in vLLM -# without a resolved revision — point at the HF hub snapshot dir under `HF_HOME`. -QWEN35_27B_FP8_SNAPSHOT = ( - "/work/dlclarge1/lushtake-hiwi/.cache/huggingface/hub/" - "models--Qwen--Qwen3.5-27B-FP8/snapshots/" - "2e1b21350ce589fcaafbb3c7d7eac526a7aed582" -) -JUDGE_MODEL = f"VLLM//{QWEN35_27B_FP8_SNAPSHOT.lstrip('/')}" - - -def submit_smoke_job(partition: str = PARTITION_ALL_DLC_L40S) -> tuple[str, str, int]: - slurm = SlurmPilot(clusters=[CLUSTER]) - dataset = "alpaca-eval" - jobname = unify("qwen3.5-smoke/judgearena-canonical", method="date") - - job_info = JobCreationInfo( - cluster=CLUSTER, - partition=partition, - jobname=jobname, - entrypoint=ENTRYPOINT, - python_binary=str(PYTHON_BINARY), - python_args={ - "dataset": dataset, - "model_A": "Dummy/no_answer", - "model_B": "Dummy/open_answer", - "judge_model": JUDGE_MODEL, - "n_instructions": 1, - "max_out_tokens_judge": 64, - }, - src_dir=SRC_DIR, - n_cpus=1, - max_runtime_minutes=20, - env={ - "HF_HUB_OFFLINE": "1", - # Ensure Hugging Face uses the shared cache location that - # already contains the Qwen3.5 FP8 checkpoint. - "HF_HOME": "/work/dlclarge1/lushtake-hiwi", - "JUDGEARENA_DATA": "/work/dlclarge1/lushtake-hiwi/judgearena-data", - }, - ) - job_id = slurm.schedule_job(job_info) - print(f"Submitted {dataset}: jobname={job_info.jobname}, job_id={job_id}") - return dataset, job_info.jobname, job_id - - -if __name__ == "__main__": - print(f"Using LOCAL_PROJECT_ROOT={LOCAL_PROJECT_ROOT}") - print(f"Using REMOTE_PROJECT_ROOT={REMOTE_PROJECT_ROOT}") - print(f"Using PYTHON_BINARY={PYTHON_BINARY}") - submit_smoke_job(partition=PARTITION_ALL_DLC_L40S) From 32f2e7e39125bfd4aaceddebac9122bb6fa1904c Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 7 Apr 2026 16:23:24 +0200 Subject: [PATCH 06/28] use json schema structured outputs, tighten vllm range - Switch from choice-based structured outputs to JSON schema constraint - Tighten vllm version range from >=0.17.0,<1.0.0 to >=0.17.0,<0.19.0 --- judgearena/evaluate.py | 21 +++++++++++++++------ judgearena/generate_and_evaluate.py | 6 +++--- judgearena/utils.py | 6 +++--- pyproject.toml | 3 ++- tests/test_local_completion_loading.py | 17 ++++++++++------- 5 files changed, 33 insertions(+), 20 deletions(-) diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 0733db2..f65ee9c 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -56,12 +56,21 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): _PAIR_SCORE_MAX = 10 -def build_pair_score_output_choices() -> list[str]: - return [ - f"score_A: {a}\nscore_B: {b}" - for a in range(_PAIR_SCORE_MIN, _PAIR_SCORE_MAX + 1) - for b in range(_PAIR_SCORE_MIN, _PAIR_SCORE_MAX + 1) - ] +def build_pair_score_json_schema() -> dict: + score_field = { + "type": "integer", + "minimum": _PAIR_SCORE_MIN, + "maximum": _PAIR_SCORE_MAX, + } + return { + "type": "object", + "properties": { + "score_A": score_field, + "score_B": score_field, + }, + "required": ["score_A", "score_B"], + "additionalProperties": False, + } _COMPLETION_LABEL_SINGLE = "Answer" diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index def7402..57a9956 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -13,7 +13,7 @@ import pandas as pd from judgearena.evaluate import ( - build_pair_score_output_choices, + build_pair_score_json_schema, judge_and_parse_prefs, resolve_judge_prompts, ) @@ -407,8 +407,8 @@ def main(args: CliArgs): judge_model_kwargs = dict(args.engine_kwargs) if not args.provide_explanation and args.judge_model.split("/")[0] == "VLLM": - judge_model_kwargs["structured_outputs_choice"] = ( - build_pair_score_output_choices() + judge_model_kwargs["structured_outputs_json"] = ( + build_pair_score_json_schema() ) judge_chat_model = make_model( diff --git a/judgearena/utils.py b/judgearena/utils.py index 4b31ea0..5fdffde 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -236,10 +236,10 @@ def __init__( "temperature": float(vllm_kwargs.pop("temperature", 0.6)), "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } - structured_outputs_choice = vllm_kwargs.pop("structured_outputs_choice", None) - if structured_outputs_choice is not None: + structured_outputs_json = vllm_kwargs.pop("structured_outputs_json", None) + if structured_outputs_json is not None: self._sampling_params_kwargs["structured_outputs"] = ( - StructuredOutputsParams(choice=structured_outputs_choice) + StructuredOutputsParams(json=structured_outputs_json) ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) diff --git a/pyproject.toml b/pyproject.toml index cc1c113..887f124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,5 +82,6 @@ indent-style = "space" [project.optional-dependencies] # vLLM on PyPI pins transformers<5; optional extra matches that so `uv lock` can resolve. -vllm = ["vllm>=0.17.0,<1.0.0", "transformers>=4.56.0,<5.0.0"] +# Tested with vllm 0.18.1; StructuredOutputsParams(json=...) requires >= 0.17. +vllm = ["vllm>=0.17.0,<0.19.0", "transformers>=4.56.0,<5.0.0"] llamacpp = ["llama-cpp-python>=0.3.0"] diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index 531796b..8882906 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -6,13 +6,16 @@ from judgearena.generate_and_evaluate import main as main_generate_and_eval -def test_build_pair_score_output_choices_covers_all_integer_pairs(): - choices = evaluate.build_pair_score_output_choices() - - assert len(choices) == 121 - assert len(set(choices)) == 121 - assert "score_A: 0\nscore_B: 0" in choices - assert "score_A: 10\nscore_B: 10" in choices +def test_build_pair_score_json_schema_covers_valid_range(): + schema = evaluate.build_pair_score_json_schema() + + assert schema["type"] == "object" + assert set(schema["required"]) == {"score_A", "score_B"} + for key in ("score_A", "score_B"): + assert schema["properties"][key]["type"] == "integer" + assert schema["properties"][key]["minimum"] == 0 + assert schema["properties"][key]["maximum"] == 10 + assert schema["additionalProperties"] is False def test_main_aligns_local_reference_by_instruction_index(tmp_path, monkeypatch): From 5f2edf0df4b21263f1787a9bca120c7380c98e59 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 7 Apr 2026 16:32:36 +0200 Subject: [PATCH 07/28] fix formatting --- judgearena/generate_and_evaluate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 57a9956..a8412e2 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -407,9 +407,7 @@ def main(args: CliArgs): judge_model_kwargs = dict(args.engine_kwargs) if not args.provide_explanation and args.judge_model.split("/")[0] == "VLLM": - judge_model_kwargs["structured_outputs_json"] = ( - build_pair_score_json_schema() - ) + judge_model_kwargs["structured_outputs_json"] = build_pair_score_json_schema() judge_chat_model = make_model( model=args.judge_model, From cffb6ddc727a79a490655fd3493a8856f20d873e Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 8 Apr 2026 00:32:35 +0200 Subject: [PATCH 08/28] Fix Qwen3.5 with mt-bench --- judgearena/instruction_dataset/mt_bench.py | 157 +++++-- judgearena/mt_bench/fastchat_compat.py | 39 +- judgearena/mt_bench/mt_bench_utils.py | 74 +++- judgearena/slurm_costs.py | 481 +++++++++++++++++++++ judgearena/utils.py | 29 +- tests/test_mt_bench_downloads.py | 206 +++++++++ tests/test_slurm_costs.py | 77 ++++ 7 files changed, 1020 insertions(+), 43 deletions(-) create mode 100644 judgearena/slurm_costs.py create mode 100644 tests/test_slurm_costs.py diff --git a/judgearena/instruction_dataset/mt_bench.py b/judgearena/instruction_dataset/mt_bench.py index e2a4233..23aa0fe 100644 --- a/judgearena/instruction_dataset/mt_bench.py +++ b/judgearena/instruction_dataset/mt_bench.py @@ -7,12 +7,51 @@ from judgearena.utils import data_root +MT_BENCH_SPACE_ID = "lmsys/mt-bench" +MT_BENCH_QUESTION_PATTERN = "data/mt_bench/question.jsonl" +MT_BENCH_MODEL_ANSWER_DIR = Path("data") / "mt_bench" / "model_answer" FASTCHAT_GPT4_REFERENCE_URL = ( "https://raw.githubusercontent.com/lm-sys/FastChat/main/" "fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl" ) +def _normalize_question_id(question_id: object) -> object: + try: + return int(question_id) + except Exception: + return question_id + + +def _snapshot_mt_bench_files( + *, + local_dir: Path, + allow_patterns: list[str], + expected_path: Path, + description: str, +) -> None: + try: + snapshot_download( + repo_id=MT_BENCH_SPACE_ID, + repo_type="space", + allow_patterns=allow_patterns, + local_dir=local_dir, + force_download=False, + ) + except Exception as e: + raise RuntimeError( + f"Failed to download {description} from HuggingFace space " + f"'{MT_BENCH_SPACE_ID}'. If you're in an offline / restricted-network " + f"environment, pre-download the space snapshot and place the file at " + f"{expected_path}, or set OPENJURY_DATA to point to that directory." + ) from e + if not expected_path.exists(): + raise FileNotFoundError( + f"Could not locate {description} after download. " + f"Expected file at {expected_path}." + ) + + def _download_gpt4_references(local_dir: Path) -> Path | None: reference_dir = local_dir / "reference_answer" reference_dir.mkdir(parents=True, exist_ok=True) @@ -46,34 +85,103 @@ def download_mt_bench(local_dir: Path | None = None) -> tuple[Path, Path | None] question_path = local_dir / "data" / "mt_bench" / "question.jsonl" if not question_path.exists(): - try: - snapshot_download( - repo_id="lmsys/mt-bench", - repo_type="space", - allow_patterns=[ - "data/mt_bench/question.jsonl", - ], - local_dir=local_dir, - force_download=False, - ) - except Exception as e: - raise RuntimeError( - "Failed to download MT-Bench questions from HuggingFace space " - "'lmsys/mt-bench'. If you're in an offline / restricted-network " - "environment, pre-download the space snapshot and place the " - f"questions file at {question_path}, or set OPENJURY_DATA to " - "point to that directory." - ) from e - if not question_path.exists(): - raise FileNotFoundError( - "Could not locate MT-Bench questions after download. " - f"Expected file at {question_path}." + _snapshot_mt_bench_files( + local_dir=local_dir, + allow_patterns=[MT_BENCH_QUESTION_PATTERN], + expected_path=question_path, + description="MT-Bench questions", ) gpt4_reference_path = _download_gpt4_references(local_dir) return question_path, gpt4_reference_path +def download_mt_bench_model_answer( + model_id: str, local_dir: Path | None = None +) -> Path: + """Download a cached MT-Bench baseline answer file if missing.""" + if local_dir is None: + local_dir = data_root / "mt-bench" + answer_path = local_dir / MT_BENCH_MODEL_ANSWER_DIR / f"{model_id}.jsonl" + if answer_path.exists(): + return answer_path + answer_path.parent.mkdir(parents=True, exist_ok=True) + allow_pattern = (MT_BENCH_MODEL_ANSWER_DIR / f"{model_id}.jsonl").as_posix() + _snapshot_mt_bench_files( + local_dir=local_dir, + allow_patterns=[allow_pattern], + expected_path=answer_path, + description=f"MT-Bench model answers for '{model_id}'", + ) + return answer_path + + +def _extract_answer_turns(record: dict, source_name: str) -> tuple[object, list[str]]: + question_id = record.get("question_id", record.get("id")) + if question_id is None: + raise ValueError( + f"MT-Bench answer record from {source_name} is missing question_id/id." + ) + choices = record.get("choices") + if not (isinstance(choices, list) and choices): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} is " + "missing a non-empty choices list." + ) + first_choice = choices[0] + if not isinstance(first_choice, dict): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} has " + "a malformed first choice entry." + ) + turns = first_choice.get("turns") + if not isinstance(turns, list): + raise ValueError( + f"MT-Bench answer record for question {question_id} in {source_name} is " + "missing a turns list." + ) + return _normalize_question_id(question_id), turns + + +def load_mt_bench_model_answers( + model: str, + n_instructions: int | None = None, + local_dir: Path | None = None, +) -> pd.DataFrame | None: + """Load pre-generated MT-Bench answers from a local file or cached model id.""" + local_path = Path(model) + if local_path.exists(): + answer_path = local_path + elif "/" not in model: + answer_path = download_mt_bench_model_answer( + model_id=model, local_dir=local_dir + ) + else: + return None + + answer_records = pd.read_json(answer_path, lines=True).to_dict(orient="records") + rows = [] + for rec in answer_records: + question_id, turns = _extract_answer_turns(rec, str(answer_path)) + rows.append( + { + "instruction_index": question_id, + "completion_turn_1": turns[0] if len(turns) > 0 else "", + "completion_turn_2": turns[1] if len(turns) > 1 else "", + } + ) + + df_answers = pd.DataFrame(rows) + if df_answers.empty: + raise ValueError( + f"MT-Bench answer file {answer_path} did not contain any rows." + ) + df_answers.sort_values("instruction_index", inplace=True) + if n_instructions is not None: + df_answers = df_answers.head(n_instructions) + return df_answers + + def load_mt_bench() -> pd.DataFrame: """Load MT-Bench questions and reference answers. @@ -126,10 +234,7 @@ def load_mt_bench() -> pd.DataFrame: raise ValueError( f"MT-Bench question record missing question_id/id: keys={list(rec.keys())}" ) - try: - qid = int(qid_raw) - except Exception: - qid = qid_raw + qid = _normalize_question_id(qid_raw) category = rec.get("category") turns = rec.get("turns") diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 3b0e7ec..2920a57 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from typing import Any, Literal @@ -52,6 +52,16 @@ class FastChatPairwisePrompt: _USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" _USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" +_BRACKETED_VERDICT_INSTRUCTION = ( + "After providing your explanation, output your final verdict by strictly " + 'following this format: "[[A]]" if assistant A is better, "[[B]]" if ' + 'assistant B is better, and "[[C]]" for a tie.' +) +_PLAIN_VERDICT_INSTRUCTION = ( + 'Output only one final verdict token: "A" if assistant A is better, "B" ' + 'if assistant B is better, and "C" for a tie.' +) + def _load_prompt_text(filename: str) -> str: path = _PROMPTS_DIR / filename @@ -62,6 +72,20 @@ def _render_prompt_text(filename: str, **kwargs: str) -> str: return _load_prompt_text(filename).format(**kwargs) +def _structured_verdict_prompt( + prompt: FastChatPairwisePrompt, +) -> FastChatPairwisePrompt: + if _BRACKETED_VERDICT_INSTRUCTION not in prompt.system_prompt: + return prompt + return replace( + prompt, + system_prompt=prompt.system_prompt.replace( + _BRACKETED_VERDICT_INSTRUCTION, + _PLAIN_VERDICT_INSTRUCTION, + ), + ) + + def _build_system_prompt( *, user_subject: str, @@ -180,11 +204,12 @@ def _load_pairwise_prompt( def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: - if "[[A]]" in judgment: + stripped = judgment.strip() + if "[[A]]" in stripped or stripped == "A": return "A" - if "[[B]]" in judgment: + if "[[B]]" in stripped or stripped == "B": return "B" - if "[[C]]" in judgment: + if "[[C]]" in stripped or stripped == "C": return "tie" return "error" @@ -267,6 +292,7 @@ def _infer_by_prompt_groups( items: list[dict[str, Any]], use_tqdm: bool, swap_answers: bool, + constrained_plain_verdict: bool, ) -> list[str]: """Run judge inference, grouping by prompt variant for batching.""" grouped_indices = _group_indices_by_prompt(items) @@ -274,6 +300,8 @@ def _infer_by_prompt_groups( judgments: list[str] = [""] * len(items) for _prompt_name, idxs in grouped_indices.items(): prompt: FastChatPairwisePrompt = items[idxs[0]]["prompt"] + if constrained_plain_verdict: + prompt = _structured_verdict_prompt(prompt) prompt_template = ChatPromptTemplate.from_messages( [("system", prompt.system_prompt), ("user", prompt.user_prompt_template)] ) @@ -434,6 +462,7 @@ def judge_mt_bench_pairwise_fastchat( swap_mode: str, truncate_input_chars: int | None, use_tqdm: bool, + constrained_plain_verdict: bool = False, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: """Pairwise MT-Bench judging compatible with FastChat's `[[A]]/[[B]]/[[C]]` format.""" assert turns_mode in ("both", "single", "multi") @@ -456,6 +485,7 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=False, + constrained_plain_verdict=constrained_plain_verdict, ) g2_judgments: list[str] | None = None @@ -465,6 +495,7 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=True, + constrained_plain_verdict=constrained_plain_verdict, ) annotations: list[dict[str, Any]] = [] diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index b274f26..a012cf5 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -18,6 +18,7 @@ from judgearena.eval_utils import _compute_grouped_stats, print_results from judgearena.generate import generate_multiturn from judgearena.instruction_dataset import load_instructions +from judgearena.instruction_dataset.mt_bench import load_mt_bench_model_answers from judgearena.mt_bench.fastchat_compat import ( FASTCHAT_TEMPERATURE_CONFIG, judge_mt_bench_pairwise_fastchat, @@ -29,6 +30,25 @@ from judgearena.config import CliArgs +# Use distinct first tokens for constrained decoding. The shared `[[` prefix +# caused the MT-Bench judge to collapse to `[[A]]` on every comparison. +_MIN_MT_BENCH_JUDGE_TOKENS = 1024 + + +def _align_mt_bench_completions( + *, questions_df: pd.DataFrame, completions: pd.DataFrame, model_name: str +) -> pd.DataFrame: + indexed = completions.set_index("instruction_index") + missing_ids = questions_df.index.difference(indexed.index) + if not missing_ids.empty: + missing_ids_preview = ", ".join(str(x) for x in missing_ids[:5]) + raise ValueError( + f"MT-Bench completions for '{model_name}' are missing " + f"{len(missing_ids)} question(s). First missing ids: {missing_ids_preview}." + ) + return indexed.loc[questions_df.index] + + def _generate_mt_bench_completions( args: CliArgs, questions_df: pd.DataFrame, @@ -46,19 +66,33 @@ def _run_generation(model_name: str) -> pd.DataFrame: max_model_len=args.max_model_len, chat_template=args.chat_template, temperature_config=FASTCHAT_TEMPERATURE_CONFIG, + **args.engine_kwargs, ) - completions_a = cache_function_dataframe( - lambda: _run_generation(args.model_A), - ignore_cache=ignore_cache, - cache_name=f"{cache_prefix}_{args.model_A}_{args.n_instructions}", - ).set_index("instruction_index") + def _load_or_generate(model_name: str) -> pd.DataFrame: + loaded_answers = load_mt_bench_model_answers( + model_name, n_instructions=args.n_instructions + ) + if loaded_answers is not None: + print(f"Using pre-generated MT-Bench answers for '{model_name}'.") + return _align_mt_bench_completions( + questions_df=questions_df, + completions=loaded_answers, + model_name=model_name, + ) + generated_answers = cache_function_dataframe( + lambda: _run_generation(model_name), + ignore_cache=ignore_cache, + cache_name=f"{cache_prefix}_{model_name}_{args.n_instructions}", + ) + return _align_mt_bench_completions( + questions_df=questions_df, + completions=generated_answers, + model_name=model_name, + ) - completions_b = cache_function_dataframe( - lambda: _run_generation(args.model_B), - ignore_cache=ignore_cache, - cache_name=f"{cache_prefix}_{args.model_B}_{args.n_instructions}", - ).set_index("instruction_index") + completions_a = _load_or_generate(args.model_A) + completions_b = _load_or_generate(args.model_B) return completions_a, completions_b @@ -97,6 +131,7 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, + constrained_plain_verdict: bool, ) -> pd.Series: prefs, annotations, combined_metadata, num_inconsistent = ( judge_mt_bench_pairwise_fastchat( @@ -111,6 +146,7 @@ def _run_mt_bench_fastchat( swap_mode=args.swap_mode, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, + constrained_plain_verdict=constrained_plain_verdict, ) ) @@ -140,6 +176,19 @@ def _run_mt_bench_fastchat( def run_mt_bench(args: CliArgs, ignore_cache: bool): """MT-Bench pipeline with FastChat-compatible pairwise judging.""" + if args.swap_mode != "both": + print( + "MT-Bench requires swap_mode='both' to match FastChat and correct " + f"for position bias; overriding requested swap_mode='{args.swap_mode}'." + ) + args.swap_mode = "both" + if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: + print( + "MT-Bench judge prompts require room for explanation plus verdict; " + f"overriding max_out_tokens_judge from {args.max_out_tokens_judge} " + f"to {_MIN_MT_BENCH_JUDGE_TOKENS}." + ) + args.max_out_tokens_judge = _MIN_MT_BENCH_JUDGE_TOKENS questions_df = load_instructions("mt-bench", n_instructions=args.n_instructions) print( f"Generating multi-turn completions for MT-Bench with {args.model_A} and {args.model_B}." @@ -149,12 +198,16 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): questions_df=questions_df, ignore_cache=ignore_cache, ) + judge_model_kwargs = dict(args.engine_kwargs) + judge_model_kwargs["disable_thinking"] = True + judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, temperature=0.0, max_model_len=args.max_model_len, chat_template=args.chat_template, + **judge_model_kwargs, ) return _run_mt_bench_fastchat( args=args, @@ -162,4 +215,5 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, + constrained_plain_verdict=False, ) diff --git a/judgearena/slurm_costs.py b/judgearena/slurm_costs.py new file mode 100644 index 0000000..2f78726 --- /dev/null +++ b/judgearena/slurm_costs.py @@ -0,0 +1,481 @@ +from __future__ import annotations + +import argparse +import json +import subprocess +from collections.abc import Iterable +from dataclasses import asdict, dataclass +from pathlib import Path + +SACCT_FIELDS = ( + "JobID,JobName%100,Partition,Account,State,ElapsedRaw,Elapsed," + "AllocCPUS,AllocNodes,AllocTRES%100,ReqTRES%100" +) +RATE_METRIC_CHOICES = ( + "wall_hours", + "cpu_hours", + "gpu_hours", + "billing_hours", + "node_hours", +) + + +@dataclass(frozen=True) +class JobSource: + job_id: int + label: str + + +@dataclass(frozen=True) +class SacctAllocation: + allocation_id: str + root_job_id: int + job_name: str + partition: str + account: str + state: str + elapsed_seconds: int + elapsed: str + alloc_cpus: float + alloc_nodes: float + alloc_tres: dict[str, str] + req_tres: dict[str, str] + + +@dataclass(frozen=True) +class JobCostSummary: + job_id: int + label: str + partition: str + account: str + states: list[str] + allocation_count: int + wall_hours: float + cpu_hours: float + gpu_hours: float + billing_hours: float + node_hours: float + estimated_cost: float | None = None + + +def parse_tres_map(tres_spec: str) -> dict[str, str]: + values: dict[str, str] = {} + if not tres_spec: + return values + for raw_entry in tres_spec.split(","): + entry = raw_entry.strip() + if not entry or "=" not in entry: + continue + key, value = entry.split("=", 1) + values[key.strip()] = value.strip() + return values + + +def parse_elapsed_seconds(elapsed: str) -> int: + if not elapsed: + return 0 + n_days = 0 + time_part = elapsed + if "-" in elapsed: + days_part, time_part = elapsed.split("-", 1) + n_days = int(days_part) + hours_str, minutes_str, seconds_str = time_part.split(":") + return ( + n_days * 86400 + + int(hours_str) * 3600 + + int(minutes_str) * 60 + + int(seconds_str) + ) + + +def parse_sacct_allocations(sacct_output: str) -> list[SacctAllocation]: + allocations: list[SacctAllocation] = [] + for raw_line in sacct_output.splitlines(): + line = raw_line.strip() + if not line: + continue + parts = line.split("|") + if len(parts) != 11: + raise ValueError(f"Unexpected sacct row with {len(parts)} fields: {line}") + allocation_id = parts[0].strip() + root_job_text = allocation_id.split("_", 1)[0] + allocations.append( + SacctAllocation( + allocation_id=allocation_id, + root_job_id=int(root_job_text), + job_name=parts[1].strip(), + partition=parts[2].strip(), + account=parts[3].strip(), + state=parts[4].strip(), + elapsed_seconds=int(parts[5] or "0"), + elapsed=parts[6].strip(), + alloc_cpus=float(parts[7] or "0"), + alloc_nodes=float(parts[8] or "0"), + alloc_tres=parse_tres_map(parts[9]), + req_tres=parse_tres_map(parts[10]), + ) + ) + return allocations + + +def query_sacct_allocations(job_ids: Iterable[int]) -> list[SacctAllocation]: + unique_job_ids = [ + str(job_id) for job_id in dict.fromkeys(int(job_id) for job_id in job_ids) + ] + if not unique_job_ids: + return [] + try: + result = subprocess.run( + [ + "sacct", + "-X", + "--allocations", + "--parsable2", + "--noheader", + f"--format={SACCT_FIELDS}", + f"--jobs={','.join(unique_job_ids)}", + ], + check=False, + capture_output=True, + text=True, + ) + except FileNotFoundError as exc: + raise RuntimeError( + "Could not find `sacct`; run this on a machine with Slurm." + ) from exc + if result.returncode != 0: + message = ( + result.stderr.strip() or result.stdout.strip() or "unknown sacct error" + ) + raise RuntimeError(f"sacct failed: {message}") + return parse_sacct_allocations(result.stdout) + + +def load_job_source_from_path(job_path: str | Path) -> JobSource: + path = Path(job_path) + job_dir = path.parent if path.name == "jobid.json" else path + jobid_path = job_dir / "jobid.json" + if not jobid_path.is_file(): + raise FileNotFoundError(f"Missing jobid.json in {job_dir}") + job_id = int(json.loads(jobid_path.read_text())["jobid"]) + metadata_path = job_dir / "metadata.json" + if metadata_path.is_file(): + metadata = json.loads(metadata_path.read_text()) + label = str(metadata.get("jobname") or job_dir.name) + else: + label = job_dir.name + return JobSource(job_id=job_id, label=label) + + +def resolve_job_sources( + *, + job_ids: Iterable[int] | None = None, + job_paths: Iterable[str | Path] | None = None, +) -> list[JobSource]: + sources: dict[int, JobSource] = {} + ordered_ids: list[int] = [] + + for job_id in job_ids or []: + normalized_job_id = int(job_id) + if normalized_job_id in sources: + continue + sources[normalized_job_id] = JobSource( + job_id=normalized_job_id, + label=str(normalized_job_id), + ) + ordered_ids.append(normalized_job_id) + + for job_path in job_paths or []: + source = load_job_source_from_path(job_path) + if source.job_id not in sources: + ordered_ids.append(source.job_id) + sources[source.job_id] = source + + return [sources[job_id] for job_id in ordered_ids] + + +def _tres_quantity(tres_map: dict[str, str], key: str) -> float: + raw_value = tres_map.get(key) + if raw_value is None: + return 0.0 + numeric_chars: list[str] = [] + for char in raw_value: + if char.isdigit() or char in {".", "-"}: + numeric_chars.append(char) + continue + break + numeric_text = "".join(numeric_chars) + return float(numeric_text) if numeric_text else 0.0 + + +def summarize_job_costs( + sources: list[JobSource], + allocations: list[SacctAllocation], + *, + rate_metric: str | None = None, + hourly_rate: float | None = None, +) -> list[JobCostSummary]: + allocations_by_job_id: dict[int, list[SacctAllocation]] = { + source.job_id: [] for source in sources + } + for allocation in allocations: + if allocation.root_job_id in allocations_by_job_id: + allocations_by_job_id[allocation.root_job_id].append(allocation) + + missing_job_ids = [ + str(source.job_id) + for source in sources + if not allocations_by_job_id[source.job_id] + ] + if missing_job_ids: + raise RuntimeError( + "No sacct allocation rows returned for job IDs: " + + ", ".join(missing_job_ids) + ) + + summaries: list[JobCostSummary] = [] + for source in sources: + job_allocations = allocations_by_job_id[source.job_id] + wall_hours = sum(row.elapsed_seconds for row in job_allocations) / 3600.0 + cpu_hours = ( + sum(row.elapsed_seconds * row.alloc_cpus for row in job_allocations) + / 3600.0 + ) + gpu_hours = ( + sum( + row.elapsed_seconds * _tres_quantity(row.alloc_tres, "gres/gpu") + for row in job_allocations + ) + / 3600.0 + ) + billing_hours = ( + sum( + row.elapsed_seconds * _tres_quantity(row.alloc_tres, "billing") + for row in job_allocations + ) + / 3600.0 + ) + node_hours = ( + sum(row.elapsed_seconds * row.alloc_nodes for row in job_allocations) + / 3600.0 + ) + metric_value = ( + _summary_metric_value( + wall_hours=wall_hours, + cpu_hours=cpu_hours, + gpu_hours=gpu_hours, + billing_hours=billing_hours, + node_hours=node_hours, + rate_metric=rate_metric, + ) + if hourly_rate is not None + else None + ) + summaries.append( + JobCostSummary( + job_id=source.job_id, + label=source.label, + partition=",".join( + sorted({row.partition for row in job_allocations if row.partition}) + ), + account=",".join( + sorted({row.account for row in job_allocations if row.account}) + ), + states=sorted({row.state for row in job_allocations if row.state}), + allocation_count=len(job_allocations), + wall_hours=wall_hours, + cpu_hours=cpu_hours, + gpu_hours=gpu_hours, + billing_hours=billing_hours, + node_hours=node_hours, + estimated_cost=( + metric_value * hourly_rate + if metric_value is not None and hourly_rate is not None + else None + ), + ) + ) + return summaries + + +def _summary_metric_value( + *, + wall_hours: float, + cpu_hours: float, + gpu_hours: float, + billing_hours: float, + node_hours: float, + rate_metric: str | None, +) -> float: + metrics = { + "wall_hours": wall_hours, + "cpu_hours": cpu_hours, + "gpu_hours": gpu_hours, + "billing_hours": billing_hours, + "node_hours": node_hours, + } + if rate_metric is None: + raise ValueError("rate_metric must be set when hourly_rate is provided") + return metrics[rate_metric] + + +def total_summary( + summaries: list[JobCostSummary], *, hourly_rate: float | None = None +) -> JobCostSummary: + return JobCostSummary( + job_id=0, + label="TOTAL", + partition=",".join( + sorted({summary.partition for summary in summaries if summary.partition}) + ), + account=",".join( + sorted({summary.account for summary in summaries if summary.account}) + ), + states=sorted({state for summary in summaries for state in summary.states}), + allocation_count=sum(summary.allocation_count for summary in summaries), + wall_hours=sum(summary.wall_hours for summary in summaries), + cpu_hours=sum(summary.cpu_hours for summary in summaries), + gpu_hours=sum(summary.gpu_hours for summary in summaries), + billing_hours=sum(summary.billing_hours for summary in summaries), + node_hours=sum(summary.node_hours for summary in summaries), + estimated_cost=( + sum(summary.estimated_cost or 0.0 for summary in summaries) + if hourly_rate is not None + else None + ), + ) + + +def _format_float(value: float) -> str: + return f"{value:.3f}" + + +def _format_cost(value: float, currency: str) -> str: + return f"{currency} {value:.2f}" + + +def _tabular_rows( + summaries: list[JobCostSummary], *, currency: str, include_cost: bool +) -> list[dict[str, str]]: + rows: list[dict[str, str]] = [] + for summary in summaries: + row = { + "job": summary.label, + "job_id": str(summary.job_id), + "tasks": str(summary.allocation_count), + "state": ",".join(summary.states), + "gpu_h": _format_float(summary.gpu_hours), + "billing_h": _format_float(summary.billing_hours), + "cpu_h": _format_float(summary.cpu_hours), + "wall_h": _format_float(summary.wall_hours), + } + if include_cost and summary.estimated_cost is not None: + row["cost"] = _format_cost(summary.estimated_cost, currency) + rows.append(row) + return rows + + +def render_table(rows: list[dict[str, str]]) -> str: + if not rows: + return "" + headers = list(rows[0].keys()) + widths = {header: len(header) for header in headers} + for row in rows: + for header in headers: + widths[header] = max(widths[header], len(row[header])) + header_line = " ".join(f"{header:<{widths[header]}}" for header in headers) + separator_line = " ".join("-" * widths[header] for header in headers) + row_lines = [ + " ".join(f"{row[header]:<{widths[header]}}" for header in headers) + for row in rows + ] + return "\n".join([header_line, separator_line, *row_lines]) + + +def _summary_to_dict(summary: JobCostSummary) -> dict[str, object]: + return asdict(summary) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m judgearena.slurm_costs", + description="Summarize Slurm job usage and simple cost estimates.", + ) + parser.add_argument( + "--job-id", + action="append", + type=int, + default=[], + help="Root Slurm job ID to summarize. Repeatable.", + ) + parser.add_argument( + "--job-path", + action="append", + default=[], + help="Path to a slurmpilot job directory or its jobid.json. Repeatable.", + ) + parser.add_argument( + "--rate-metric", + choices=RATE_METRIC_CHOICES, + default="gpu_hours", + help="Metric used for the optional hourly rate conversion.", + ) + parser.add_argument( + "--hourly-rate", + type=float, + default=None, + help="Optional hourly rate applied to --rate-metric.", + ) + parser.add_argument( + "--currency", + default="EUR", + help="Currency label for the optional cost estimate.", + ) + parser.add_argument( + "--json", + action="store_true", + help="Print JSON instead of a text table.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + args = build_parser().parse_args(argv) + sources = resolve_job_sources(job_ids=args.job_id, job_paths=args.job_path) + if not sources: + raise SystemExit("Provide at least one --job-id or --job-path.") + allocations = query_sacct_allocations(source.job_id for source in sources) + summaries = summarize_job_costs( + sources, + allocations, + rate_metric=args.rate_metric, + hourly_rate=args.hourly_rate, + ) + total = total_summary(summaries, hourly_rate=args.hourly_rate) + if args.json: + payload = { + "jobs": [_summary_to_dict(summary) for summary in summaries], + "total": _summary_to_dict(total), + "rate_metric": args.rate_metric if args.hourly_rate is not None else None, + "hourly_rate": args.hourly_rate, + "currency": args.currency if args.hourly_rate is not None else None, + } + print(json.dumps(payload, indent=2, sort_keys=True)) + return 0 + + table_rows = _tabular_rows( + [*summaries, total], + currency=args.currency, + include_cost=args.hourly_rate is not None, + ) + print(render_table(table_rows)) + if args.hourly_rate is None: + print( + "\nNo hourly rate was provided. Pass --hourly-rate together with " + "--rate-metric to convert one of the reported hour metrics into money." + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/judgearena/utils.py b/judgearena/utils.py index 5fdffde..bb7aaa5 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -179,6 +179,15 @@ async def ainvoke(self, input, **invoke_kwargs): return self.message +_DISABLE_THINKING_PREFIX = "{%- set enable_thinking = false %}\n" + + +def _disable_thinking_chat_template(template: str) -> str: + if "{%- set enable_thinking = false" in template: + return template + return _DISABLE_THINKING_PREFIX + template + + class ChatVLLM: """VLLM wrapper that auto-detects whether to use chat() or generate(). @@ -206,6 +215,7 @@ def __init__( self.model_path = model self.max_tokens = max_tokens + disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) # Cap max_model_len to the model's max_position_embeddings so that # vLLM doesn't reject an overly large context window. @@ -250,7 +260,11 @@ def __init__( # 2. If tokenizer has one, use it → use chat() (pass None to vLLM) # 3. No template found → fall back to generate() for base models if chat_template: - self.chat_template = chat_template + self.chat_template = ( + _disable_thinking_chat_template(chat_template) + if disable_thinking + else chat_template + ) self._use_generate = False print(f"ChatVLLM: using explicit chat template for '{model}'") else: @@ -265,9 +279,18 @@ def __init__( self.chat_template = None self._use_generate = True else: - self.chat_template = None # let vLLM use the tokenizer's own + self.chat_template = ( + _disable_thinking_chat_template(tokenizer.chat_template) + if disable_thinking + else None + ) self._use_generate = False - print(f"ChatVLLM: using tokenizer's chat template for '{model}'") + if disable_thinking: + print( + f"ChatVLLM: using tokenizer chat template with thinking disabled for '{model}'" + ) + else: + print(f"ChatVLLM: using tokenizer's chat template for '{model}'") def set_temperature(self, temperature: float) -> None: from vllm import SamplingParams diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 75851a8..a31626f 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -1,4 +1,10 @@ +from types import SimpleNamespace + +import pandas as pd + import judgearena.instruction_dataset.mt_bench as mt_bench +import judgearena.mt_bench.fastchat_compat as fastchat_compat +import judgearena.mt_bench.mt_bench_utils as mt_bench_utils import judgearena.utils as utils @@ -62,3 +68,203 @@ def _contexts_snapshot_stub(**_kwargs): ] assert calls["contexts"] == 1 assert calls["mt_bench"] == 1 + + +def test_load_mt_bench_model_answers_reads_cached_baseline_file(tmp_path): + answer_path = tmp_path / "data" / "mt_bench" / "model_answer" / "gpt-4.jsonl" + answer_path.parent.mkdir(parents=True, exist_ok=True) + answer_path.write_text( + '{"question_id": 2, "choices": [{"turns": ["A2", "B2"]}]}\n' + '{"question_id": 1, "choices": [{"turns": ["A1"]}]}\n' + ) + + df_answers = mt_bench.load_mt_bench_model_answers("gpt-4", local_dir=tmp_path) + + assert df_answers.to_dict(orient="records") == [ + { + "instruction_index": 1, + "completion_turn_1": "A1", + "completion_turn_2": "", + }, + { + "instruction_index": 2, + "completion_turn_1": "A2", + "completion_turn_2": "B2", + }, + ] + + +def test_generate_mt_bench_completions_uses_pregenerated_baseline(monkeypatch): + questions_df = pd.DataFrame( + {"turn_1": ["Q1", "Q2"], "turn_2": ["Q1b", "Q2b"]}, + index=pd.Index([1, 2], name="instruction_index"), + ) + generated_models = [] + generation_kwargs = [] + + monkeypatch.setattr( + mt_bench_utils, "cache_function_dataframe", lambda fun, **_kwargs: fun() + ) + + def fake_generate_multiturn( + *, + questions, + model, + truncate_input_chars, + max_tokens, + use_tqdm, + max_model_len, + chat_template, + temperature_config, + **engine_kwargs, + ): + generated_models.append(model) + generation_kwargs.append(engine_kwargs) + return pd.DataFrame( + { + "instruction_index": [1, 2], + "completion_turn_1": ["Gen A1", "Gen A2"], + "completion_turn_2": ["Gen B1", "Gen B2"], + } + ) + + monkeypatch.setattr(mt_bench_utils, "generate_multiturn", fake_generate_multiturn) + monkeypatch.setattr( + mt_bench_utils, + "load_mt_bench_model_answers", + lambda model, n_instructions=None: ( + pd.DataFrame( + { + "instruction_index": [2, 1], + "completion_turn_1": ["Base A2", "Base A1"], + "completion_turn_2": ["Base B2", "Base B1"], + } + ) + if model == "gpt-4" + else None + ), + ) + + args = SimpleNamespace( + model_A="VLLM/example/model-a", + model_B="gpt-4", + n_instructions=2, + truncate_all_input_chars=8192, + max_out_tokens_models=1024, + use_tqdm=False, + max_model_len=16384, + chat_template=None, + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + completions_a, completions_b = mt_bench_utils._generate_mt_bench_completions( + args=args, + questions_df=questions_df, + ignore_cache=False, + ) + + assert generated_models == ["VLLM/example/model-a"] + assert generation_kwargs == [ + {"gpu_memory_utilization": 0.7, "language_model_only": True} + ] + assert completions_a.loc[1, "completion_turn_1"] == "Gen A1" + assert completions_b.loc[1, "completion_turn_1"] == "Base A1" + assert completions_b.loc[2, "completion_turn_2"] == "Base B2" + + +def test_parse_fastchat_verdict_accepts_plain_structured_labels(): + assert fastchat_compat._parse_fastchat_verdict("A") == "A" + assert fastchat_compat._parse_fastchat_verdict("B") == "B" + assert fastchat_compat._parse_fastchat_verdict("C") == "tie" + + +def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + + def fake_make_model( + *, model, max_tokens, temperature, max_model_len, chat_template, **kwargs + ): + captured["make_model"] = { + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "max_model_len": max_model_len, + "chat_template": chat_template, + "kwargs": kwargs, + } + return object() + + monkeypatch.setattr(mt_bench_utils, "make_model", fake_make_model) + + def fake_run_mt_bench_fastchat(**kwargs): + captured["run_mt_bench_fastchat"] = kwargs + return pd.Series( + kwargs["questions_df"].index.to_list(), + dtype=float, + ) + + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_fastchat", + fake_run_mt_bench_fastchat, + ) + + args = SimpleNamespace( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + truncate_all_input_chars=8192, + max_out_tokens_models=1024, + max_out_tokens_judge=256, + use_tqdm=False, + max_model_len=16384, + chat_template=None, + provide_explanation=False, + swap_mode="fixed", + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert args.swap_mode == "both" + assert args.max_out_tokens_judge == 1024 + assert captured["make_model"]["max_tokens"] == 1024 + assert captured["make_model"]["kwargs"] == { + "disable_thinking": True, + "gpu_memory_utilization": 0.7, + "language_model_only": True, + } + assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" + assert captured["run_mt_bench_fastchat"]["constrained_plain_verdict"] is False diff --git a/tests/test_slurm_costs.py b/tests/test_slurm_costs.py new file mode 100644 index 0000000..94f173d --- /dev/null +++ b/tests/test_slurm_costs.py @@ -0,0 +1,77 @@ +import json + +import pytest + +from judgearena.slurm_costs import ( + JobSource, + _tres_quantity, + load_job_source_from_path, + parse_elapsed_seconds, + parse_sacct_allocations, + parse_tres_map, + resolve_job_sources, + summarize_job_costs, + total_summary, +) + + +def test_parse_tres_map_preserves_entries_and_extracts_numeric_quantities(): + tres_map = parse_tres_map("billing=2,cpu=2,gres/gpu=1,mem=125G,node=1") + + assert tres_map["mem"] == "125G" + assert _tres_quantity(tres_map, "billing") == 2.0 + assert _tres_quantity(tres_map, "gres/gpu") == 1.0 + + +def test_parse_elapsed_seconds_supports_day_prefix(): + assert parse_elapsed_seconds("1-02:03:04") == 93784 + + +def test_load_job_source_from_path_uses_metadata_jobname(tmp_path): + job_dir = tmp_path / "bench" / "alpaca-eval-2026-04-06-16-25-10" + job_dir.mkdir(parents=True) + (job_dir / "jobid.json").write_text(json.dumps({"jobid": 28707665})) + (job_dir / "metadata.json").write_text( + json.dumps({"jobname": "bench/alpaca-eval-2026-04-06-16-25-10"}) + ) + + source = load_job_source_from_path(job_dir) + resolved = resolve_job_sources(job_ids=[28707665], job_paths=[job_dir]) + + assert source == JobSource( + job_id=28707665, + label="bench/alpaca-eval-2026-04-06-16-25-10", + ) + assert resolved == [source] + + +def test_summarize_job_costs_aggregates_job_arrays_and_rate_conversion(): + sacct_output = "\n".join( + [ + "28707665_0|bench/alpaca-eval|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|60|00:01:00|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", + "28707665_1|bench/alpaca-eval|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|90|00:01:30|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", + "28708344_0|bench/arena-hard|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|120|00:02:00|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", + ] + ) + allocations = parse_sacct_allocations(sacct_output) + sources = [ + JobSource(job_id=28707665, label="bench/alpaca-eval"), + JobSource(job_id=28708344, label="bench/arena-hard"), + ] + + summaries = summarize_job_costs( + sources, + allocations, + rate_metric="gpu_hours", + hourly_rate=3.5, + ) + total = total_summary(summaries, hourly_rate=3.5) + + assert summaries[0].allocation_count == 2 + assert summaries[0].gpu_hours == pytest.approx(150 / 3600) + assert summaries[0].billing_hours == pytest.approx(300 / 3600) + assert summaries[0].estimated_cost == pytest.approx((150 / 3600) * 3.5) + assert summaries[1].gpu_hours == pytest.approx(120 / 3600) + assert total.gpu_hours == pytest.approx(270 / 3600) + assert total.cpu_hours == pytest.approx(540 / 3600) + assert total.estimated_cost == pytest.approx((270 / 3600) * 3.5) From ac243aa64c71c65234cde3a07be065abe33489f6 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Mon, 13 Apr 2026 15:25:12 +0200 Subject: [PATCH 09/28] use latest vllm with thinking tokens limits and thinking field in the output --- judgearena/evaluate.py | 28 +++++- judgearena/generate_and_evaluate.py | 5 ++ judgearena/mt_bench/fastchat_compat.py | 32 ++++++- judgearena/mt_bench/mt_bench_utils.py | 24 +++++- judgearena/prompts/mt_bench/system-base.txt | 2 +- judgearena/prompts/prompt.txt | 11 ++- judgearena/utils.py | 93 +++++++++++++++----- pyproject.toml | 4 +- tests/test_chat_vllm.py | 96 +++++++++++++++++++++ tests/test_local_completion_loading.py | 68 ++++++++++++++- tests/test_mt_bench_downloads.py | 57 +++++++++++- tests/test_regexp.py | 17 ++++ 12 files changed, 400 insertions(+), 37 deletions(-) create mode 100644 tests/test_chat_vllm.py diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index f65ee9c..4c1b9f0 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -16,6 +16,7 @@ data_root, do_inference, download_hf, + extract_json_object, read_df, truncate, ) @@ -32,6 +33,13 @@ def preference_from_scores(self, score_a: float, score_b: float) -> float: ) def parse_model_raw(self, judge_completion: str) -> float | None: + json_payload = extract_json_object(judge_completion) + if json_payload is not None: + score_a = self._coerce_score(json_payload.get("score_A")) + score_b = self._coerce_score(json_payload.get("score_B")) + if score_a is not None and score_b is not None: + return float(self.preference_from_scores(score_a, score_b)) + # lower case to avoid confusion, e.g. when "a" is used instead of "A" score_a = self.get_regexp_match( judge_completion.lower(), r'score.*?a[": *\n]*(-?\d+)' @@ -44,6 +52,19 @@ def parse_model_raw(self, judge_completion: str) -> float | None: else: return float(self.preference_from_scores(score_a, score_b)) + def _coerce_score(self, value: object) -> float | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return float(value) + if isinstance(value, float) and value.is_integer(): + return value + if isinstance(value, str): + match = re.fullmatch(r"\s*(-?\d+)\s*", value) + if match is not None: + return float(match.group(1)) + return None + def get_regexp_match(self, s: str, regex: str, group_index: int = 1): m = re.search(re.compile(regex), s) if m is None: @@ -54,6 +75,7 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): _PAIR_SCORE_MIN = 0 _PAIR_SCORE_MAX = 10 +_PAIR_REASONING_MAX_CHARS = 384 def build_pair_score_json_schema() -> dict: @@ -65,10 +87,14 @@ def build_pair_score_json_schema() -> dict: return { "type": "object", "properties": { + "reasoning": { + "type": "string", + "maxLength": _PAIR_REASONING_MAX_CHARS, + }, "score_A": score_field, "score_B": score_field, }, - "required": ["score_A", "score_B"], + "required": ["reasoning", "score_A", "score_B"], "additionalProperties": False, } diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index a8412e2..2f1afc4 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -30,6 +30,8 @@ read_df, ) +_DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET = 128 + def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None @@ -408,6 +410,9 @@ def main(args: CliArgs): judge_model_kwargs = dict(args.engine_kwargs) if not args.provide_explanation and args.judge_model.split("/")[0] == "VLLM": judge_model_kwargs["structured_outputs_json"] = build_pair_score_json_schema() + judge_model_kwargs.setdefault( + "thinking_token_budget", _DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) judge_chat_model = make_model( model=args.judge_model, diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 2920a57..9923a4e 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -3,6 +3,7 @@ from __future__ import annotations import math +import re from dataclasses import dataclass, replace from pathlib import Path from typing import Any, Literal @@ -11,7 +12,7 @@ from langchain_core.prompts import ChatPromptTemplate from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows -from judgearena.utils import do_inference +from judgearena.utils import do_inference, extract_json_object, strip_thinking_tags FASTCHAT_TEMPERATURE_CONFIG: dict[str, float] = { "writing": 0.7, @@ -61,6 +62,7 @@ class FastChatPairwisePrompt: 'Output only one final verdict token: "A" if assistant A is better, "B" ' 'if assistant B is better, and "C" for a tie.' ) +_PARTIAL_JSON_VERDICT_RE = re.compile(r'"verdict"\s*:\s*"(?P[ABC])"') def _load_prompt_text(filename: str) -> str: @@ -204,7 +206,27 @@ def _load_pairwise_prompt( def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: - stripped = judgment.strip() + json_payload = extract_json_object(judgment) + if json_payload is not None: + verdict = json_payload.get("verdict") + if isinstance(verdict, str): + normalized = verdict.strip().upper() + if normalized == "A": + return "A" + if normalized == "B": + return "B" + if normalized in {"C", "TIE"}: + return "tie" + + partial_json_match = _PARTIAL_JSON_VERDICT_RE.search(judgment) + if partial_json_match is not None: + if partial_json_match.group("verdict") == "A": + return "A" + if partial_json_match.group("verdict") == "B": + return "B" + return "tie" + + stripped = strip_thinking_tags(judgment).strip() if "[[A]]" in stripped or stripped == "A": return "A" if "[[B]]" in stripped or stripped == "B": @@ -233,8 +255,12 @@ def _conservative_winner( Declare a winner only if the two orderings agree; otherwise treat as tie. """ - if g1 == "error" or g2 == "error": + if g1 == "error" and g2 == "error": return "error", False + if g1 == "error": + return g2, False + if g2 == "error": + return g1, False if g1 == g2: return g1, False return "tie", True diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index a012cf5..9070d1a 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -32,7 +32,24 @@ # Use distinct first tokens for constrained decoding. The shared `[[` prefix # caused the MT-Bench judge to collapse to `[[A]]` on every comparison. -_MIN_MT_BENCH_JUDGE_TOKENS = 1024 +_MIN_MT_BENCH_JUDGE_TOKENS = 2048 +_DEFAULT_MT_BENCH_JUDGE_THINKING_TOKEN_BUDGET = 192 +_MT_BENCH_REASONING_MAX_CHARS = 384 + + +def build_mt_bench_verdict_json_schema() -> dict: + return { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "maxLength": _MT_BENCH_REASONING_MAX_CHARS, + }, + "verdict": {"type": "string", "enum": ["A", "B", "C"]}, + }, + "required": ["reasoning", "verdict"], + "additionalProperties": False, + } def _align_mt_bench_completions( @@ -199,7 +216,10 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): ignore_cache=ignore_cache, ) judge_model_kwargs = dict(args.engine_kwargs) - judge_model_kwargs["disable_thinking"] = True + judge_model_kwargs["structured_outputs_json"] = build_mt_bench_verdict_json_schema() + judge_model_kwargs.setdefault( + "thinking_token_budget", _DEFAULT_MT_BENCH_JUDGE_THINKING_TOKEN_BUDGET + ) judge_chat_model = make_model( model=args.judge_model, diff --git a/judgearena/prompts/mt_bench/system-base.txt b/judgearena/prompts/mt_bench/system-base.txt index b4aff2e..8a2e41d 100644 --- a/judgearena/prompts/mt_bench/system-base.txt +++ b/judgearena/prompts/mt_bench/system-base.txt @@ -1 +1 @@ -Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie. +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Output your response as valid JSON with two keys: "reasoning" for a concise rationale under 300 characters and "verdict" with exactly one of "A", "B", or "C", where "A" means assistant A is better, "B" means assistant B is better, and "C" means a tie. diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index 38021e6..060bee2 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -12,10 +12,13 @@ # Your output ## Format description -Your output should follow this format: +Your output should be valid JSON with exactly these keys: ``` -score_A: -score_B: +{{ + "reasoning": "", + "score_A": , + "score_B": +}} ``` -## Your output, do not repeat the input above{explanation_suffix} +## Your output, do not repeat the input above diff --git a/judgearena/utils.py b/judgearena/utils.py index bb7aaa5..39dde72 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -1,9 +1,12 @@ import asyncio +import json import os +import re import time import warnings from collections.abc import Callable from pathlib import Path +from typing import Any import pandas as pd from huggingface_hub import snapshot_download @@ -66,6 +69,49 @@ def safe_text(value: object, truncate_chars: int | None) -> str: return truncate(str(value), max_len=truncate_chars) +_THINK_BLOCK_RE = re.compile(r".*?", re.IGNORECASE | re.DOTALL) +_JSON_CODE_FENCE_RE = re.compile( + r"```(?:json)?\s*(?P\{.*?\})\s*```", + re.IGNORECASE | re.DOTALL, +) + + +def strip_thinking_tags(text: str | None) -> str: + if not isinstance(text, str): + return "" + return _THINK_BLOCK_RE.sub("", text) + + +def extract_json_object(text: str | None) -> dict[str, Any] | None: + """Best-effort JSON object extraction from model output. + + Handles raw JSON, fenced JSON blocks, and outputs that still contain leaked + `...` sections ahead of the machine-readable payload. + """ + + cleaned = strip_thinking_tags(text).strip() + if not cleaned: + return None + + candidates = [cleaned] + fenced_match = _JSON_CODE_FENCE_RE.search(cleaned) + if fenced_match is not None: + candidates.insert(0, fenced_match.group("payload")) + + decoder = json.JSONDecoder() + for candidate in candidates: + for idx, char in enumerate(candidate): + if char != "{": + continue + try: + parsed, _end = decoder.raw_decode(candidate[idx:]) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + return None + + def compute_pref_summary(prefs: pd.Series) -> dict[str, float | int]: """Compute win/loss/tie stats for preference series (0=A, 0.5=tie, 1=B).""" prefs = pd.Series(prefs, dtype="float64") @@ -179,15 +225,6 @@ async def ainvoke(self, input, **invoke_kwargs): return self.message -_DISABLE_THINKING_PREFIX = "{%- set enable_thinking = false %}\n" - - -def _disable_thinking_chat_template(template: str) -> str: - if "{%- set enable_thinking = false" in template: - return template - return _DISABLE_THINKING_PREFIX + template - - class ChatVLLM: """VLLM wrapper that auto-detects whether to use chat() or generate(). @@ -211,11 +248,16 @@ def __init__( **vllm_kwargs, ): from vllm import LLM, SamplingParams + from vllm.config.reasoning import ReasoningConfig from vllm.sampling_params import StructuredOutputsParams self.model_path = model self.max_tokens = max_tokens disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) + thinking_token_budget = vllm_kwargs.pop("thinking_token_budget", None) + self._chat_template_kwargs = ( + {"enable_thinking": False} if disable_thinking else None + ) # Cap max_model_len to the model's max_position_embeddings so that # vLLM doesn't reject an overly large context window. @@ -246,6 +288,13 @@ def __init__( "temperature": float(vllm_kwargs.pop("temperature", 0.6)), "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } + if thinking_token_budget is not None: + vllm_kwargs.setdefault("reasoning_config", ReasoningConfig()) + if "qwen3" in model.lower(): + vllm_kwargs.setdefault("reasoning_parser", "qwen3") + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) structured_outputs_json = vllm_kwargs.pop("structured_outputs_json", None) if structured_outputs_json is not None: self._sampling_params_kwargs["structured_outputs"] = ( @@ -260,13 +309,14 @@ def __init__( # 2. If tokenizer has one, use it → use chat() (pass None to vLLM) # 3. No template found → fall back to generate() for base models if chat_template: - self.chat_template = ( - _disable_thinking_chat_template(chat_template) - if disable_thinking - else chat_template - ) + self.chat_template = chat_template self._use_generate = False - print(f"ChatVLLM: using explicit chat template for '{model}'") + if disable_thinking: + print( + f"ChatVLLM: using explicit chat template with thinking disabled for '{model}'" + ) + else: + print(f"ChatVLLM: using explicit chat template for '{model}'") else: tokenizer = self.llm.get_tokenizer() if not getattr(tokenizer, "chat_template", None): @@ -278,12 +328,14 @@ def __init__( ) self.chat_template = None self._use_generate = True + if disable_thinking: + warnings.warn( + f"Model '{model}' has no chat template, so disable_thinking " + "cannot be applied when falling back to llm.generate().", + stacklevel=2, + ) else: - self.chat_template = ( - _disable_thinking_chat_template(tokenizer.chat_template) - if disable_thinking - else None - ) + self.chat_template = None self._use_generate = False if disable_thinking: print( @@ -365,6 +417,7 @@ def batch(self, inputs: list, **invoke_kwargs) -> list[str]: self.sampling_params, add_generation_prompt=True, chat_template=self.chat_template, + chat_template_kwargs=self._chat_template_kwargs, ) return [out.outputs[0].text for out in outputs] diff --git a/pyproject.toml b/pyproject.toml index 887f124..19de68f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,6 @@ indent-style = "space" [project.optional-dependencies] # vLLM on PyPI pins transformers<5; optional extra matches that so `uv lock` can resolve. -# Tested with vllm 0.18.1; StructuredOutputsParams(json=...) requires >= 0.17. -vllm = ["vllm>=0.17.0,<0.19.0", "transformers>=4.56.0,<5.0.0"] +# JudgeArena relies on v0.19+ for Qwen3.5 thinking_token_budget support and FP8 fixes. +vllm = ["vllm>=0.19.0,<1.0.0", "transformers>=4.56.0,<5.0.0"] llamacpp = ["llama-cpp-python>=0.3.0"] diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py new file mode 100644 index 0000000..73304e6 --- /dev/null +++ b/tests/test_chat_vllm.py @@ -0,0 +1,96 @@ +import sys +from types import SimpleNamespace + +import judgearena.utils as utils + + +def _install_fake_vllm(monkeypatch): + captured = {} + + class FakeSamplingParams: + def __init__(self, **kwargs): + captured["sampling_kwargs"] = kwargs + + class FakeStructuredOutputsParams: + def __init__(self, **kwargs): + captured["structured_kwargs"] = kwargs + self.kwargs = kwargs + + class FakeReasoningConfig: + pass + + class FakeLLM: + def __init__(self, *, model, trust_remote_code, **kwargs): + captured["llm_init"] = { + "model": model, + "trust_remote_code": trust_remote_code, + "kwargs": kwargs, + } + + def get_tokenizer(self): + return SimpleNamespace(chat_template="{{ messages }}") + + def chat(self, messages, sampling_params, **kwargs): + captured["chat_call"] = { + "messages": messages, + "sampling_params": sampling_params, + "kwargs": kwargs, + } + return [SimpleNamespace(outputs=[SimpleNamespace(text="ok")])] + + monkeypatch.setitem( + sys.modules, + "vllm", + SimpleNamespace(LLM=FakeLLM, SamplingParams=FakeSamplingParams), + ) + monkeypatch.setitem( + sys.modules, + "vllm.sampling_params", + SimpleNamespace(StructuredOutputsParams=FakeStructuredOutputsParams), + ) + monkeypatch.setitem( + sys.modules, + "vllm.config.reasoning", + SimpleNamespace(ReasoningConfig=FakeReasoningConfig), + ) + return captured, FakeReasoningConfig + + +def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatch): + captured, fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="Qwen/Qwen3.5-27B-FP8", + max_tokens=32, + structured_outputs_json={ + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + }, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 + assert captured["structured_kwargs"]["json"]["type"] == "object" + llm_kwargs = captured["llm_init"]["kwargs"] + assert llm_kwargs["reasoning_parser"] == "qwen3" + assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) + + +def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-27B-FP8", + max_tokens=16, + disable_thinking=True, + gpu_memory_utilization=0.7, + ) + + outputs = chat_model.batch(["hello"]) + + assert outputs == ["ok"] + assert captured["chat_call"]["kwargs"]["chat_template_kwargs"] == { + "enable_thinking": False + } diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index 8882906..6a921d5 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -10,7 +10,11 @@ def test_build_pair_score_json_schema_covers_valid_range(): schema = evaluate.build_pair_score_json_schema() assert schema["type"] == "object" - assert set(schema["required"]) == {"score_A", "score_B"} + assert set(schema["required"]) == {"reasoning", "score_A", "score_B"} + assert schema["properties"]["reasoning"] == { + "type": "string", + "maxLength": evaluate._PAIR_REASONING_MAX_CHARS, + } for key in ("score_A", "score_B"): assert schema["properties"][key]["type"] == "integer" assert schema["properties"][key]["minimum"] == 0 @@ -93,6 +97,68 @@ def fake_judge_and_parse_prefs( assert prefs.tolist() == [1.0, 1.0] +def test_main_passes_json_schema_and_thinking_budget_to_vllm_judge( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = { + "model": model, + "max_tokens": max_tokens, + "max_model_len": max_model_len, + "chat_template": chat_template, + "kwargs": kwargs, + } + return object() + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + lambda **kwargs: ( + [{"judge_completion": '{"reasoning":"ok","score_A":1,"score_B":2}'}], + None, + pd.Series([1.0]), + ), + ) + + prefs = main_generate_and_eval( + CliArgs( + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [1.0] + assert captured["make_model"]["kwargs"]["structured_outputs_json"] == ( + evaluate.build_pair_score_json_schema() + ) + assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 128 + + def test_annotate_battles_warns_when_judge_completions_are_truncated( monkeypatch, capsys ): diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index a31626f..7bcadb0 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -178,6 +178,56 @@ def test_parse_fastchat_verdict_accepts_plain_structured_labels(): assert fastchat_compat._parse_fastchat_verdict("C") == "tie" +def test_parse_fastchat_verdict_accepts_json_and_strips_thinking(): + assert ( + fastchat_compat._parse_fastchat_verdict( + 'Need a longer chain.{"reasoning":"done","verdict":"B"}' + ) + == "B" + ) + assert ( + fastchat_compat._parse_fastchat_verdict( + '```json\n{"reasoning":"tie","verdict":"C"}\n```' + ) + == "tie" + ) + assert ( + fastchat_compat._parse_fastchat_verdict( + 'unfinished analysis {"reasoning":"cut short","verdict":"A"' + ) + == "A" + ) + + +def test_conservative_winner_uses_non_error_side_when_only_one_parse_fails(): + assert fastchat_compat._conservative_winner("model_A", "error") == ( + "model_A", + False, + ) + assert fastchat_compat._conservative_winner("error", "model_B") == ( + "model_B", + False, + ) + assert fastchat_compat._conservative_winner("error", "error") == ("error", False) + assert fastchat_compat._conservative_winner("model_A", "model_B") == ("tie", True) + + +def test_build_mt_bench_verdict_json_schema(): + schema = mt_bench_utils.build_mt_bench_verdict_json_schema() + + assert schema["type"] == "object" + assert set(schema["required"]) == {"reasoning", "verdict"} + assert schema["properties"]["reasoning"] == { + "type": "string", + "maxLength": mt_bench_utils._MT_BENCH_REASONING_MAX_CHARS, + } + assert schema["properties"]["verdict"] == { + "type": "string", + "enum": ["A", "B", "C"], + } + assert schema["additionalProperties"] is False + + def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): questions_df = pd.DataFrame( {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, @@ -259,12 +309,13 @@ def fake_run_mt_bench_fastchat(**kwargs): mt_bench_utils.run_mt_bench(args, ignore_cache=False) assert args.swap_mode == "both" - assert args.max_out_tokens_judge == 1024 - assert captured["make_model"]["max_tokens"] == 1024 + assert args.max_out_tokens_judge == 2048 + assert captured["make_model"]["max_tokens"] == 2048 assert captured["make_model"]["kwargs"] == { - "disable_thinking": True, "gpu_memory_utilization": 0.7, "language_model_only": True, + "structured_outputs_json": mt_bench_utils.build_mt_bench_verdict_json_schema(), + "thinking_token_budget": 192, } assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" assert captured["run_mt_bench_fastchat"]["constrained_plain_verdict"] is False diff --git a/tests/test_regexp.py b/tests/test_regexp.py index 39efa41..74e9407 100644 --- a/tests/test_regexp.py +++ b/tests/test_regexp.py @@ -38,3 +38,20 @@ def test_regexp(): assert pref == 0.5744425168116589 print(pref) + + +def test_pair_score_prefers_json_scores_over_reasoning_text(): + raw_text = """ + I would score assistant A as 2/10 if I stopped early. + { + "reasoning": "At first glance I might score assistant A as 2, but after comparing both answers carefully, assistant B is better.", + "score_A": 0, + "score_B": 10 + } + """ + + scorer = PairScore() + pref = scorer.parse_model_raw(raw_text) + + assert pref is not None + assert pref == 0.9525741268224333 From 319050d442ce87269040dc1b7e91bdabbcd32440 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 15 Apr 2026 00:05:56 +0200 Subject: [PATCH 10/28] thinking token handling improvements, mt-bench improvements, use mt-bench baseline from huggingface and update huggingface repo --- judgearena/evaluate.py | 36 +++---- judgearena/generate_and_evaluate.py | 11 ++- judgearena/instruction_dataset/__init__.py | 7 +- judgearena/instruction_dataset/m_arenahard.py | 4 +- judgearena/mt_bench/fastchat_compat.py | 97 ++++++++++--------- judgearena/mt_bench/mt_bench_utils.py | 59 +++++++---- judgearena/prompts/mt_bench/system-base.txt | 2 +- .../prompts/prompt-with-explanation.txt | 19 ++-- judgearena/prompts/prompt.txt | 1 - judgearena/utils.py | 22 ++++- tests/test_chat_vllm.py | 7 +- tests/test_local_completion_loading.py | 81 ++++++++++++++-- tests/test_mt_bench_downloads.py | 45 ++++++--- tests/test_regexp.py | 2 +- 14 files changed, 269 insertions(+), 124 deletions(-) diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 306a66a..984eebd 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -79,34 +79,39 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): _PAIR_SCORE_MIN = 0 _PAIR_SCORE_MAX = 10 -_PAIR_REASONING_MAX_CHARS = 384 +_PAIR_EXPLANATION_MAX_CHARS = 384 -def build_pair_score_json_schema() -> dict: +def build_pair_score_json_schema(*, include_explanation: bool = False) -> dict: score_field = { "type": "integer", "minimum": _PAIR_SCORE_MIN, "maximum": _PAIR_SCORE_MAX, } - return { - "type": "object", - "properties": { - "reasoning": { + properties: dict[str, object] = { + "score_A": score_field, + "score_B": score_field, + } + required = ["score_A", "score_B"] + if include_explanation: + properties = { + "explanation": { "type": "string", - "maxLength": _PAIR_REASONING_MAX_CHARS, + "maxLength": _PAIR_EXPLANATION_MAX_CHARS, }, - "score_A": score_field, - "score_B": score_field, - }, - "required": ["reasoning", "score_A", "score_B"], + **properties, + } + required = ["explanation", *required] + return { + "type": "object", + "properties": properties, + "required": required, "additionalProperties": False, } _COMPLETION_LABEL_SINGLE = "Answer" _COMPLETION_LABEL_MULTI_TURN = "Conversation with User" -_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" -_SCORE_FENCE = "\n```" def load_judge_system_and_user_prompt( @@ -124,11 +129,6 @@ def load_judge_system_and_user_prompt( "{completion_label}", _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, ) - user_prompt_template = user_prompt_template.replace( - "{explanation_suffix}", - _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, - ) - return system_prompt, user_prompt_template diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 6f84af3..109c648 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -27,6 +27,7 @@ from judgearena.mt_bench.mt_bench_utils import run_mt_bench from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, cache_function_dataframe, compute_pref_summary, data_root, @@ -35,8 +36,6 @@ read_df, ) -_DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET = 128 - def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None @@ -298,10 +297,12 @@ def main(args: CliArgs): print(f"Evaluating completions with judge {args.judge_model}.") judge_model_kwargs = dict(args.engine_kwargs) - if not args.provide_explanation and args.judge_model.split("/")[0] == "VLLM": - judge_model_kwargs["structured_outputs_json"] = build_pair_score_json_schema() + if args.judge_model.split("/")[0] == "VLLM": + judge_model_kwargs["structured_outputs_json"] = build_pair_score_json_schema( + include_explanation=args.provide_explanation + ) judge_model_kwargs.setdefault( - "thinking_token_budget", _DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET ) judge_chat_model = make_model( diff --git a/judgearena/instruction_dataset/__init__.py b/judgearena/instruction_dataset/__init__.py index 4681f0a..48fccd1 100644 --- a/judgearena/instruction_dataset/__init__.py +++ b/judgearena/instruction_dataset/__init__.py @@ -4,8 +4,6 @@ download_arena_hard, is_arena_hard_dataset, ) -from judgearena.instruction_dataset.m_arenahard import load_m_arenahard -from judgearena.utils import data_root, download_hf, read_df def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.DataFrame: @@ -15,6 +13,9 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat df_instructions = load_mt_bench() elif "m-arena-hard" in dataset: + from judgearena.instruction_dataset.m_arenahard import load_m_arenahard + from judgearena.utils import data_root + if dataset == "m-arena-hard": language = None else: @@ -62,6 +63,8 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat ) else: + from judgearena.utils import data_root, download_hf, read_df + assert dataset in [ "alpaca-eval", "arena-hard-v0.1", diff --git a/judgearena/instruction_dataset/m_arenahard.py b/judgearena/instruction_dataset/m_arenahard.py index 3d7b919..1c45f33 100644 --- a/judgearena/instruction_dataset/m_arenahard.py +++ b/judgearena/instruction_dataset/m_arenahard.py @@ -3,8 +3,6 @@ import pandas as pd from huggingface_hub import snapshot_download -from judgearena.utils import data_root - def load_m_arenahard(local_path, language: str | None = None): snapshot_download( @@ -54,4 +52,6 @@ def load_m_arenahard(local_path, language: str | None = None): if __name__ == "__main__": + from judgearena.utils import data_root + load_m_arenahard(local_path=data_root, language="EU") diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 9923a4e..41b1df6 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -4,7 +4,7 @@ import math import re -from dataclasses import dataclass, replace +from dataclasses import dataclass from pathlib import Path from typing import Any, Literal @@ -40,10 +40,22 @@ @dataclass(frozen=True) class FastChatPairwisePrompt: name: str - system_prompt: str + user_subject: str + task_description: str + begin_instruction: str user_prompt_template: str multi_turn: bool ref_based: bool + focus_line: str = "" + + def render_system_prompt(self, *, provide_explanation: bool) -> str: + return _build_system_prompt( + user_subject=self.user_subject, + task_description=self.task_description, + begin_instruction=self.begin_instruction, + focus_line=self.focus_line, + provide_explanation=provide_explanation, + ) _PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" @@ -53,15 +65,6 @@ class FastChatPairwisePrompt: _USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" _USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" -_BRACKETED_VERDICT_INSTRUCTION = ( - "After providing your explanation, output your final verdict by strictly " - 'following this format: "[[A]]" if assistant A is better, "[[B]]" if ' - 'assistant B is better, and "[[C]]" for a tie.' -) -_PLAIN_VERDICT_INSTRUCTION = ( - 'Output only one final verdict token: "A" if assistant A is better, "B" ' - 'if assistant B is better, and "C" for a tie.' -) _PARTIAL_JSON_VERDICT_RE = re.compile(r'"verdict"\s*:\s*"(?P[ABC])"') @@ -74,18 +77,12 @@ def _render_prompt_text(filename: str, **kwargs: str) -> str: return _load_prompt_text(filename).format(**kwargs) -def _structured_verdict_prompt( - prompt: FastChatPairwisePrompt, -) -> FastChatPairwisePrompt: - if _BRACKETED_VERDICT_INSTRUCTION not in prompt.system_prompt: - return prompt - return replace( - prompt, - system_prompt=prompt.system_prompt.replace( - _BRACKETED_VERDICT_INSTRUCTION, - _PLAIN_VERDICT_INSTRUCTION, - ), - ) +def _begin_instruction_for_mode( + begin_instruction: str, *, provide_explanation: bool +) -> str: + if provide_explanation: + return begin_instruction + return re.sub(r"\s+and provide a short explanation$", "", begin_instruction) def _build_system_prompt( @@ -94,14 +91,28 @@ def _build_system_prompt( task_description: str, begin_instruction: str, focus_line: str = "", + provide_explanation: bool, ) -> str: focus_segment = f"{focus_line} " if focus_line else "" + output_format_instruction = ( + 'Output your response as valid JSON with exactly two keys: "explanation" ' + 'for a concise rationale under 300 characters and "verdict" with exactly ' + 'one of "A", "B", or "C", where "A" means assistant A is better, "B" ' + 'means assistant B is better, and "C" means a tie.' + if provide_explanation + else 'Output your response as valid JSON with exactly one key: "verdict" ' + 'with exactly one of "A", "B", or "C", where "A" means assistant A is ' + 'better, "B" means assistant B is better, and "C" means a tie.' + ) return _render_prompt_text( _SYSTEM_BASE_FILE, user_subject=user_subject, task_description=task_description, focus_line=focus_segment, - begin_instruction=begin_instruction, + begin_instruction=_begin_instruction_for_mode( + begin_instruction, provide_explanation=provide_explanation + ), + output_format_instruction=output_format_instruction, ) @@ -128,18 +139,16 @@ def _load_pairwise_prompt( ) -> FastChatPairwisePrompt: return FastChatPairwisePrompt( name=name, + user_subject=system_user_subject, + task_description=system_task_description, + begin_instruction=system_begin_instruction, multi_turn=multi_turn, ref_based=ref_based, - system_prompt=_build_system_prompt( - user_subject=system_user_subject, - task_description=system_task_description, - begin_instruction=system_begin_instruction, - focus_line=system_focus_line, - ), user_prompt_template=_build_user_prompt_template( multi_turn=multi_turn, ref_based=ref_based, ), + focus_line=system_focus_line, ) @@ -255,12 +264,8 @@ def _conservative_winner( Declare a winner only if the two orderings agree; otherwise treat as tie. """ - if g1 == "error" and g2 == "error": + if g1 == "error" or g2 == "error": return "error", False - if g1 == "error": - return g2, False - if g2 == "error": - return g1, False if g1 == g2: return g1, False return "tie", True @@ -318,7 +323,7 @@ def _infer_by_prompt_groups( items: list[dict[str, Any]], use_tqdm: bool, swap_answers: bool, - constrained_plain_verdict: bool, + provide_explanation: bool, ) -> list[str]: """Run judge inference, grouping by prompt variant for batching.""" grouped_indices = _group_indices_by_prompt(items) @@ -326,10 +331,11 @@ def _infer_by_prompt_groups( judgments: list[str] = [""] * len(items) for _prompt_name, idxs in grouped_indices.items(): prompt: FastChatPairwisePrompt = items[idxs[0]]["prompt"] - if constrained_plain_verdict: - prompt = _structured_verdict_prompt(prompt) + system_prompt = prompt.render_system_prompt( + provide_explanation=provide_explanation + ) prompt_template = ChatPromptTemplate.from_messages( - [("system", prompt.system_prompt), ("user", prompt.user_prompt_template)] + [("system", system_prompt), ("user", prompt.user_prompt_template)] ) batch_kwargs = [] @@ -421,12 +427,14 @@ def _resolve_fastchat_item_result( judge_model: str, model_a: str, model_b: str, + provide_explanation: bool, ) -> tuple[dict[str, Any], dict[str, object], float, bool]: prompt: FastChatPairwisePrompt = item["prompt"] kwargs = item["prompt_kwargs"] g1_user_prompt = prompt.user_prompt_template.format(**kwargs) g1_verdict = _parse_fastchat_verdict(g1_raw) g1_winner = _map_verdict_to_winner(g1_verdict, swapped=False) + system_prompt = prompt.render_system_prompt(provide_explanation=provide_explanation) final_winner = g1_winner inconsistent = False @@ -438,7 +446,7 @@ def _resolve_fastchat_item_result( "model_B": model_b, "judge": judge_model, "prompt_name": prompt.name, - "system_prompt": prompt.system_prompt, + "system_prompt": system_prompt, "g1_user_prompt": g1_user_prompt, "g1_judgment": g1_raw, "g1_verdict": g1_verdict, @@ -488,9 +496,9 @@ def judge_mt_bench_pairwise_fastchat( swap_mode: str, truncate_input_chars: int | None, use_tqdm: bool, - constrained_plain_verdict: bool = False, + provide_explanation: bool = False, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: - """Pairwise MT-Bench judging compatible with FastChat's `[[A]]/[[B]]/[[C]]` format.""" + """Run FastChat-style MT-Bench pairwise judging with JSON verdict outputs.""" assert turns_mode in ("both", "single", "multi") assert swap_mode in ("fixed", "both") @@ -511,7 +519,7 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=False, - constrained_plain_verdict=constrained_plain_verdict, + provide_explanation=provide_explanation, ) g2_judgments: list[str] | None = None @@ -521,7 +529,7 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=True, - constrained_plain_verdict=constrained_plain_verdict, + provide_explanation=provide_explanation, ) annotations: list[dict[str, Any]] = [] @@ -539,6 +547,7 @@ def judge_mt_bench_pairwise_fastchat( judge_model=judge_model, model_a=model_a, model_b=model_b, + provide_explanation=provide_explanation, ) ) if inconsistent: diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index d71ebbb..846077d 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -24,30 +24,43 @@ judge_mt_bench_pairwise_fastchat, ) from judgearena.repro import _to_jsonable -from judgearena.utils import cache_function_dataframe, compute_pref_summary, make_model +from judgearena.utils import ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, + cache_function_dataframe, + compute_pref_summary, + make_model, +) if TYPE_CHECKING: from judgearena.generate_and_evaluate import CliArgs -# Use distinct first tokens for constrained decoding. The shared `[[` prefix -# caused the MT-Bench judge to collapse to `[[A]]` on every comparison. +# MT-Bench judge prompts need headroom for budgeted thinking and the final JSON. _MIN_MT_BENCH_JUDGE_TOKENS = 2048 -_DEFAULT_MT_BENCH_JUDGE_THINKING_TOKEN_BUDGET = 192 -_MT_BENCH_REASONING_MAX_CHARS = 384 +_MT_BENCH_EXPLANATION_MAX_CHARS = 384 -def build_mt_bench_verdict_json_schema() -> dict: - return { - "type": "object", - "properties": { - "reasoning": { +def build_mt_bench_verdict_json_schema( + *, include_explanation: bool = False +) -> dict[str, object]: + """Return the MT-Bench judge schema for verdict-only or verdict+explanation.""" + properties: dict[str, object] = { + "verdict": {"type": "string", "enum": ["A", "B", "C"]}, + } + required = ["verdict"] + if include_explanation: + properties = { + "explanation": { "type": "string", - "maxLength": _MT_BENCH_REASONING_MAX_CHARS, + "maxLength": _MT_BENCH_EXPLANATION_MAX_CHARS, }, - "verdict": {"type": "string", "enum": ["A", "B", "C"]}, - }, - "required": ["reasoning", "verdict"], + **properties, + } + required = ["explanation", *required] + return { + "type": "object", + "properties": properties, + "required": required, "additionalProperties": False, } @@ -55,6 +68,7 @@ def build_mt_bench_verdict_json_schema() -> dict: def _align_mt_bench_completions( *, questions_df: pd.DataFrame, completions: pd.DataFrame, model_name: str ) -> pd.DataFrame: + """Align cached or generated MT-Bench completions to the question order.""" indexed = completions.set_index("instruction_index") missing_ids = questions_df.index.difference(indexed.index) if not missing_ids.empty: @@ -71,6 +85,7 @@ def _generate_mt_bench_completions( questions_df: pd.DataFrame, ignore_cache: bool, ) -> tuple[pd.DataFrame, pd.DataFrame]: + """Load baseline MT-Bench answers or generate fresh multi-turn outputs.""" cache_prefix = "mt-bench" def _run_generation(model_name: str) -> pd.DataFrame: @@ -114,6 +129,7 @@ def _load_or_generate(model_name: str) -> pd.DataFrame: def _build_mt_bench_result_name(args: CliArgs, suffix: str | None = None) -> str: + """Build a filesystem-safe MT-Bench result artifact prefix.""" name = f"{args.dataset}-{args.model_A}-{args.model_B}-{args.judge_model}" name += f"-{args.swap_mode}" if suffix: @@ -128,6 +144,7 @@ def _save_mt_bench_results( annotations_df: pd.DataFrame, name_suffix: str | None = None, ) -> None: + """Persist MT-Bench arguments, annotations, and aggregate results.""" name = _build_mt_bench_result_name(args, suffix=name_suffix) res_folder = Path(args.result_folder) / name res_folder.mkdir(parents=True, exist_ok=True) @@ -148,8 +165,8 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, - constrained_plain_verdict: bool, ) -> pd.Series: + """Run FastChat-style MT-Bench judging and save the resulting artifacts.""" prefs, annotations, combined_metadata, num_inconsistent = ( judge_mt_bench_pairwise_fastchat( judge_chat_model=judge_chat_model, @@ -163,7 +180,7 @@ def _run_mt_bench_fastchat( swap_mode=args.swap_mode, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, - constrained_plain_verdict=constrained_plain_verdict, + provide_explanation=args.provide_explanation, ) ) @@ -201,7 +218,8 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): args.swap_mode = "both" if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: print( - "MT-Bench judge prompts require room for explanation plus verdict; " + "MT-Bench judge prompts require room for budgeted thinking and the " + "final JSON verdict; " f"overriding max_out_tokens_judge from {args.max_out_tokens_judge} " f"to {_MIN_MT_BENCH_JUDGE_TOKENS}." ) @@ -216,9 +234,11 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): ignore_cache=ignore_cache, ) judge_model_kwargs = dict(args.engine_kwargs) - judge_model_kwargs["structured_outputs_json"] = build_mt_bench_verdict_json_schema() + judge_model_kwargs["structured_outputs_json"] = build_mt_bench_verdict_json_schema( + include_explanation=args.provide_explanation + ) judge_model_kwargs.setdefault( - "thinking_token_budget", _DEFAULT_MT_BENCH_JUDGE_THINKING_TOKEN_BUDGET + "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET ) judge_chat_model = make_model( @@ -235,5 +255,4 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, - constrained_plain_verdict=False, ) diff --git a/judgearena/prompts/mt_bench/system-base.txt b/judgearena/prompts/mt_bench/system-base.txt index 8a2e41d..7cf5120 100644 --- a/judgearena/prompts/mt_bench/system-base.txt +++ b/judgearena/prompts/mt_bench/system-base.txt @@ -1 +1 @@ -Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Output your response as valid JSON with two keys: "reasoning" for a concise rationale under 300 characters and "verdict" with exactly one of "A", "B", or "C", where "A" means assistant A is better, "B" means assistant B is better, and "C" means a tie. +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. {output_format_instruction} diff --git a/judgearena/prompts/prompt-with-explanation.txt b/judgearena/prompts/prompt-with-explanation.txt index 6600f51..804e766 100644 --- a/judgearena/prompts/prompt-with-explanation.txt +++ b/judgearena/prompts/prompt-with-explanation.txt @@ -1,21 +1,24 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's Answer|> +<|The Start of Assistant A's {completion_label}|> {completion_A} -<|The End of Assistant A's Answer|> +<|The End of Assistant A's {completion_label}|> -<|The Start of Assistant B's Answer|> +<|The Start of Assistant B's {completion_label}|> {completion_B} -<|The End of Assistant B's Answer|> +<|The End of Assistant B's {completion_label}|> # Your output ## Format description -Your output should follow this format: +Your output should be valid JSON with exactly these keys: ``` -score_A: -score_B: +{{ + "explanation": "", + "score_A": , + "score_B": +}} ``` -## Your output, do not repeat the input above, first starts with an explanation of your judgement +## Your output, do not repeat the input above diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index 060bee2..ce7ee4e 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -15,7 +15,6 @@ Your output should be valid JSON with exactly these keys: ``` {{ - "reasoning": "", "score_A": , "score_B": }} diff --git a/judgearena/utils.py b/judgearena/utils.py index 5691477..c3b2ee1 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -31,6 +31,12 @@ def _data_root_path() -> Path: data_root = _data_root_path() +DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET = 512 +VLLM_QWEN_REASONING_START_STR = "" +VLLM_QWEN_REASONING_END_STR = ( + "I have to give the solution based on the thinking directly now." +) + def set_langchain_cache(): set_llm_cache(SQLiteCache(database_path=str(data_root / ".langchain.db"))) @@ -40,7 +46,7 @@ def download_hf(name: str, local_path: Path): local_path.mkdir(exist_ok=True, parents=True) # downloads the model from huggingface into `local_path` folder snapshot_download( - repo_id="geoalgo/llmjudge", + repo_id="judge-arena/judge-arena-dataset", repo_type="dataset", allow_patterns=f"*{name}*", local_dir=local_path, @@ -134,6 +140,7 @@ def safe_text(value: object, truncate_chars: int | None) -> str: def strip_thinking_tags(text: str | None) -> str: + """Remove full `...` blocks from raw model output.""" if not isinstance(text, str): return "" return _THINK_BLOCK_RE.sub("", text) @@ -143,7 +150,8 @@ def extract_json_object(text: str | None) -> dict[str, Any] | None: """Best-effort JSON object extraction from model output. Handles raw JSON, fenced JSON blocks, and outputs that still contain leaked - `...` sections ahead of the machine-readable payload. + reasoning text such as `...{...}` ahead of the + machine-readable payload. """ cleaned = strip_thinking_tags(text).strip() @@ -318,9 +326,17 @@ def __init__( "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } if thinking_token_budget is not None: - vllm_kwargs.setdefault("reasoning_config", ReasoningConfig()) if "qwen3" in model.lower(): + vllm_kwargs.setdefault( + "reasoning_config", + ReasoningConfig( + reasoning_start_str=VLLM_QWEN_REASONING_START_STR, + reasoning_end_str=VLLM_QWEN_REASONING_END_STR, + ), + ) vllm_kwargs.setdefault("reasoning_parser", "qwen3") + else: + vllm_kwargs.setdefault("reasoning_config", ReasoningConfig()) self._sampling_params_kwargs["thinking_token_budget"] = int( thinking_token_budget ) diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index 73304e6..5fc2d5e 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -17,7 +17,8 @@ def __init__(self, **kwargs): self.kwargs = kwargs class FakeReasoningConfig: - pass + def __init__(self, **kwargs): + captured["reasoning_config_kwargs"] = kwargs class FakeLLM: def __init__(self, *, model, trust_remote_code, **kwargs): @@ -74,6 +75,10 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 assert captured["structured_kwargs"]["json"]["type"] == "object" + assert captured["reasoning_config_kwargs"] == { + "reasoning_start_str": utils.VLLM_QWEN_REASONING_START_STR, + "reasoning_end_str": utils.VLLM_QWEN_REASONING_END_STR, + } llm_kwargs = captured["llm_init"]["kwargs"] assert llm_kwargs["reasoning_parser"] == "qwen3" assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index 6a921d5..3e37366 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -6,14 +6,26 @@ from judgearena.generate_and_evaluate import main as main_generate_and_eval -def test_build_pair_score_json_schema_covers_valid_range(): +def test_build_pair_score_json_schema_covers_valid_range_without_explanation(): schema = evaluate.build_pair_score_json_schema() assert schema["type"] == "object" - assert set(schema["required"]) == {"reasoning", "score_A", "score_B"} - assert schema["properties"]["reasoning"] == { + assert set(schema["required"]) == {"score_A", "score_B"} + for key in ("score_A", "score_B"): + assert schema["properties"][key]["type"] == "integer" + assert schema["properties"][key]["minimum"] == 0 + assert schema["properties"][key]["maximum"] == 10 + assert schema["additionalProperties"] is False + + +def test_build_pair_score_json_schema_covers_valid_range_with_explanation(): + schema = evaluate.build_pair_score_json_schema(include_explanation=True) + + assert schema["type"] == "object" + assert set(schema["required"]) == {"explanation", "score_A", "score_B"} + assert schema["properties"]["explanation"] == { "type": "string", - "maxLength": evaluate._PAIR_REASONING_MAX_CHARS, + "maxLength": evaluate._PAIR_EXPLANATION_MAX_CHARS, } for key in ("score_A", "score_B"): assert schema["properties"][key]["type"] == "integer" @@ -135,7 +147,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs generate_and_evaluate, "judge_and_parse_prefs", lambda **kwargs: ( - [{"judge_completion": '{"reasoning":"ok","score_A":1,"score_B":2}'}], + [{"judge_completion": '{"score_A":1,"score_B":2}'}], None, pd.Series([1.0]), ), @@ -156,7 +168,64 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs assert captured["make_model"]["kwargs"]["structured_outputs_json"] == ( evaluate.build_pair_score_json_schema() ) - assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 128 + assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 + + +def test_main_passes_explanation_schema_to_vllm_judge_when_requested( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = kwargs + return object() + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + lambda **kwargs: ( + [{"judge_completion": '{"explanation":"ok","score_A":1,"score_B":2}'}], + None, + pd.Series([1.0]), + ), + ) + + prefs = main_generate_and_eval( + CliArgs( + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + provide_explanation=True, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [1.0] + assert captured["make_model"]["structured_outputs_json"] == ( + evaluate.build_pair_score_json_schema(include_explanation=True) + ) + assert captured["make_model"]["thinking_token_budget"] == 512 def test_annotate_battles_warns_when_judge_completions_are_truncated( diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index d416d82..bc6f724 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -193,45 +193,66 @@ def test_parse_fastchat_verdict_accepts_plain_structured_labels(): def test_parse_fastchat_verdict_accepts_json_and_strips_thinking(): assert ( fastchat_compat._parse_fastchat_verdict( - 'Need a longer chain.{"reasoning":"done","verdict":"B"}' + 'Need a longer chain.{"explanation":"done","verdict":"B"}' ) == "B" ) assert ( fastchat_compat._parse_fastchat_verdict( - '```json\n{"reasoning":"tie","verdict":"C"}\n```' + '```json\n{"explanation":"tie","verdict":"C"}\n```' ) == "tie" ) assert ( fastchat_compat._parse_fastchat_verdict( - 'unfinished analysis {"reasoning":"cut short","verdict":"A"' + 'unfinished analysis {"explanation":"cut short","verdict":"A"' ) == "A" ) -def test_conservative_winner_uses_non_error_side_when_only_one_parse_fails(): +def test_pair_v2_system_prompt_omits_explanation_when_disabled(): + rendered = fastchat_compat._PAIR_V2.render_system_prompt(provide_explanation=False) + + assert "provide a short explanation" not in rendered + assert 'exactly one key: "verdict"' in rendered + + +def test_conservative_winner_marks_one_sided_parse_failures_as_error(): assert fastchat_compat._conservative_winner("model_A", "error") == ( - "model_A", + "error", False, ) assert fastchat_compat._conservative_winner("error", "model_B") == ( - "model_B", + "error", False, ) assert fastchat_compat._conservative_winner("error", "error") == ("error", False) assert fastchat_compat._conservative_winner("model_A", "model_B") == ("tie", True) -def test_build_mt_bench_verdict_json_schema(): +def test_build_mt_bench_verdict_json_schema_without_explanation(): schema = mt_bench_utils.build_mt_bench_verdict_json_schema() assert schema["type"] == "object" - assert set(schema["required"]) == {"reasoning", "verdict"} - assert schema["properties"]["reasoning"] == { + assert set(schema["required"]) == {"verdict"} + assert schema["properties"] == { + "verdict": { + "type": "string", + "enum": ["A", "B", "C"], + } + } + assert schema["additionalProperties"] is False + + +def test_build_mt_bench_verdict_json_schema_with_explanation(): + schema = mt_bench_utils.build_mt_bench_verdict_json_schema(include_explanation=True) + + assert schema["type"] == "object" + assert set(schema["required"]) == {"explanation", "verdict"} + assert schema["properties"]["explanation"] == { "type": "string", - "maxLength": mt_bench_utils._MT_BENCH_REASONING_MAX_CHARS, + "maxLength": mt_bench_utils._MT_BENCH_EXPLANATION_MAX_CHARS, } assert schema["properties"]["verdict"] == { "type": "string", @@ -327,7 +348,7 @@ def fake_run_mt_bench_fastchat(**kwargs): "gpu_memory_utilization": 0.7, "language_model_only": True, "structured_outputs_json": mt_bench_utils.build_mt_bench_verdict_json_schema(), - "thinking_token_budget": 192, + "thinking_token_budget": 512, } assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" - assert captured["run_mt_bench_fastchat"]["constrained_plain_verdict"] is False + assert "constrained_plain_verdict" not in captured["run_mt_bench_fastchat"] diff --git a/tests/test_regexp.py b/tests/test_regexp.py index 74e9407..0f7a868 100644 --- a/tests/test_regexp.py +++ b/tests/test_regexp.py @@ -44,7 +44,7 @@ def test_pair_score_prefers_json_scores_over_reasoning_text(): raw_text = """ I would score assistant A as 2/10 if I stopped early. { - "reasoning": "At first glance I might score assistant A as 2, but after comparing both answers carefully, assistant B is better.", + "explanation": "At first glance I might score assistant A as 2, but after comparing both answers carefully, assistant B is better.", "score_A": 0, "score_B": 10 } From cb7ada5aee0f617e19dfde98b884448088205b4d Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 15 Apr 2026 16:46:06 +0200 Subject: [PATCH 11/28] Revert to free form generation, and use thinking token budget with regex stripping since the structured output wasn't working for isolating thinking tokens anyway --- README.md | 27 + judgearena/evaluate.py | 106 ++-- judgearena/generate.py | 36 +- judgearena/generate_and_evaluate.py | 32 +- judgearena/mt_bench/fastchat_compat.py | 60 +- judgearena/mt_bench/mt_bench_utils.py | 119 ++-- judgearena/openrouter_reference_pricing.py | 571 ++++++++++++++++++ .../prompts/prompt-with-explanation.txt | 11 +- judgearena/prompts/prompt.txt | 10 +- judgearena/repro.py | 3 + judgearena/utils.py | 120 ++-- tests/test_chat_vllm.py | 18 +- tests/test_local_completion_loading.py | 78 +-- tests/test_mt_bench_downloads.py | 84 +-- tests/test_openrouter_reference_pricing.py | 209 +++++++ tests/test_regexp.py | 39 +- 16 files changed, 1200 insertions(+), 323 deletions(-) create mode 100644 judgearena/openrouter_reference_pricing.py create mode 100644 tests/test_openrouter_reference_pricing.py diff --git a/README.md b/README.md index 508ac9f..c7840ce 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,25 @@ The evaluation scripts expose four different length controls with different role - `--max_out_tokens_judge`: generation token budget for the judge completion (reasoning + score output). - `--max_model_len`: optional vLLM context-window limit (prompt + generated tokens), applied to vLLM models; this should be greater than or equal to the two `max_out_tokens_*` values. +### OpenRouter Reference Pricing For Local Runs + +JudgeArena can estimate an `openrouter_reference_cost_usd` for local runs by combining: + +- locally counted prompt and completion tokens +- OpenRouter's public model pricing from `GET /api/v1/models` + +This is a reference price, not actual billed spend from either OpenRouter or your cluster. + +Reference pricing is only applied when the local model has an exact OpenRouter match using one of: + +- the OpenRouter model `id` +- the OpenRouter `canonical_slug` +- the model `hugging_face_id` + +If no exact match exists, JudgeArena still records token totals but leaves the reference price unset. + +The aggregated pricing summary is printed to stdout and stored in `run-metadata.v1.json` under `pricing_reference`. + ### Engine-Specific Configuration (`--engine_kwargs`) Some providers expose additional engine-level knobs (for example, vLLM allows configuring tensor parallelism or GPU memory utilization). @@ -273,6 +292,14 @@ Datasets are stored in: - `$JUDGEARENA_DATA` if set; otherwise `$OPENJURY_DATA` if set (legacy) - `~/judgearena-data/` if neither variable is set +If compute nodes do not have internet access, refresh the cached OpenRouter price book on the login node before launching jobs: + +```bash +uv run python -m judgearena.openrouter_reference_pricing --refresh +``` + +The benchmark launcher in `slurmpilot_scripts/launch_benchmark_eval.py` also attempts to warm this cache automatically. The cache is stored under `$JUDGEARENA_DATA/reference_pricing/openrouter_models.json` unless `JUDGEARENA_OPENROUTER_PRICE_CACHE` overrides it. + ## 🛠️ Development To maintain code quality, we use **pre-commit** hooks. Run this once to set them up: diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 984eebd..395aa13 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -14,14 +14,20 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( compute_pref_summary, data_root, do_inference, download_hf, - extract_json_object, + infer_model_spec_from_instance, read_df, + strip_thinking_tags, truncate, ) @@ -37,13 +43,7 @@ def preference_from_scores(self, score_a: float, score_b: float) -> float: ) def parse_model_raw(self, judge_completion: str) -> float | None: - json_payload = extract_json_object(judge_completion) - if json_payload is not None: - score_a = self._coerce_score(json_payload.get("score_A")) - score_b = self._coerce_score(json_payload.get("score_B")) - if score_a is not None and score_b is not None: - return float(self.preference_from_scores(score_a, score_b)) - + judge_completion = strip_thinking_tags(judge_completion) # lower case to avoid confusion, e.g. when "a" is used instead of "A" score_a = self.get_regexp_match( judge_completion.lower(), r'score.*?a[": *\n]*(-?\d+)' @@ -52,23 +52,18 @@ def parse_model_raw(self, judge_completion: str) -> float | None: judge_completion.lower(), r'score.*?b[": *\n]*(-?\d+)' ) if score_a is None or score_b is None: - return None + verdict_match = re.search(r"\[\[\s*([ABCabc])\s*\]\]", judge_completion) + if verdict_match is None: + return None + bracketed_verdict = verdict_match.group(1).lower() + return { + "a": 0.0, + "b": 1.0, + "c": 0.5, + }[bracketed_verdict] else: return float(self.preference_from_scores(score_a, score_b)) - def _coerce_score(self, value: object) -> float | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return float(value) - if isinstance(value, float) and value.is_integer(): - return value - if isinstance(value, str): - match = re.fullmatch(r"\s*(-?\d+)\s*", value) - if match is not None: - return float(match.group(1)) - return None - def get_regexp_match(self, s: str, regex: str, group_index: int = 1): m = re.search(re.compile(regex), s) if m is None: @@ -77,41 +72,10 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) -_PAIR_SCORE_MIN = 0 -_PAIR_SCORE_MAX = 10 -_PAIR_EXPLANATION_MAX_CHARS = 384 - - -def build_pair_score_json_schema(*, include_explanation: bool = False) -> dict: - score_field = { - "type": "integer", - "minimum": _PAIR_SCORE_MIN, - "maximum": _PAIR_SCORE_MAX, - } - properties: dict[str, object] = { - "score_A": score_field, - "score_B": score_field, - } - required = ["score_A", "score_B"] - if include_explanation: - properties = { - "explanation": { - "type": "string", - "maxLength": _PAIR_EXPLANATION_MAX_CHARS, - }, - **properties, - } - required = ["explanation", *required] - return { - "type": "object", - "properties": properties, - "required": required, - "additionalProperties": False, - } - - _COMPLETION_LABEL_SINGLE = "Answer" _COMPLETION_LABEL_MULTI_TURN = "Conversation with User" +_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" +_SCORE_FENCE = "\n```" def load_judge_system_and_user_prompt( @@ -129,6 +93,10 @@ def load_judge_system_and_user_prompt( "{completion_label}", _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, ) + user_prompt_template = user_prompt_template.replace( + "{explanation_suffix}", + _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + ) return system_prompt, user_prompt_template @@ -230,6 +198,8 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): from langchain_together.llms import Together judge_chat_model = Together(model="meta-llama/Llama-3.3-70B-Instruct-Turbo") + judge_model_spec = infer_model_spec_from_instance(judge_chat_model) + usage_tracker = OpenRouterReferencePricingTracker() unique_string = dataset + "-" + datetime.now().strftime("%Y%m%d_%H%M%S") output_folder = data_root / "judge-evals" / unique_string @@ -250,6 +220,9 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): use_tqdm=use_tqdm, truncate_input_chars=truncate_input_chars, provide_explanation=provide_explanation, + usage_tracker=usage_tracker, + usage_phase="judge", + usage_model_spec=judge_model_spec, ) # Pairwise judge results @@ -266,6 +239,13 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): print(f"{method_A} against {method_B}:\n{results}") with open(output_folder / "results.json", "w") as f: json.dump(_to_jsonable(results), f, allow_nan=False) + pricing_reference = None + if judge_model_spec is not None: + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={"judge": judge_model_spec}, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) run_metadata = { "dataset": dataset, @@ -293,6 +273,7 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): judge_system_prompt=judge_system_prompt, judge_user_prompt_template=judge_user_prompt_template, started_at_utc=run_started_at, + pricing_reference=pricing_reference, ) except OSError as e: print(f"Warning: failed to write run metadata: {e}") @@ -317,6 +298,9 @@ def annotate_battles( truncate_input_chars: int | None = 8192, use_tqdm: bool = False, provide_explanation: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, ) -> list[JudgeAnnotation]: """ Directly evaluate from list of instructions and completions @@ -387,6 +371,9 @@ def annotate_battles( chat_model=judge_chat_model, inputs=inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) annotations = [] @@ -421,6 +408,9 @@ def judge_and_parse_prefs( user_prompt_template: str | None = None, truncate_input_chars: int = 8192, use_tqdm: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, ) -> tuple[list[JudgeAnnotation], list[JudgeAnnotation] | None, pd.Series]: """Run judge annotation and parse preferences, handling swap_mode='both'. @@ -446,6 +436,9 @@ def judge_and_parse_prefs( user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) annotations_reversed = None @@ -460,6 +453,9 @@ def judge_and_parse_prefs( user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) def _none_to_nan(x): diff --git a/judgearena/generate.py b/judgearena/generate.py index 5fe1666..9720fad 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -15,6 +15,8 @@ def generate_instructions( max_tokens: int | None = 32768, use_tqdm: bool = True, system_prompt: str | None = None, + usage_tracker=None, + usage_phase: str | None = None, **engine_kwargs, ) -> pd.DataFrame: chat_model = make_model(model, max_tokens=max_tokens, **engine_kwargs) @@ -41,6 +43,9 @@ def generate_instructions( chat_model=chat_model, inputs=inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, ) df_outputs = pd.DataFrame( data={ @@ -69,6 +74,8 @@ def _infer_grouped_by_temperature( inputs: list, temperatures: list[float], use_tqdm: bool, + usage_tracker=None, + usage_phase: str | None = None, ) -> list[str]: outputs: list[str] = [""] * len(inputs) groups: dict[float, list[int]] = {} @@ -91,6 +98,9 @@ def _infer_grouped_by_temperature( chat_model=group_model, inputs=group_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model_spec, ) for i, out in zip(idxs, group_outs, strict=True): outputs[i] = out @@ -105,6 +115,8 @@ def generate_multiturn( max_tokens: int | None = 8192, use_tqdm: bool = True, temperature_config: dict[str, float] | None = None, + usage_tracker=None, + usage_phase: str | None = None, **model_kwargs, ) -> pd.DataFrame: """Generate two-turn completions for MT-Bench style questions.""" @@ -148,12 +160,17 @@ def generate_multiturn( inputs=turn1_inputs, temperatures=temperatures, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, ) else: completions_turn_1 = do_inference( chat_model=chat_model, inputs=turn1_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, ) turn2_inputs = [] @@ -195,12 +212,17 @@ def generate_multiturn( inputs=turn2_inputs, temperatures=temperatures, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, ) else: completions_turn_2 = do_inference( chat_model=chat_model, inputs=turn2_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model, ) return pd.DataFrame( @@ -218,20 +240,26 @@ def generate_base( truncate_input_chars: int | None = 8192, max_tokens: int | None = 32768, use_tqdm: bool = False, + usage_tracker=None, + usage_phase: str | None = None, **engine_kwargs, ) -> pd.DataFrame: - model = make_model(model, max_tokens=max_tokens, **engine_kwargs) + model_spec = model + model = make_model(model_spec, max_tokens=max_tokens, **engine_kwargs) inputs = [ truncate(instruction, max_len=truncate_input_chars) for instruction in instructions ] - completions = model.batch( + completions = do_inference( + chat_model=model, inputs=inputs, - max_tokens=max_tokens, + use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=model_spec, ) - completions = [x.content if hasattr(x, "content") else x for x in completions] df_outputs = pd.DataFrame( data={ diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 109c648..df58fa3 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -13,11 +13,7 @@ import pandas as pd from judgearena.cli_common import BaseCliArgs, add_common_arguments, parse_engine_kwargs -from judgearena.evaluate import ( - build_pair_score_json_schema, - judge_and_parse_prefs, - resolve_judge_prompts, -) +from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions from judgearena.instruction_dataset.arena_hard import ( @@ -25,6 +21,11 @@ is_arena_hard_dataset, ) from judgearena.mt_bench.mt_bench_utils import run_mt_bench +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, @@ -196,6 +197,7 @@ def main(args: CliArgs): """ run_started_at = datetime.now(UTC) + usage_tracker = OpenRouterReferencePricingTracker() print( f"Using dataset {args.dataset} and evaluating models {args.model_A} and {args.model_B}." ) @@ -264,6 +266,8 @@ def main(args: CliArgs): instructions=instructions, model=args.model_A, use_tqdm=args.use_tqdm, + usage_tracker=usage_tracker, + usage_phase="generation_model_A", ), ignore_cache=ignore_cache, cache_name=f"{args.dataset}_{args.model_A}_{args.n_instructions}", @@ -283,6 +287,8 @@ def main(args: CliArgs): instructions=instructions, model=args.model_B, use_tqdm=args.use_tqdm, + usage_tracker=usage_tracker, + usage_phase="generation_model_B", ), ignore_cache=ignore_cache, cache_name=f"{args.dataset}_{args.model_B}_{args.n_instructions}", @@ -298,9 +304,6 @@ def main(args: CliArgs): judge_model_kwargs = dict(args.engine_kwargs) if args.judge_model.split("/")[0] == "VLLM": - judge_model_kwargs["structured_outputs_json"] = build_pair_score_json_schema( - include_explanation=args.provide_explanation - ) judge_model_kwargs.setdefault( "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET ) @@ -353,6 +356,9 @@ def main(args: CliArgs): user_prompt_template=judge_user_prompt_template, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, + usage_tracker=usage_tracker, + usage_phase="judge", + usage_model_spec=args.judge_model, ) df = pd.DataFrame(annotations) @@ -386,6 +392,15 @@ def main(args: CliArgs): } print(f"{args.model_A} vs {args.model_B} judged by {args.judge_model}") print_results(results) + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={ + "generation_model_A": args.model_A, + "generation_model_B": args.model_B, + "judge": args.judge_model, + }, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) with open(res_folder / f"results-{name}.json", "w") as f: json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) @@ -410,6 +425,7 @@ def main(args: CliArgs): judge_system_prompt=effective_judge_system_prompt, judge_user_prompt_template=judge_user_prompt_template, started_at_utc=run_started_at, + pricing_reference=pricing_reference, ) except OSError as e: print(f"Warning: failed to write run metadata: {e}") diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 41b1df6..18fcd5d 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -12,7 +12,8 @@ from langchain_core.prompts import ChatPromptTemplate from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows -from judgearena.utils import do_inference, extract_json_object, strip_thinking_tags +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker +from judgearena.utils import do_inference, strip_thinking_tags FASTCHAT_TEMPERATURE_CONFIG: dict[str, float] = { "writing": 0.7, @@ -65,8 +66,6 @@ def render_system_prompt(self, *, provide_explanation: bool) -> str: _USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" _USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" -_PARTIAL_JSON_VERDICT_RE = re.compile(r'"verdict"\s*:\s*"(?P[ABC])"') - def _load_prompt_text(filename: str) -> str: path = _PROMPTS_DIR / filename @@ -95,14 +94,13 @@ def _build_system_prompt( ) -> str: focus_segment = f"{focus_line} " if focus_line else "" output_format_instruction = ( - 'Output your response as valid JSON with exactly two keys: "explanation" ' - 'for a concise rationale under 300 characters and "verdict" with exactly ' - 'one of "A", "B", or "C", where "A" means assistant A is better, "B" ' - 'means assistant B is better, and "C" means a tie.' + "After providing your explanation, output your final verdict by strictly " + 'following this format: "[[A]]" if assistant A is better, "[[B]]" if ' + 'assistant B is better, and "[[C]]" for a tie.' if provide_explanation - else 'Output your response as valid JSON with exactly one key: "verdict" ' - 'with exactly one of "A", "B", or "C", where "A" means assistant A is ' - 'better, "B" means assistant B is better, and "C" means a tie.' + else "Output your final verdict by strictly following this format: " + '"[[A]]" if assistant A is better, "[[B]]" if assistant B is better, ' + 'and "[[C]]" for a tie.' ) return _render_prompt_text( _SYSTEM_BASE_FILE, @@ -215,32 +213,12 @@ def _load_pairwise_prompt( def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: - json_payload = extract_json_object(judgment) - if json_payload is not None: - verdict = json_payload.get("verdict") - if isinstance(verdict, str): - normalized = verdict.strip().upper() - if normalized == "A": - return "A" - if normalized == "B": - return "B" - if normalized in {"C", "TIE"}: - return "tie" - - partial_json_match = _PARTIAL_JSON_VERDICT_RE.search(judgment) - if partial_json_match is not None: - if partial_json_match.group("verdict") == "A": - return "A" - if partial_json_match.group("verdict") == "B": - return "B" - return "tie" - stripped = strip_thinking_tags(judgment).strip() - if "[[A]]" in stripped or stripped == "A": + if "[[A]]" in stripped: return "A" - if "[[B]]" in stripped or stripped == "B": + if "[[B]]" in stripped: return "B" - if "[[C]]" in stripped or stripped == "C": + if "[[C]]" in stripped: return "tie" return "error" @@ -324,6 +302,9 @@ def _infer_by_prompt_groups( use_tqdm: bool, swap_answers: bool, provide_explanation: bool, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, ) -> list[str]: """Run judge inference, grouping by prompt variant for batching.""" grouped_indices = _group_indices_by_prompt(items) @@ -350,6 +331,9 @@ def _infer_by_prompt_groups( chat_model=judge_chat_model, inputs=prompt_inputs, use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, ) for i, out in zip(idxs, outs, strict=True): judgments[i] = str(out) @@ -497,8 +481,10 @@ def judge_mt_bench_pairwise_fastchat( truncate_input_chars: int | None, use_tqdm: bool, provide_explanation: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: - """Run FastChat-style MT-Bench pairwise judging with JSON verdict outputs.""" + """Run FastChat-style MT-Bench pairwise judging with bracketed verdict outputs.""" assert turns_mode in ("both", "single", "multi") assert swap_mode in ("fixed", "both") @@ -520,6 +506,9 @@ def judge_mt_bench_pairwise_fastchat( use_tqdm=use_tqdm, swap_answers=False, provide_explanation=provide_explanation, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, ) g2_judgments: list[str] | None = None @@ -530,6 +519,9 @@ def judge_mt_bench_pairwise_fastchat( use_tqdm=use_tqdm, swap_answers=True, provide_explanation=provide_explanation, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, ) annotations: list[dict[str, Any]] = [] diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 846077d..2a9d290 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -9,7 +9,7 @@ import json import os from dataclasses import asdict -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING @@ -23,7 +23,12 @@ FASTCHAT_TEMPERATURE_CONFIG, judge_mt_bench_pairwise_fastchat, ) -from judgearena.repro import _to_jsonable +from judgearena.openrouter_reference_pricing import ( + OpenRouterReferencePricingTracker, + build_openrouter_reference_pricing_summary, + format_openrouter_reference_pricing_summary, +) +from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, cache_function_dataframe, @@ -35,34 +40,10 @@ from judgearena.generate_and_evaluate import CliArgs -# MT-Bench judge prompts need headroom for budgeted thinking and the final JSON. -_MIN_MT_BENCH_JUDGE_TOKENS = 2048 -_MT_BENCH_EXPLANATION_MAX_CHARS = 384 - - -def build_mt_bench_verdict_json_schema( - *, include_explanation: bool = False -) -> dict[str, object]: - """Return the MT-Bench judge schema for verdict-only or verdict+explanation.""" - properties: dict[str, object] = { - "verdict": {"type": "string", "enum": ["A", "B", "C"]}, - } - required = ["verdict"] - if include_explanation: - properties = { - "explanation": { - "type": "string", - "maxLength": _MT_BENCH_EXPLANATION_MAX_CHARS, - }, - **properties, - } - required = ["explanation", *required] - return { - "type": "object", - "properties": properties, - "required": required, - "additionalProperties": False, - } +# Original MT-Bench prompts include a visible explanation before the final verdict, +# and Qwen can spend thousands of visible tokens after reasoning ends on turn 2. +_MIN_MT_BENCH_JUDGE_TOKENS = 24576 +_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN = 28672 def _align_mt_bench_completions( @@ -84,11 +65,12 @@ def _generate_mt_bench_completions( args: CliArgs, questions_df: pd.DataFrame, ignore_cache: bool, + usage_tracker: OpenRouterReferencePricingTracker, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Load baseline MT-Bench answers or generate fresh multi-turn outputs.""" cache_prefix = "mt-bench" - def _run_generation(model_name: str) -> pd.DataFrame: + def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: return generate_multiturn( questions=questions_df, model=model_name, @@ -98,10 +80,12 @@ def _run_generation(model_name: str) -> pd.DataFrame: max_model_len=args.max_model_len, chat_template=args.chat_template, temperature_config=FASTCHAT_TEMPERATURE_CONFIG, + usage_tracker=usage_tracker, + usage_phase=usage_phase, **args.engine_kwargs, ) - def _load_or_generate(model_name: str) -> pd.DataFrame: + def _load_or_generate(model_name: str, usage_phase: str) -> pd.DataFrame: loaded_answers = load_mt_bench_model_answers( model_name, n_instructions=args.n_instructions ) @@ -113,7 +97,7 @@ def _load_or_generate(model_name: str) -> pd.DataFrame: model_name=model_name, ) generated_answers = cache_function_dataframe( - lambda: _run_generation(model_name), + lambda: _run_generation(model_name, usage_phase), ignore_cache=ignore_cache, cache_name=f"{cache_prefix}_{model_name}_{args.n_instructions}", ) @@ -123,8 +107,8 @@ def _load_or_generate(model_name: str) -> pd.DataFrame: model_name=model_name, ) - completions_a = _load_or_generate(args.model_A) - completions_b = _load_or_generate(args.model_B) + completions_a = _load_or_generate(args.model_A, "generation_model_A") + completions_b = _load_or_generate(args.model_B, "generation_model_B") return completions_a, completions_b @@ -142,6 +126,9 @@ def _save_mt_bench_results( args: CliArgs, results: dict[str, object], annotations_df: pd.DataFrame, + questions_df: pd.DataFrame, + pricing_reference: dict[str, object] | None, + started_at_utc: datetime, name_suffix: str | None = None, ) -> None: """Persist MT-Bench arguments, annotations, and aggregate results.""" @@ -157,6 +144,20 @@ def _save_mt_bench_results( with open(res_folder / f"results-{name}.json", "w") as f: json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) + write_run_metadata( + output_dir=res_folder, + entrypoint="judgearena.mt_bench.mt_bench_utils.run_mt_bench", + run=asdict(args), + results=results, + input_payloads={ + "instruction_index": questions_df.index.tolist(), + "turn_1": questions_df["turn_1"].tolist(), + "turn_2": questions_df["turn_2"].tolist(), + }, + started_at_utc=started_at_utc, + pricing_reference=pricing_reference, + ) + def _run_mt_bench_fastchat( *, @@ -165,6 +166,9 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, + provide_explanation: bool, + usage_tracker: OpenRouterReferencePricingTracker, + started_at_utc: datetime, ) -> pd.Series: """Run FastChat-style MT-Bench judging and save the resulting artifacts.""" prefs, annotations, combined_metadata, num_inconsistent = ( @@ -180,7 +184,9 @@ def _run_mt_bench_fastchat( swap_mode=args.swap_mode, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, - provide_explanation=args.provide_explanation, + provide_explanation=provide_explanation, + usage_tracker=usage_tracker, + usage_phase="judge", ) ) @@ -199,10 +205,22 @@ def _run_mt_bench_fastchat( "user": os.getenv("USER", ""), } print_results(results) + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={ + "generation_model_A": args.model_A, + "generation_model_B": args.model_B, + "judge": args.judge_model, + }, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) _save_mt_bench_results( args=args, results=results, annotations_df=pd.DataFrame(annotations), + questions_df=questions_df, + pricing_reference=pricing_reference, + started_at_utc=started_at_utc, name_suffix="mtbench", ) return prefs @@ -210,6 +228,13 @@ def _run_mt_bench_fastchat( def run_mt_bench(args: CliArgs, ignore_cache: bool): """MT-Bench pipeline with FastChat-compatible pairwise judging.""" + run_started_at = datetime.now(UTC) + usage_tracker = OpenRouterReferencePricingTracker() + if not args.provide_explanation: + print( + "MT-Bench ignores provide_explanation=False and keeps the original " + "FastChat-style explanation-plus-verdict prompt." + ) if args.swap_mode != "both": print( "MT-Bench requires swap_mode='both' to match FastChat and correct " @@ -218,8 +243,8 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): args.swap_mode = "both" if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: print( - "MT-Bench judge prompts require room for budgeted thinking and the " - "final JSON verdict; " + "MT-Bench judge prompts require room for budgeted thinking, the " + "original explanation, and the final verdict; " f"overriding max_out_tokens_judge from {args.max_out_tokens_judge} " f"to {_MIN_MT_BENCH_JUDGE_TOKENS}." ) @@ -232,11 +257,20 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): args=args, questions_df=questions_df, ignore_cache=ignore_cache, + usage_tracker=usage_tracker, ) + if ( + args.max_model_len is not None + and args.max_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + ): + print( + "MT-Bench judge prompts require a larger total context window for " + "prompt plus completion; " + f"overriding max_model_len from {args.max_model_len} " + f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN} for the judge." + ) + args.max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN judge_model_kwargs = dict(args.engine_kwargs) - judge_model_kwargs["structured_outputs_json"] = build_mt_bench_verdict_json_schema( - include_explanation=args.provide_explanation - ) judge_model_kwargs.setdefault( "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET ) @@ -255,4 +289,7 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, + provide_explanation=True, + usage_tracker=usage_tracker, + started_at_utc=run_started_at, ) diff --git a/judgearena/openrouter_reference_pricing.py b/judgearena/openrouter_reference_pricing.py new file mode 100644 index 0000000..0463f14 --- /dev/null +++ b/judgearena/openrouter_reference_pricing.py @@ -0,0 +1,571 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import urllib.error +import urllib.request +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +OPENROUTER_MODELS_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_PRICE_CACHE_ENV = "JUDGEARENA_OPENROUTER_PRICE_CACHE" +DEFAULT_CACHE_RELATIVE_PATH = Path("reference_pricing") / "openrouter_models.json" +_KNOWN_PROVIDER_PREFIXES = frozenset( + { + "ChatOpenAI", + "Dummy", + "LlamaCpp", + "OpenAI", + "OpenRouter", + "Together", + "VLLM", + } +) +_UNAPPLIED_PRICE_COMPONENTS = ( + "image", + "input_cache_read", + "input_cache_write", + "internal_reasoning", + "web_search", +) +_LOCAL_VARIANT_SUFFIX_RE = re.compile( + r"(?i)(?:[-_](?:fp8|fp16|bf16|int8|int4|int3|awq|gptq(?:[-_][a-z0-9]+)*))+$" +) + + +def _data_root_path() -> Path: + raw = os.environ.get("JUDGEARENA_DATA") or os.environ.get("OPENJURY_DATA") + if raw: + return Path(raw).expanduser() + return Path("~/judgearena-data/").expanduser() + + +def get_openrouter_price_cache_path() -> Path: + raw = os.environ.get(OPENROUTER_PRICE_CACHE_ENV) + if raw: + return Path(raw).expanduser() + return _data_root_path() / DEFAULT_CACHE_RELATIVE_PATH + + +def _utc_now_iso() -> str: + return datetime.now(UTC).isoformat() + + +def _as_price_float(raw_value: object) -> float: + if raw_value in (None, ""): + return 0.0 + return float(raw_value) + + +@dataclass(frozen=True) +class OpenRouterModelPricing: + prompt: float + completion: float + request: float = 0.0 + image: float = 0.0 + web_search: float = 0.0 + internal_reasoning: float = 0.0 + input_cache_read: float = 0.0 + input_cache_write: float = 0.0 + + +@dataclass(frozen=True) +class OpenRouterModelEntry: + model_id: str + canonical_slug: str | None + hugging_face_id: str | None + name: str + pricing: OpenRouterModelPricing + + def exact_match_candidates(self) -> tuple[str, ...]: + candidates = [self.model_id] + if self.canonical_slug: + candidates.append(self.canonical_slug) + if self.hugging_face_id: + candidates.append(self.hugging_face_id) + return tuple(candidates) + + +@dataclass(frozen=True) +class OpenRouterPriceCatalog: + source_url: str + fetched_at_utc: str | None + cache_path: str + models: tuple[OpenRouterModelEntry, ...] + + +@dataclass(frozen=True) +class TokenUsageRecord: + phase: str + model_spec: str + prompt_tokens: int + completion_tokens: int + requests: int = 1 + + +class OpenRouterReferencePricingTracker: + def __init__(self) -> None: + self._records: list[TokenUsageRecord] = [] + + @property + def records(self) -> list[TokenUsageRecord]: + return list(self._records) + + def has_records(self) -> bool: + return bool(self._records) + + def record_batch_from_model( + self, + *, + phase: str, + model_spec: str, + chat_model: object, + inputs: list, + outputs: list[str], + ) -> bool: + if not hasattr(chat_model, "count_prompt_tokens_batch") or not hasattr( + chat_model, "count_completion_tokens_batch" + ): + return False + + prompt_tokens = chat_model.count_prompt_tokens_batch(inputs) + completion_tokens = chat_model.count_completion_tokens_batch(outputs) + if len(prompt_tokens) != len(completion_tokens) or len(prompt_tokens) != len( + outputs + ): + raise ValueError("Prompt/completion token counts must align with outputs.") + + for prompt_count, completion_count in zip( + prompt_tokens, completion_tokens, strict=True + ): + self._records.append( + TokenUsageRecord( + phase=phase, + model_spec=model_spec, + prompt_tokens=int(prompt_count), + completion_tokens=int(completion_count), + ) + ) + return True + + +def _parse_catalog_model(raw_model: dict[str, Any]) -> OpenRouterModelEntry: + raw_pricing = raw_model.get("pricing") or {} + pricing = OpenRouterModelPricing( + prompt=_as_price_float(raw_pricing.get("prompt")), + completion=_as_price_float(raw_pricing.get("completion")), + request=_as_price_float(raw_pricing.get("request")), + image=_as_price_float(raw_pricing.get("image")), + web_search=_as_price_float(raw_pricing.get("web_search")), + internal_reasoning=_as_price_float(raw_pricing.get("internal_reasoning")), + input_cache_read=_as_price_float(raw_pricing.get("input_cache_read")), + input_cache_write=_as_price_float(raw_pricing.get("input_cache_write")), + ) + return OpenRouterModelEntry( + model_id=str(raw_model["id"]), + canonical_slug=( + str(raw_model["canonical_slug"]) + if raw_model.get("canonical_slug") is not None + else None + ), + hugging_face_id=( + str(raw_model["hugging_face_id"]) + if raw_model.get("hugging_face_id") is not None + else None + ), + name=str(raw_model.get("name") or raw_model["id"]), + pricing=pricing, + ) + + +def parse_openrouter_catalog_payload( + payload: dict[str, Any], + *, + fetched_at_utc: str | None = None, + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + raw_models = payload.get("models") + if raw_models is None: + raw_models = payload.get("data") + if not isinstance(raw_models, list): + raise ValueError("OpenRouter models payload is missing a `data` list.") + return OpenRouterPriceCatalog( + source_url=str(payload.get("source_url") or OPENROUTER_MODELS_URL), + fetched_at_utc=( + str(payload["fetched_at_utc"]) + if payload.get("fetched_at_utc") is not None + else fetched_at_utc + ), + cache_path=str(cache_path or payload.get("cache_path") or ""), + models=tuple(_parse_catalog_model(model) for model in raw_models), + ) + + +def _cache_payload_from_raw_response(raw_payload: dict[str, Any]) -> dict[str, Any]: + return { + "source_url": OPENROUTER_MODELS_URL, + "fetched_at_utc": _utc_now_iso(), + "models": raw_payload.get("data", []), + } + + +def _fetch_openrouter_catalog_payload(timeout_seconds: float = 30.0) -> dict[str, Any]: + request = urllib.request.Request(OPENROUTER_MODELS_URL) + api_key = os.environ.get("OPENROUTER_API_KEY") + if api_key: + request.add_header("Authorization", f"Bearer {api_key}") + with urllib.request.urlopen(request, timeout=timeout_seconds) as response: + return json.loads(response.read().decode("utf-8")) + + +def load_openrouter_price_catalog( + *, + refresh: bool = False, + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + resolved_cache_path = ( + Path(cache_path) + if cache_path is not None + else get_openrouter_price_cache_path() + ) + resolved_cache_path.parent.mkdir(parents=True, exist_ok=True) + if refresh or not resolved_cache_path.is_file(): + fetched_payload = _fetch_openrouter_catalog_payload() + cache_payload = _cache_payload_from_raw_response(fetched_payload) + with open(resolved_cache_path, "w", encoding="utf-8") as handle: + json.dump(cache_payload, handle, indent=2, sort_keys=True) + with open(resolved_cache_path, encoding="utf-8") as handle: + cached_payload = json.load(handle) + return parse_openrouter_catalog_payload( + cached_payload, + cache_path=resolved_cache_path, + ) + + +def load_openrouter_price_catalog_with_fallback( + *, + refresh: bool = False, + cache_path: str | Path | None = None, +) -> tuple[OpenRouterPriceCatalog | None, str | None]: + resolved_cache_path = ( + Path(cache_path) + if cache_path is not None + else get_openrouter_price_cache_path() + ) + try: + catalog = load_openrouter_price_catalog( + refresh=refresh, + cache_path=resolved_cache_path, + ) + except (OSError, ValueError, json.JSONDecodeError, urllib.error.URLError) as exc: + if resolved_cache_path.is_file(): + try: + catalog = load_openrouter_price_catalog( + refresh=False, + cache_path=resolved_cache_path, + ) + return ( + catalog, + f"Using cached OpenRouter price catalog after refresh failed: {exc}", + ) + except (OSError, ValueError, json.JSONDecodeError) as cached_exc: + return None, ( + "OpenRouter price catalog refresh and cache load failed: " + f"{cached_exc}" + ) + return None, f"OpenRouter price catalog unavailable: {exc}" + return catalog, None + + +def _strip_provider_prefix(model_spec: str) -> str | None: + if not model_spec: + return None + if "/" not in model_spec: + return model_spec + provider, remainder = model_spec.split("/", 1) + if provider in _KNOWN_PROVIDER_PREFIXES: + return remainder + return model_spec + + +def _candidate_match_variants(candidate: str) -> tuple[str, ...]: + variants = [candidate] + owner_prefix = "" + model_name = candidate + if "/" in candidate: + owner, model_name = candidate.rsplit("/", 1) + owner_prefix = f"{owner}/" + normalized_model_name = _LOCAL_VARIANT_SUFFIX_RE.sub("", model_name) + if normalized_model_name and normalized_model_name != model_name: + variants.append(f"{owner_prefix}{normalized_model_name}") + return tuple(dict.fromkeys(variants)) + + +def find_openrouter_match( + catalog: OpenRouterPriceCatalog, + model_spec: str, +) -> tuple[OpenRouterModelEntry | None, str | None]: + candidate = _strip_provider_prefix(model_spec) + if not candidate: + return None, None + candidate_variants = _candidate_match_variants(candidate) + lowered_exact = candidate_variants[0].casefold() + lowered_normalized = {variant.casefold() for variant in candidate_variants[1:]} + for model in catalog.models: + lowered_candidates = { + match_candidate.casefold() + for match_candidate in model.exact_match_candidates() + } + if lowered_exact in lowered_candidates: + return model, "exact_case_insensitive" + if lowered_normalized.intersection(lowered_candidates): + return model, "local_variant_suffix_stripped" + return None, None + + +def find_exact_openrouter_match( + catalog: OpenRouterPriceCatalog, + model_spec: str, +) -> OpenRouterModelEntry | None: + matched_model, match_strategy = find_openrouter_match(catalog, model_spec) + if match_strategy == "exact_case_insensitive": + return matched_model + return None + + +def _sum_phase_records(records: list[TokenUsageRecord]) -> dict[str, int]: + prompt_tokens = sum(record.prompt_tokens for record in records) + completion_tokens = sum(record.completion_tokens for record in records) + requests = sum(record.requests for record in records) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + "request_count": requests, + } + + +def _ignored_pricing_components(pricing: OpenRouterModelPricing) -> list[str]: + ignored: list[str] = [] + for field_name in _UNAPPLIED_PRICE_COMPONENTS: + if getattr(pricing, field_name) != 0.0: + ignored.append(field_name) + return ignored + + +def _phase_summary_for_unmatched( + *, + model_spec: str, + usage_totals: dict[str, int], + pricing_status: str, +) -> dict[str, Any]: + return { + "model_spec": model_spec, + "pricing_status": pricing_status, + **usage_totals, + "openrouter_model_id": None, + "openrouter_canonical_slug": None, + "openrouter_hugging_face_id": None, + "openrouter_reference_cost_usd": None, + "applied_pricing_usd": None, + "ignored_pricing_components": [], + } + + +def build_openrouter_reference_pricing_summary( + *, + tracker: OpenRouterReferencePricingTracker, + phase_model_specs: dict[str, str], + refresh_catalog: bool = False, + cache_path: str | Path | None = None, +) -> dict[str, Any]: + phase_records: dict[str, list[TokenUsageRecord]] = { + phase: [record for record in tracker.records if record.phase == phase] + for phase in phase_model_specs + } + should_load_catalog = any(phase_records.values()) + catalog: OpenRouterPriceCatalog | None = None + catalog_warning: str | None = None + if should_load_catalog: + catalog, catalog_warning = load_openrouter_price_catalog_with_fallback( + refresh=refresh_catalog, + cache_path=cache_path, + ) + + phase_summaries: dict[str, dict[str, Any]] = {} + priced_costs: list[float] = [] + for phase, model_spec in phase_model_specs.items(): + records = phase_records[phase] + usage_totals = _sum_phase_records(records) + if not records: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="no_runtime_token_data", + ) + continue + if catalog is None: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="price_catalog_unavailable", + ) + continue + + matched_model, match_strategy = find_openrouter_match(catalog, model_spec) + if matched_model is None: + phase_summaries[phase] = _phase_summary_for_unmatched( + model_spec=model_spec, + usage_totals=usage_totals, + pricing_status="no_exact_openrouter_match", + ) + continue + + ignored_components = _ignored_pricing_components(matched_model.pricing) + phase_cost = ( + usage_totals["prompt_tokens"] * matched_model.pricing.prompt + + usage_totals["completion_tokens"] * matched_model.pricing.completion + + usage_totals["request_count"] * matched_model.pricing.request + ) + priced_costs.append(phase_cost) + if match_strategy == "local_variant_suffix_stripped": + base_status = "matched_openrouter_model_after_variant_normalization" + else: + base_status = "matched_exact_openrouter_model" + phase_summaries[phase] = { + "model_spec": model_spec, + "pricing_status": ( + base_status if not ignored_components else f"{base_status}_partial" + ), + **usage_totals, + "openrouter_model_id": matched_model.model_id, + "openrouter_canonical_slug": matched_model.canonical_slug, + "openrouter_hugging_face_id": matched_model.hugging_face_id, + "openrouter_reference_cost_usd": phase_cost, + "applied_pricing_usd": { + "prompt": matched_model.pricing.prompt, + "completion": matched_model.pricing.completion, + "request": matched_model.pricing.request, + }, + "ignored_pricing_components": ignored_components, + } + + total_prompt_tokens = sum( + phase_summary["prompt_tokens"] for phase_summary in phase_summaries.values() + ) + total_completion_tokens = sum( + phase_summary["completion_tokens"] for phase_summary in phase_summaries.values() + ) + total_request_count = sum( + phase_summary["request_count"] for phase_summary in phase_summaries.values() + ) + total_reference_cost = sum(priced_costs) if priced_costs else None + + return { + "pricing_model": "openrouter_reference", + "pricing_currency": "USD", + "catalog_source_url": OPENROUTER_MODELS_URL, + "catalog_cache_path": str( + cache_path if cache_path is not None else get_openrouter_price_cache_path() + ), + "catalog_fetched_at_utc": catalog.fetched_at_utc if catalog else None, + "catalog_warning": catalog_warning, + "exact_match_policy": { + "strategy": "exact_case_insensitive", + "match_fields": ["id", "canonical_slug", "hugging_face_id"], + "fallback_normalizations": ["strip_common_local_quantization_suffixes"], + }, + "phases": phase_summaries, + "total": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + "request_count": total_request_count, + "openrouter_reference_cost_usd": total_reference_cost, + }, + } + + +def format_openrouter_reference_pricing_summary(summary: dict[str, Any]) -> str: + lines = ["OpenRouter reference pricing:"] + for phase, phase_summary in summary["phases"].items(): + phase_cost = phase_summary.get("openrouter_reference_cost_usd") + cost_str = f" | usd={phase_cost:.6f}" if phase_cost is not None else "" + lines.append( + " " + + f"{phase}: status={phase_summary['pricing_status']}" + + f" | prompt={phase_summary['prompt_tokens']}" + + f" | completion={phase_summary['completion_tokens']}" + + f" | total={phase_summary['total_tokens']}" + + cost_str + ) + total = summary["total"] + total_cost = total.get("openrouter_reference_cost_usd") + total_cost_str = f" | usd={total_cost:.6f}" if total_cost is not None else "" + lines.append( + " total:" + + f" prompt={total['prompt_tokens']}" + + f" | completion={total['completion_tokens']}" + + f" | total={total['total_tokens']}" + + total_cost_str + ) + warning = summary.get("catalog_warning") + if warning: + lines.append(f" warning: {warning}") + return "\n".join(lines) + + +def refresh_openrouter_price_catalog( + cache_path: str | Path | None = None, +) -> OpenRouterPriceCatalog: + return load_openrouter_price_catalog(refresh=True, cache_path=cache_path) + + +def _build_cli_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m judgearena.openrouter_reference_pricing", + description="Refresh or inspect the cached OpenRouter model pricing catalog.", + ) + parser.add_argument( + "--refresh", + action="store_true", + help="Force-refresh the cached OpenRouter models catalog.", + ) + parser.add_argument( + "--model", + default=None, + help="Optional local model spec to resolve against the cached catalog.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + args = _build_cli_parser().parse_args(argv) + catalog = load_openrouter_price_catalog(refresh=args.refresh) + print( + json.dumps( + { + "catalog_source_url": catalog.source_url, + "catalog_fetched_at_utc": catalog.fetched_at_utc, + "catalog_cache_path": catalog.cache_path, + "model_count": len(catalog.models), + "matched_model": ( + asdict(find_exact_openrouter_match(catalog, args.model)) + if args.model + and find_exact_openrouter_match(catalog, args.model) is not None + else None + ), + }, + indent=2, + sort_keys=True, + ) + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/judgearena/prompts/prompt-with-explanation.txt b/judgearena/prompts/prompt-with-explanation.txt index 804e766..3d9eb41 100644 --- a/judgearena/prompts/prompt-with-explanation.txt +++ b/judgearena/prompts/prompt-with-explanation.txt @@ -12,13 +12,10 @@ # Your output ## Format description -Your output should be valid JSON with exactly these keys: +Your output should follow this format: ``` -{{ - "explanation": "", - "score_A": , - "score_B": -}} +score_A: +score_B: ``` -## Your output, do not repeat the input above +## Your output, do not repeat the input above, first starts with an explanation of your judgement diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index ce7ee4e..38021e6 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -12,12 +12,10 @@ # Your output ## Format description -Your output should be valid JSON with exactly these keys: +Your output should follow this format: ``` -{{ - "score_A": , - "score_B": -}} +score_A: +score_B: ``` -## Your output, do not repeat the input above +## Your output, do not repeat the input above{explanation_suffix} diff --git a/judgearena/repro.py b/judgearena/repro.py index 9468c14..72b0059 100644 --- a/judgearena/repro.py +++ b/judgearena/repro.py @@ -231,6 +231,7 @@ def write_run_metadata( judge_system_prompt: str | None = None, judge_user_prompt_template: str | None = None, started_at_utc: datetime | None = None, + pricing_reference: dict[str, Any] | None = None, metadata_filename: str = METADATA_FILENAME, ) -> Path: """Write run metadata JSON and return the output path.""" @@ -282,6 +283,8 @@ def write_run_metadata( judge_user_prompt_template_hash = _hash_string_sha256(judge_user_prompt_template) if judge_user_prompt_template_hash: metadata["judge_user_prompt_template_sha256"] = judge_user_prompt_template_hash + if pricing_reference is not None: + metadata["pricing_reference"] = _to_jsonable(pricing_reference) metadata["artifacts"] = _collect_artifacts( output_path, metadata_filename=metadata_filename diff --git a/judgearena/utils.py b/judgearena/utils.py index c3b2ee1..1026961 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -1,12 +1,10 @@ import asyncio -import json import os import re import time import warnings from collections.abc import Callable from pathlib import Path -from typing import Any import pandas as pd from huggingface_hub import snapshot_download @@ -20,6 +18,7 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker def _data_root_path() -> Path: @@ -133,10 +132,6 @@ def safe_text(value: object, truncate_chars: int | None) -> str: _THINK_BLOCK_RE = re.compile(r".*?", re.IGNORECASE | re.DOTALL) -_JSON_CODE_FENCE_RE = re.compile( - r"```(?:json)?\s*(?P\{.*?\})\s*```", - re.IGNORECASE | re.DOTALL, -) def strip_thinking_tags(text: str | None) -> str: @@ -146,38 +141,14 @@ def strip_thinking_tags(text: str | None) -> str: return _THINK_BLOCK_RE.sub("", text) -def extract_json_object(text: str | None) -> dict[str, Any] | None: - """Best-effort JSON object extraction from model output. - - Handles raw JSON, fenced JSON blocks, and outputs that still contain leaked - reasoning text such as `...{...}` ahead of the - machine-readable payload. - """ - - cleaned = strip_thinking_tags(text).strip() - if not cleaned: - return None - - candidates = [cleaned] - fenced_match = _JSON_CODE_FENCE_RE.search(cleaned) - if fenced_match is not None: - candidates.insert(0, fenced_match.group("payload")) - - decoder = json.JSONDecoder() - for candidate in candidates: - for idx, char in enumerate(candidate): - if char != "{": - continue - try: - parsed, _end = decoder.raw_decode(candidate[idx:]) - except json.JSONDecodeError: - continue - if isinstance(parsed, dict): - return parsed - return None - - -def do_inference(chat_model, inputs, use_tqdm: bool = False): +def do_inference( + chat_model, + inputs, + use_tqdm: bool = False, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, +): # Retries on rate-limit/server errors with exponential backoff. # Async path retries individual calls; batch path splits into 4^attempt chunks on failure. invoke_kwargs = { @@ -244,6 +215,24 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): # is it because of using Chat and barebones models? # when using OpenAI, the output is AIMessage not a string... res = [x.content if hasattr(x, "content") else x for x in res] + if ( + usage_tracker is not None + and usage_phase is not None + and usage_model_spec is not None + ): + try: + usage_tracker.record_batch_from_model( + phase=usage_phase, + model_spec=usage_model_spec, + chat_model=chat_model, + inputs=list(inputs), + outputs=res, + ) + except Exception as e: + print( + f"Warning: failed to record token usage for phase " + f"'{usage_phase}' ({usage_model_spec}): {e}" + ) return res @@ -286,7 +275,6 @@ def __init__( ): from vllm import LLM, SamplingParams from vllm.config.reasoning import ReasoningConfig - from vllm.sampling_params import StructuredOutputsParams self.model_path = model self.max_tokens = max_tokens @@ -340,14 +328,10 @@ def __init__( self._sampling_params_kwargs["thinking_token_budget"] = int( thinking_token_budget ) - structured_outputs_json = vllm_kwargs.pop("structured_outputs_json", None) - if structured_outputs_json is not None: - self._sampling_params_kwargs["structured_outputs"] = ( - StructuredOutputsParams(json=structured_outputs_json) - ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) + self.tokenizer = self.llm.get_tokenizer() # Resolve chat template: # 1. Explicit override always wins → use chat() with that template @@ -363,8 +347,7 @@ def __init__( else: print(f"ChatVLLM: using explicit chat template for '{model}'") else: - tokenizer = self.llm.get_tokenizer() - if not getattr(tokenizer, "chat_template", None): + if not getattr(self.tokenizer, "chat_template", None): warnings.warn( f"Model '{model}' tokenizer does not define a chat template. " f"Falling back to llm.generate() (no chat formatting). " @@ -466,6 +449,39 @@ def batch(self, inputs: list, **invoke_kwargs) -> list[str]: ) return [out.outputs[0].text for out in outputs] + def _count_chat_prompt_tokens(self, messages: list[dict]) -> int: + tokenizer_kwargs: dict[str, object] = { + "tokenize": True, + "add_generation_prompt": True, + } + if self.chat_template is not None: + tokenizer_kwargs["chat_template"] = self.chat_template + if self._chat_template_kwargs is not None: + tokenizer_kwargs["chat_template_kwargs"] = self._chat_template_kwargs + try: + token_ids = self.tokenizer.apply_chat_template(messages, **tokenizer_kwargs) + except TypeError: + tokenizer_kwargs.pop("chat_template_kwargs", None) + token_ids = self.tokenizer.apply_chat_template(messages, **tokenizer_kwargs) + return len(token_ids) + + def count_prompt_tokens_batch(self, inputs: list) -> list[int]: + counts: list[int] = [] + for input_item in inputs: + if self._use_generate: + counts.append(len(self.tokenizer.encode(self._to_raw_text(input_item)))) + else: + counts.append( + self._count_chat_prompt_tokens(self._to_messages(input_item)) + ) + return counts + + def count_completion_tokens_batch(self, outputs: list[str]) -> list[int]: + return [ + len(self.tokenizer.encode(output, add_special_tokens=False)) + for output in outputs + ] + def invoke(self, input_item, **invoke_kwargs) -> str: """Process a single input.""" results = self.batch([input_item], **invoke_kwargs) @@ -553,6 +569,18 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): return model_cls_dict[model_provider](**engine_kwargs) +def infer_model_spec_from_instance(model: object) -> str | None: + if isinstance(model, DummyModel): + return model.name + model_path = getattr(model, "model_path", None) + if isinstance(model_path, str): + return f"VLLM/{model_path}" + model_name = getattr(model, "model_name", None) or getattr(model, "model", None) + if isinstance(model_name, str): + return f"{model.__class__.__name__}/{model_name}" + return None + + def download_all(): print(f"Downloading all dataset in {data_root}") local_path_tables = data_root / "tables" diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index 5fc2d5e..7407087 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -11,11 +11,6 @@ class FakeSamplingParams: def __init__(self, **kwargs): captured["sampling_kwargs"] = kwargs - class FakeStructuredOutputsParams: - def __init__(self, **kwargs): - captured["structured_kwargs"] = kwargs - self.kwargs = kwargs - class FakeReasoningConfig: def __init__(self, **kwargs): captured["reasoning_config_kwargs"] = kwargs @@ -44,11 +39,6 @@ def chat(self, messages, sampling_params, **kwargs): "vllm", SimpleNamespace(LLM=FakeLLM, SamplingParams=FakeSamplingParams), ) - monkeypatch.setitem( - sys.modules, - "vllm.sampling_params", - SimpleNamespace(StructuredOutputsParams=FakeStructuredOutputsParams), - ) monkeypatch.setitem( sys.modules, "vllm.config.reasoning", @@ -63,18 +53,12 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc utils.ChatVLLM( model="Qwen/Qwen3.5-27B-FP8", max_tokens=32, - structured_outputs_json={ - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False, - }, thinking_token_budget=64, gpu_memory_utilization=0.7, ) assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 - assert captured["structured_kwargs"]["json"]["type"] == "object" + assert "structured_outputs" not in captured["sampling_kwargs"] assert captured["reasoning_config_kwargs"] == { "reasoning_start_str": utils.VLLM_QWEN_REASONING_START_STR, "reasoning_end_str": utils.VLLM_QWEN_REASONING_END_STR, diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index 3e37366..c60af40 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -6,32 +6,25 @@ from judgearena.generate_and_evaluate import main as main_generate_and_eval -def test_build_pair_score_json_schema_covers_valid_range_without_explanation(): - schema = evaluate.build_pair_score_json_schema() - - assert schema["type"] == "object" - assert set(schema["required"]) == {"score_A", "score_B"} - for key in ("score_A", "score_B"): - assert schema["properties"][key]["type"] == "integer" - assert schema["properties"][key]["minimum"] == 0 - assert schema["properties"][key]["maximum"] == 10 - assert schema["additionalProperties"] is False - - -def test_build_pair_score_json_schema_covers_valid_range_with_explanation(): - schema = evaluate.build_pair_score_json_schema(include_explanation=True) - - assert schema["type"] == "object" - assert set(schema["required"]) == {"explanation", "score_A", "score_B"} - assert schema["properties"]["explanation"] == { - "type": "string", - "maxLength": evaluate._PAIR_EXPLANATION_MAX_CHARS, - } - for key in ("score_A", "score_B"): - assert schema["properties"][key]["type"] == "integer" - assert schema["properties"][key]["minimum"] == 0 - assert schema["properties"][key]["maximum"] == 10 - assert schema["additionalProperties"] is False +def test_load_judge_prompt_without_explanation_uses_freeform_scores(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=False + ) + + assert "valid JSON" not in user_prompt + assert "score_A:" in user_prompt + assert "score_B:" in user_prompt + + +def test_load_judge_prompt_with_explanation_uses_freeform_scores(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=True + ) + + assert "valid JSON" not in user_prompt + assert "first starts with an explanation of your judgement" in user_prompt + assert "score_A:" in user_prompt + assert "score_B:" in user_prompt def test_main_aligns_local_reference_by_instruction_index(tmp_path, monkeypatch): @@ -76,6 +69,9 @@ def fake_judge_and_parse_prefs( user_prompt_template, truncate_input_chars, use_tqdm, + usage_tracker, + usage_phase, + usage_model_spec, ): captured["instructions"] = instructions captured["completions_A"] = completions_A @@ -109,9 +105,7 @@ def fake_judge_and_parse_prefs( assert prefs.tolist() == [1.0, 1.0] -def test_main_passes_json_schema_and_thinking_budget_to_vllm_judge( - tmp_path, monkeypatch -): +def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): instructions = pd.DataFrame( {"instruction": ["Instruction A"]}, index=pd.Index([1], name="instruction_index"), @@ -147,7 +141,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs generate_and_evaluate, "judge_and_parse_prefs", lambda **kwargs: ( - [{"judge_completion": '{"score_A":1,"score_B":2}'}], + [{"judge_completion": "score_A: 1\nscore_B: 2"}], None, pd.Series([1.0]), ), @@ -165,13 +159,11 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs ) assert prefs.tolist() == [1.0] - assert captured["make_model"]["kwargs"]["structured_outputs_json"] == ( - evaluate.build_pair_score_json_schema() - ) + assert "structured_outputs_json" not in captured["make_model"]["kwargs"] assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 -def test_main_passes_explanation_schema_to_vllm_judge_when_requested( +def test_main_passes_thinking_budget_to_vllm_judge_when_explanation_requested( tmp_path, monkeypatch ): instructions = pd.DataFrame( @@ -203,7 +195,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs generate_and_evaluate, "judge_and_parse_prefs", lambda **kwargs: ( - [{"judge_completion": '{"explanation":"ok","score_A":1,"score_B":2}'}], + [{"judge_completion": "Explanation: ok\nscore_A: 1\nscore_B: 2"}], None, pd.Series([1.0]), ), @@ -222,9 +214,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs ) assert prefs.tolist() == [1.0] - assert captured["make_model"]["structured_outputs_json"] == ( - evaluate.build_pair_score_json_schema(include_explanation=True) - ) + assert "structured_outputs_json" not in captured["make_model"] assert captured["make_model"]["thinking_token_budget"] == 512 @@ -233,7 +223,15 @@ def test_annotate_battles_warns_when_judge_completions_are_truncated( ): captured = {} - def fake_do_inference(*, chat_model, inputs, use_tqdm): + def fake_do_inference( + *, + chat_model, + inputs, + use_tqdm, + usage_tracker, + usage_phase, + usage_model_spec, + ): captured["judge_prompt"] = inputs[0].to_messages()[1].content return ["score_A: 0\nscore_B: 10"] @@ -255,5 +253,7 @@ def fake_do_inference(*, chat_model, inputs, use_tqdm): assert "Ans" in captured["judge_prompt"] assert "Answer A" not in captured["judge_prompt"] assert "Answer B" not in captured["judge_prompt"] + assert "valid JSON" not in captured["judge_prompt"] + assert "score_A:" in captured["judge_prompt"] assert annotations[0].completion_A == "Answer A" assert annotations[0].completion_B == "Answer B" diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index bc6f724..7925070 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -128,6 +128,8 @@ def fake_generate_multiturn( max_model_len, chat_template, temperature_config, + usage_tracker, + usage_phase, **engine_kwargs, ): generated_models.append(model) @@ -173,6 +175,7 @@ def fake_generate_multiturn( args=args, questions_df=questions_df, ignore_cache=False, + usage_tracker=object(), ) assert generated_models == ["VLLM/example/model-a"] @@ -184,38 +187,30 @@ def fake_generate_multiturn( assert completions_b.loc[2, "completion_turn_2"] == "Base B2" -def test_parse_fastchat_verdict_accepts_plain_structured_labels(): - assert fastchat_compat._parse_fastchat_verdict("A") == "A" - assert fastchat_compat._parse_fastchat_verdict("B") == "B" - assert fastchat_compat._parse_fastchat_verdict("C") == "tie" - - -def test_parse_fastchat_verdict_accepts_json_and_strips_thinking(): - assert ( - fastchat_compat._parse_fastchat_verdict( - 'Need a longer chain.{"explanation":"done","verdict":"B"}' - ) - == "B" - ) +def test_parse_fastchat_verdict_accepts_bracketed_verdicts_after_thinking(): assert ( fastchat_compat._parse_fastchat_verdict( - '```json\n{"explanation":"tie","verdict":"C"}\n```' - ) - == "tie" - ) - assert ( - fastchat_compat._parse_fastchat_verdict( - 'unfinished analysis {"explanation":"cut short","verdict":"A"' + "Need a longer chain.[[A]]" ) == "A" ) + assert fastchat_compat._parse_fastchat_verdict("[[B]]") == "B" + assert fastchat_compat._parse_fastchat_verdict("[[C]]") == "tie" + +def test_parse_fastchat_verdict_marks_non_bracketed_outputs_as_error(): + assert fastchat_compat._parse_fastchat_verdict("A") == "error" + assert fastchat_compat._parse_fastchat_verdict('{"verdict":"B"}') == "error" -def test_pair_v2_system_prompt_omits_explanation_when_disabled(): - rendered = fastchat_compat._PAIR_V2.render_system_prompt(provide_explanation=False) - assert "provide a short explanation" not in rendered - assert 'exactly one key: "verdict"' in rendered +def test_pair_v2_system_prompt_matches_original_fastchat_contract(): + rendered = fastchat_compat._PAIR_V2.render_system_prompt(provide_explanation=True) + + assert "provide a short explanation" in rendered + assert "valid JSON" not in rendered + assert '"[[A]]"' in rendered + assert '"[[B]]"' in rendered + assert '"[[C]]"' in rendered def test_conservative_winner_marks_one_sided_parse_failures_as_error(): @@ -231,36 +226,6 @@ def test_conservative_winner_marks_one_sided_parse_failures_as_error(): assert fastchat_compat._conservative_winner("model_A", "model_B") == ("tie", True) -def test_build_mt_bench_verdict_json_schema_without_explanation(): - schema = mt_bench_utils.build_mt_bench_verdict_json_schema() - - assert schema["type"] == "object" - assert set(schema["required"]) == {"verdict"} - assert schema["properties"] == { - "verdict": { - "type": "string", - "enum": ["A", "B", "C"], - } - } - assert schema["additionalProperties"] is False - - -def test_build_mt_bench_verdict_json_schema_with_explanation(): - schema = mt_bench_utils.build_mt_bench_verdict_json_schema(include_explanation=True) - - assert schema["type"] == "object" - assert set(schema["required"]) == {"explanation", "verdict"} - assert schema["properties"]["explanation"] == { - "type": "string", - "maxLength": mt_bench_utils._MT_BENCH_EXPLANATION_MAX_CHARS, - } - assert schema["properties"]["verdict"] == { - "type": "string", - "enum": ["A", "B", "C"], - } - assert schema["additionalProperties"] is False - - def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): questions_df = pd.DataFrame( {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, @@ -276,7 +241,7 @@ def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): monkeypatch.setattr( mt_bench_utils, "_generate_mt_bench_completions", - lambda args, questions_df, ignore_cache: ( + lambda args, questions_df, ignore_cache, usage_tracker: ( pd.DataFrame( { "completion_turn_1": ["A1"], @@ -342,13 +307,14 @@ def fake_run_mt_bench_fastchat(**kwargs): mt_bench_utils.run_mt_bench(args, ignore_cache=False) assert args.swap_mode == "both" - assert args.max_out_tokens_judge == 2048 - assert captured["make_model"]["max_tokens"] == 2048 + assert args.max_out_tokens_judge == 24576 + assert args.max_model_len == 28672 + assert captured["make_model"]["max_tokens"] == 24576 + assert captured["make_model"]["max_model_len"] == 28672 assert captured["make_model"]["kwargs"] == { "gpu_memory_utilization": 0.7, "language_model_only": True, - "structured_outputs_json": mt_bench_utils.build_mt_bench_verdict_json_schema(), "thinking_token_budget": 512, } assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" - assert "constrained_plain_verdict" not in captured["run_mt_bench_fastchat"] + assert captured["run_mt_bench_fastchat"]["provide_explanation"] is True diff --git a/tests/test_openrouter_reference_pricing.py b/tests/test_openrouter_reference_pricing.py new file mode 100644 index 0000000..1dbe0c2 --- /dev/null +++ b/tests/test_openrouter_reference_pricing.py @@ -0,0 +1,209 @@ +import json + +import judgearena.openrouter_reference_pricing as pricing +from judgearena.repro import write_run_metadata +from judgearena.utils import do_inference + + +class CountingModel: + def batch(self, inputs, **invoke_kwargs): + return [f"output-{idx}" for idx, _input in enumerate(inputs)] + + def count_prompt_tokens_batch(self, inputs): + return [len(str(input_item)) for input_item in inputs] + + def count_completion_tokens_batch(self, outputs): + return [len(output) for output in outputs] + + +def test_do_inference_records_token_usage(): + tracker = pricing.OpenRouterReferencePricingTracker() + model = CountingModel() + + outputs = do_inference( + chat_model=model, + inputs=["abc", "de"], + usage_tracker=tracker, + usage_phase="judge", + usage_model_spec="VLLM/org/model", + ) + + assert outputs == ["output-0", "output-1"] + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="VLLM/org/model", + prompt_tokens=3, + completion_tokens=8, + requests=1, + ), + pricing.TokenUsageRecord( + phase="judge", + model_spec="VLLM/org/model", + prompt_tokens=2, + completion_tokens=8, + requests=1, + ), + ] + + +def test_build_reference_pricing_summary_uses_exact_match_and_reports_partial_cost( + monkeypatch, +): + catalog = pricing.parse_openrouter_catalog_payload( + { + "data": [ + { + "id": "openrouter/example-model", + "canonical_slug": "openrouter/example-model", + "hugging_face_id": "Org/Example-Model", + "name": "Example Model", + "pricing": { + "prompt": "0.001", + "completion": "0.002", + "request": "0.01", + "internal_reasoning": "0.5", + }, + } + ], + "fetched_at_utc": "2026-04-07T00:00:00+00:00", + } + ) + monkeypatch.setattr( + pricing, + "load_openrouter_price_catalog_with_fallback", + lambda **kwargs: (catalog, None), + ) + + tracker = pricing.OpenRouterReferencePricingTracker() + tracker._records.extend( + [ + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Org/Example-Model", + prompt_tokens=100, + completion_tokens=20, + ), + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Org/Example-Model", + prompt_tokens=50, + completion_tokens=5, + ), + pricing.TokenUsageRecord( + phase="generation_model_B", + model_spec="VLLM/No/Match", + prompt_tokens=10, + completion_tokens=2, + ), + ] + ) + + summary = pricing.build_openrouter_reference_pricing_summary( + tracker=tracker, + phase_model_specs={ + "generation_model_A": "VLLM/Org/Example-Model", + "generation_model_B": "VLLM/No/Match", + "judge": "VLLM/No/Runtime", + }, + ) + + matched = summary["phases"]["generation_model_A"] + assert matched["openrouter_model_id"] == "openrouter/example-model" + assert matched["pricing_status"] == "matched_exact_openrouter_model_partial" + assert matched["prompt_tokens"] == 150 + assert matched["completion_tokens"] == 25 + assert matched["request_count"] == 2 + assert matched["openrouter_reference_cost_usd"] == 0.22 + assert matched["ignored_pricing_components"] == ["internal_reasoning"] + + unmatched = summary["phases"]["generation_model_B"] + assert unmatched["pricing_status"] == "no_exact_openrouter_match" + assert unmatched["openrouter_reference_cost_usd"] is None + + no_runtime = summary["phases"]["judge"] + assert no_runtime["pricing_status"] == "no_runtime_token_data" + assert no_runtime["total_tokens"] == 0 + + assert summary["total"]["openrouter_reference_cost_usd"] == 0.22 + + +def test_build_reference_pricing_summary_matches_quantized_local_variant( + monkeypatch, +): + catalog = pricing.parse_openrouter_catalog_payload( + { + "data": [ + { + "id": "qwen/qwen3.5-27b", + "canonical_slug": "qwen/qwen3.5-27b", + "hugging_face_id": "Qwen/Qwen3.5-27B", + "name": "Qwen: Qwen3.5-27B", + "pricing": { + "prompt": "0.001", + "completion": "0.002", + "request": "0.01", + }, + } + ], + "fetched_at_utc": "2026-04-15T00:00:00+00:00", + } + ) + monkeypatch.setattr( + pricing, + "load_openrouter_price_catalog_with_fallback", + lambda **kwargs: (catalog, None), + ) + + tracker = pricing.OpenRouterReferencePricingTracker() + tracker._records.append( + pricing.TokenUsageRecord( + phase="generation_model_A", + model_spec="VLLM/Qwen/Qwen3.5-27B-FP8", + prompt_tokens=10, + completion_tokens=5, + ) + ) + + summary = pricing.build_openrouter_reference_pricing_summary( + tracker=tracker, + phase_model_specs={ + "generation_model_A": "VLLM/Qwen/Qwen3.5-27B-FP8", + "judge": "VLLM/No/Runtime", + }, + ) + + matched = summary["phases"]["generation_model_A"] + assert ( + matched["pricing_status"] + == "matched_openrouter_model_after_variant_normalization" + ) + assert matched["openrouter_model_id"] == "qwen/qwen3.5-27b" + assert matched["openrouter_reference_cost_usd"] == 0.03 + assert summary["exact_match_policy"]["fallback_normalizations"] == [ + "strip_common_local_quantization_suffixes" + ] + + +def test_write_run_metadata_includes_pricing_reference(tmp_path, monkeypatch): + monkeypatch.setattr( + "judgearena.repro._get_dependency_versions", + lambda *args, **kwargs: {}, + ) + monkeypatch.setattr("judgearena.repro._get_git_hash", lambda *args, **kwargs: None) + + metadata_path = write_run_metadata( + output_dir=tmp_path, + entrypoint="judgearena.test", + run={"dataset": "alpaca-eval"}, + pricing_reference={ + "pricing_model": "openrouter_reference", + "total": {"openrouter_reference_cost_usd": 1.23}, + }, + ) + + metadata = json.loads(metadata_path.read_text()) + assert metadata["pricing_reference"]["pricing_model"] == "openrouter_reference" + assert ( + metadata["pricing_reference"]["total"]["openrouter_reference_cost_usd"] == 1.23 + ) diff --git a/tests/test_regexp.py b/tests/test_regexp.py index 0f7a868..23af4d5 100644 --- a/tests/test_regexp.py +++ b/tests/test_regexp.py @@ -40,14 +40,16 @@ def test_regexp(): print(pref) -def test_pair_score_prefers_json_scores_over_reasoning_text(): +def test_pair_score_ignores_scores_inside_thinking_tags(): raw_text = """ - I would score assistant A as 2/10 if I stopped early. - { - "explanation": "At first glance I might score assistant A as 2, but after comparing both answers carefully, assistant B is better.", - "score_A": 0, - "score_B": 10 - } + + Early draft: + score_A: 2 + score_B: 1 + + Explanation: Assistant B is clearly better overall. + score_A: 0 + score_B: 10 """ scorer = PairScore() @@ -55,3 +57,26 @@ def test_pair_score_prefers_json_scores_over_reasoning_text(): assert pref is not None assert pref == 0.9525741268224333 + + +def test_pair_score_falls_back_to_bracketed_verdicts(): + scorer = PairScore() + + assert scorer.parse_model_raw("Explanation: ok\n[[A]]") == 0.0 + assert scorer.parse_model_raw("Explanation: ok\n[[B]]") == 1.0 + assert scorer.parse_model_raw("Explanation: ok\n[[C]]") == 0.5 + + +def test_pair_score_ignores_thinking_tags_before_bracketed_verdict(): + raw_text = """ + + score_A: 0 + score_B: 10 + + Concise verdict only. + [[B]] + """ + + scorer = PairScore() + + assert scorer.parse_model_raw(raw_text) == 1.0 From 84faa057ca4754991ad7fdc8f97118725649e794 Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Fri, 17 Apr 2026 14:06:38 +0200 Subject: [PATCH 12/28] revert unnecessary changes and relics from earlier trials --- README.md | 27 - judgearena/generate_and_evaluate.py | 40 +- judgearena/mt_bench/fastchat_compat.py | 65 +-- judgearena/mt_bench/mt_bench_utils.py | 18 +- judgearena/openrouter_reference_pricing.py | 11 + .../prompts/prompt-with-explanation.txt | 8 +- judgearena/slurm_costs.py | 481 ------------------ judgearena/utils.py | 220 ++++++-- .../launch_kislurm_qwen35_smoke.py | 65 +++ tests/test_chat_vllm.py | 44 ++ tests/test_local_completion_loading.py | 115 ++--- tests/test_mt_bench_downloads.py | 4 +- tests/test_slurm_costs.py | 77 --- 13 files changed, 392 insertions(+), 783 deletions(-) delete mode 100644 judgearena/slurm_costs.py create mode 100644 slurmpilot_scripts/launch_kislurm_qwen35_smoke.py delete mode 100644 tests/test_slurm_costs.py diff --git a/README.md b/README.md index c7840ce..508ac9f 100644 --- a/README.md +++ b/README.md @@ -82,25 +82,6 @@ The evaluation scripts expose four different length controls with different role - `--max_out_tokens_judge`: generation token budget for the judge completion (reasoning + score output). - `--max_model_len`: optional vLLM context-window limit (prompt + generated tokens), applied to vLLM models; this should be greater than or equal to the two `max_out_tokens_*` values. -### OpenRouter Reference Pricing For Local Runs - -JudgeArena can estimate an `openrouter_reference_cost_usd` for local runs by combining: - -- locally counted prompt and completion tokens -- OpenRouter's public model pricing from `GET /api/v1/models` - -This is a reference price, not actual billed spend from either OpenRouter or your cluster. - -Reference pricing is only applied when the local model has an exact OpenRouter match using one of: - -- the OpenRouter model `id` -- the OpenRouter `canonical_slug` -- the model `hugging_face_id` - -If no exact match exists, JudgeArena still records token totals but leaves the reference price unset. - -The aggregated pricing summary is printed to stdout and stored in `run-metadata.v1.json` under `pricing_reference`. - ### Engine-Specific Configuration (`--engine_kwargs`) Some providers expose additional engine-level knobs (for example, vLLM allows configuring tensor parallelism or GPU memory utilization). @@ -292,14 +273,6 @@ Datasets are stored in: - `$JUDGEARENA_DATA` if set; otherwise `$OPENJURY_DATA` if set (legacy) - `~/judgearena-data/` if neither variable is set -If compute nodes do not have internet access, refresh the cached OpenRouter price book on the login node before launching jobs: - -```bash -uv run python -m judgearena.openrouter_reference_pricing --refresh -``` - -The benchmark launcher in `slurmpilot_scripts/launch_benchmark_eval.py` also attempts to warm this cache automatically. The cache is stored under `$JUDGEARENA_DATA/reference_pricing/openrouter_models.json` unless `JUDGEARENA_OPENROUTER_PRICE_CACHE` overrides it. - ## 🛠️ Development To maintain code quality, we use **pre-commit** hooks. Run this once to set them up: diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index df58fa3..16aca13 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -35,46 +35,23 @@ download_hf, make_model, read_df, + should_default_thinking_token_budget, ) def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None ) -> pd.DataFrame | None: - """Try loading pre-existing completions from the dataset or a local file. + """Try loading pre-existing completions from the dataset. Some datasets (e.g. alpaca-eval) ship with completions for well-known models such as ``gpt4_1106_preview``. When ``model`` matches a column in ``model_outputs/{dataset}.csv.zip``, those completions are returned directly so that no model instantiation / generation is needed. - ``model`` may also be a local dataframe path. Local files must contain - ``instruction_index`` and ``output`` columns. - Returns a DataFrame with columns ``completion`` and ``instruction_index``, or ``None`` when no pre-existing completions are found. """ - local_path = Path(model) - if local_path.exists(): - print(f"Loading completions from local path '{local_path}'.") - df_outputs = read_df(local_path) - required_columns = {"instruction_index", "output"} - missing_columns = required_columns.difference(df_outputs.columns) - if missing_columns: - missing_columns_list = ", ".join(sorted(missing_columns)) - raise ValueError( - f"Local completion file '{local_path}' is missing required columns: " - f"{missing_columns_list}." - ) - - df_outputs = df_outputs.loc[:, ["instruction_index", "output"]].rename( - columns={"output": "completion"} - ) - df_outputs.loc[:, "completion"] = df_outputs.loc[:, "completion"].fillna("") - if n_instructions is not None: - df_outputs = df_outputs.head(n_instructions) - return df_outputs - local_path_tables = data_root / "tables" if is_arena_hard_dataset(dataset): download_arena_hard(dataset=dataset, local_tables_path=local_path_tables) @@ -304,9 +281,16 @@ def main(args: CliArgs): judge_model_kwargs = dict(args.engine_kwargs) if args.judge_model.split("/")[0] == "VLLM": - judge_model_kwargs.setdefault( - "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET - ) + judge_model_name = "/".join(args.judge_model.split("/")[1:]) + if ( + "thinking_token_budget" not in judge_model_kwargs + and should_default_thinking_token_budget( + judge_model_name, judge_model_kwargs + ) + ): + judge_model_kwargs["thinking_token_budget"] = ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) judge_chat_model = make_model( model=args.judge_model, diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 18fcd5d..5bf23f5 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -3,7 +3,6 @@ from __future__ import annotations import math -import re from dataclasses import dataclass from pathlib import Path from typing import Any, Literal @@ -41,22 +40,10 @@ @dataclass(frozen=True) class FastChatPairwisePrompt: name: str - user_subject: str - task_description: str - begin_instruction: str + system_prompt: str user_prompt_template: str multi_turn: bool ref_based: bool - focus_line: str = "" - - def render_system_prompt(self, *, provide_explanation: bool) -> str: - return _build_system_prompt( - user_subject=self.user_subject, - task_description=self.task_description, - begin_instruction=self.begin_instruction, - focus_line=self.focus_line, - provide_explanation=provide_explanation, - ) _PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" @@ -76,41 +63,25 @@ def _render_prompt_text(filename: str, **kwargs: str) -> str: return _load_prompt_text(filename).format(**kwargs) -def _begin_instruction_for_mode( - begin_instruction: str, *, provide_explanation: bool -) -> str: - if provide_explanation: - return begin_instruction - return re.sub(r"\s+and provide a short explanation$", "", begin_instruction) - - def _build_system_prompt( *, user_subject: str, task_description: str, begin_instruction: str, focus_line: str = "", - provide_explanation: bool, ) -> str: focus_segment = f"{focus_line} " if focus_line else "" - output_format_instruction = ( - "After providing your explanation, output your final verdict by strictly " - 'following this format: "[[A]]" if assistant A is better, "[[B]]" if ' - 'assistant B is better, and "[[C]]" for a tie.' - if provide_explanation - else "Output your final verdict by strictly following this format: " - '"[[A]]" if assistant A is better, "[[B]]" if assistant B is better, ' - 'and "[[C]]" for a tie.' - ) return _render_prompt_text( _SYSTEM_BASE_FILE, user_subject=user_subject, task_description=task_description, focus_line=focus_segment, - begin_instruction=_begin_instruction_for_mode( - begin_instruction, provide_explanation=provide_explanation + begin_instruction=begin_instruction, + output_format_instruction=( + "After providing your explanation, output your final verdict by strictly following this format: " + '"[[A]]" if assistant A is better, "[[B]]" if assistant B is better, ' + 'and "[[C]]" for a tie.' ), - output_format_instruction=output_format_instruction, ) @@ -137,16 +108,18 @@ def _load_pairwise_prompt( ) -> FastChatPairwisePrompt: return FastChatPairwisePrompt( name=name, - user_subject=system_user_subject, - task_description=system_task_description, - begin_instruction=system_begin_instruction, multi_turn=multi_turn, ref_based=ref_based, + system_prompt=_build_system_prompt( + user_subject=system_user_subject, + task_description=system_task_description, + begin_instruction=system_begin_instruction, + focus_line=system_focus_line, + ), user_prompt_template=_build_user_prompt_template( multi_turn=multi_turn, ref_based=ref_based, ), - focus_line=system_focus_line, ) @@ -301,7 +274,6 @@ def _infer_by_prompt_groups( items: list[dict[str, Any]], use_tqdm: bool, swap_answers: bool, - provide_explanation: bool, usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, usage_model_spec: str | None = None, @@ -312,11 +284,8 @@ def _infer_by_prompt_groups( judgments: list[str] = [""] * len(items) for _prompt_name, idxs in grouped_indices.items(): prompt: FastChatPairwisePrompt = items[idxs[0]]["prompt"] - system_prompt = prompt.render_system_prompt( - provide_explanation=provide_explanation - ) prompt_template = ChatPromptTemplate.from_messages( - [("system", system_prompt), ("user", prompt.user_prompt_template)] + [("system", prompt.system_prompt), ("user", prompt.user_prompt_template)] ) batch_kwargs = [] @@ -411,14 +380,12 @@ def _resolve_fastchat_item_result( judge_model: str, model_a: str, model_b: str, - provide_explanation: bool, ) -> tuple[dict[str, Any], dict[str, object], float, bool]: prompt: FastChatPairwisePrompt = item["prompt"] kwargs = item["prompt_kwargs"] g1_user_prompt = prompt.user_prompt_template.format(**kwargs) g1_verdict = _parse_fastchat_verdict(g1_raw) g1_winner = _map_verdict_to_winner(g1_verdict, swapped=False) - system_prompt = prompt.render_system_prompt(provide_explanation=provide_explanation) final_winner = g1_winner inconsistent = False @@ -430,7 +397,7 @@ def _resolve_fastchat_item_result( "model_B": model_b, "judge": judge_model, "prompt_name": prompt.name, - "system_prompt": system_prompt, + "system_prompt": prompt.system_prompt, "g1_user_prompt": g1_user_prompt, "g1_judgment": g1_raw, "g1_verdict": g1_verdict, @@ -480,7 +447,6 @@ def judge_mt_bench_pairwise_fastchat( swap_mode: str, truncate_input_chars: int | None, use_tqdm: bool, - provide_explanation: bool = False, usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: @@ -505,7 +471,6 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=False, - provide_explanation=provide_explanation, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=judge_model, @@ -518,7 +483,6 @@ def judge_mt_bench_pairwise_fastchat( items=items, use_tqdm=use_tqdm, swap_answers=True, - provide_explanation=provide_explanation, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=judge_model, @@ -539,7 +503,6 @@ def judge_mt_bench_pairwise_fastchat( judge_model=judge_model, model_a=model_a, model_b=model_b, - provide_explanation=provide_explanation, ) ) if inconsistent: diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 2a9d290..e86310c 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -34,6 +34,7 @@ cache_function_dataframe, compute_pref_summary, make_model, + should_default_thinking_token_budget, ) if TYPE_CHECKING: @@ -166,7 +167,6 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, - provide_explanation: bool, usage_tracker: OpenRouterReferencePricingTracker, started_at_utc: datetime, ) -> pd.Series: @@ -184,7 +184,6 @@ def _run_mt_bench_fastchat( swap_mode=args.swap_mode, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, - provide_explanation=provide_explanation, usage_tracker=usage_tracker, usage_phase="judge", ) @@ -271,9 +270,17 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): ) args.max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN judge_model_kwargs = dict(args.engine_kwargs) - judge_model_kwargs.setdefault( - "thinking_token_budget", DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET - ) + if args.judge_model.split("/")[0] == "VLLM": + judge_model_name = "/".join(args.judge_model.split("/")[1:]) + if ( + "thinking_token_budget" not in judge_model_kwargs + and should_default_thinking_token_budget( + judge_model_name, judge_model_kwargs + ) + ): + judge_model_kwargs["thinking_token_budget"] = ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) judge_chat_model = make_model( model=args.judge_model, @@ -289,7 +296,6 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, - provide_explanation=True, usage_tracker=usage_tracker, started_at_utc=run_started_at, ) diff --git a/judgearena/openrouter_reference_pricing.py b/judgearena/openrouter_reference_pricing.py index 0463f14..cf13820 100644 --- a/judgearena/openrouter_reference_pricing.py +++ b/judgearena/openrouter_reference_pricing.py @@ -1,3 +1,14 @@ +"""Reference pricing utilities for local JudgeArena runs. + +This module counts local prompt/completion tokens and, when an exact +OpenRouter model match exists, attaches a comparable public-price estimate. +Refresh the cached catalog on a machine with internet access via +`uv run python -m judgearena.openrouter_reference_pricing --refresh`. +By default the cache lives under +`$JUDGEARENA_DATA/reference_pricing/openrouter_models.json`, unless +`JUDGEARENA_OPENROUTER_PRICE_CACHE` overrides it. +""" + from __future__ import annotations import argparse diff --git a/judgearena/prompts/prompt-with-explanation.txt b/judgearena/prompts/prompt-with-explanation.txt index 3d9eb41..6600f51 100644 --- a/judgearena/prompts/prompt-with-explanation.txt +++ b/judgearena/prompts/prompt-with-explanation.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's {completion_label}|> +<|The Start of Assistant A's Answer|> {completion_A} -<|The End of Assistant A's {completion_label}|> +<|The End of Assistant A's Answer|> -<|The Start of Assistant B's {completion_label}|> +<|The Start of Assistant B's Answer|> {completion_B} -<|The End of Assistant B's {completion_label}|> +<|The End of Assistant B's Answer|> # Your output diff --git a/judgearena/slurm_costs.py b/judgearena/slurm_costs.py deleted file mode 100644 index 2f78726..0000000 --- a/judgearena/slurm_costs.py +++ /dev/null @@ -1,481 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import subprocess -from collections.abc import Iterable -from dataclasses import asdict, dataclass -from pathlib import Path - -SACCT_FIELDS = ( - "JobID,JobName%100,Partition,Account,State,ElapsedRaw,Elapsed," - "AllocCPUS,AllocNodes,AllocTRES%100,ReqTRES%100" -) -RATE_METRIC_CHOICES = ( - "wall_hours", - "cpu_hours", - "gpu_hours", - "billing_hours", - "node_hours", -) - - -@dataclass(frozen=True) -class JobSource: - job_id: int - label: str - - -@dataclass(frozen=True) -class SacctAllocation: - allocation_id: str - root_job_id: int - job_name: str - partition: str - account: str - state: str - elapsed_seconds: int - elapsed: str - alloc_cpus: float - alloc_nodes: float - alloc_tres: dict[str, str] - req_tres: dict[str, str] - - -@dataclass(frozen=True) -class JobCostSummary: - job_id: int - label: str - partition: str - account: str - states: list[str] - allocation_count: int - wall_hours: float - cpu_hours: float - gpu_hours: float - billing_hours: float - node_hours: float - estimated_cost: float | None = None - - -def parse_tres_map(tres_spec: str) -> dict[str, str]: - values: dict[str, str] = {} - if not tres_spec: - return values - for raw_entry in tres_spec.split(","): - entry = raw_entry.strip() - if not entry or "=" not in entry: - continue - key, value = entry.split("=", 1) - values[key.strip()] = value.strip() - return values - - -def parse_elapsed_seconds(elapsed: str) -> int: - if not elapsed: - return 0 - n_days = 0 - time_part = elapsed - if "-" in elapsed: - days_part, time_part = elapsed.split("-", 1) - n_days = int(days_part) - hours_str, minutes_str, seconds_str = time_part.split(":") - return ( - n_days * 86400 - + int(hours_str) * 3600 - + int(minutes_str) * 60 - + int(seconds_str) - ) - - -def parse_sacct_allocations(sacct_output: str) -> list[SacctAllocation]: - allocations: list[SacctAllocation] = [] - for raw_line in sacct_output.splitlines(): - line = raw_line.strip() - if not line: - continue - parts = line.split("|") - if len(parts) != 11: - raise ValueError(f"Unexpected sacct row with {len(parts)} fields: {line}") - allocation_id = parts[0].strip() - root_job_text = allocation_id.split("_", 1)[0] - allocations.append( - SacctAllocation( - allocation_id=allocation_id, - root_job_id=int(root_job_text), - job_name=parts[1].strip(), - partition=parts[2].strip(), - account=parts[3].strip(), - state=parts[4].strip(), - elapsed_seconds=int(parts[5] or "0"), - elapsed=parts[6].strip(), - alloc_cpus=float(parts[7] or "0"), - alloc_nodes=float(parts[8] or "0"), - alloc_tres=parse_tres_map(parts[9]), - req_tres=parse_tres_map(parts[10]), - ) - ) - return allocations - - -def query_sacct_allocations(job_ids: Iterable[int]) -> list[SacctAllocation]: - unique_job_ids = [ - str(job_id) for job_id in dict.fromkeys(int(job_id) for job_id in job_ids) - ] - if not unique_job_ids: - return [] - try: - result = subprocess.run( - [ - "sacct", - "-X", - "--allocations", - "--parsable2", - "--noheader", - f"--format={SACCT_FIELDS}", - f"--jobs={','.join(unique_job_ids)}", - ], - check=False, - capture_output=True, - text=True, - ) - except FileNotFoundError as exc: - raise RuntimeError( - "Could not find `sacct`; run this on a machine with Slurm." - ) from exc - if result.returncode != 0: - message = ( - result.stderr.strip() or result.stdout.strip() or "unknown sacct error" - ) - raise RuntimeError(f"sacct failed: {message}") - return parse_sacct_allocations(result.stdout) - - -def load_job_source_from_path(job_path: str | Path) -> JobSource: - path = Path(job_path) - job_dir = path.parent if path.name == "jobid.json" else path - jobid_path = job_dir / "jobid.json" - if not jobid_path.is_file(): - raise FileNotFoundError(f"Missing jobid.json in {job_dir}") - job_id = int(json.loads(jobid_path.read_text())["jobid"]) - metadata_path = job_dir / "metadata.json" - if metadata_path.is_file(): - metadata = json.loads(metadata_path.read_text()) - label = str(metadata.get("jobname") or job_dir.name) - else: - label = job_dir.name - return JobSource(job_id=job_id, label=label) - - -def resolve_job_sources( - *, - job_ids: Iterable[int] | None = None, - job_paths: Iterable[str | Path] | None = None, -) -> list[JobSource]: - sources: dict[int, JobSource] = {} - ordered_ids: list[int] = [] - - for job_id in job_ids or []: - normalized_job_id = int(job_id) - if normalized_job_id in sources: - continue - sources[normalized_job_id] = JobSource( - job_id=normalized_job_id, - label=str(normalized_job_id), - ) - ordered_ids.append(normalized_job_id) - - for job_path in job_paths or []: - source = load_job_source_from_path(job_path) - if source.job_id not in sources: - ordered_ids.append(source.job_id) - sources[source.job_id] = source - - return [sources[job_id] for job_id in ordered_ids] - - -def _tres_quantity(tres_map: dict[str, str], key: str) -> float: - raw_value = tres_map.get(key) - if raw_value is None: - return 0.0 - numeric_chars: list[str] = [] - for char in raw_value: - if char.isdigit() or char in {".", "-"}: - numeric_chars.append(char) - continue - break - numeric_text = "".join(numeric_chars) - return float(numeric_text) if numeric_text else 0.0 - - -def summarize_job_costs( - sources: list[JobSource], - allocations: list[SacctAllocation], - *, - rate_metric: str | None = None, - hourly_rate: float | None = None, -) -> list[JobCostSummary]: - allocations_by_job_id: dict[int, list[SacctAllocation]] = { - source.job_id: [] for source in sources - } - for allocation in allocations: - if allocation.root_job_id in allocations_by_job_id: - allocations_by_job_id[allocation.root_job_id].append(allocation) - - missing_job_ids = [ - str(source.job_id) - for source in sources - if not allocations_by_job_id[source.job_id] - ] - if missing_job_ids: - raise RuntimeError( - "No sacct allocation rows returned for job IDs: " - + ", ".join(missing_job_ids) - ) - - summaries: list[JobCostSummary] = [] - for source in sources: - job_allocations = allocations_by_job_id[source.job_id] - wall_hours = sum(row.elapsed_seconds for row in job_allocations) / 3600.0 - cpu_hours = ( - sum(row.elapsed_seconds * row.alloc_cpus for row in job_allocations) - / 3600.0 - ) - gpu_hours = ( - sum( - row.elapsed_seconds * _tres_quantity(row.alloc_tres, "gres/gpu") - for row in job_allocations - ) - / 3600.0 - ) - billing_hours = ( - sum( - row.elapsed_seconds * _tres_quantity(row.alloc_tres, "billing") - for row in job_allocations - ) - / 3600.0 - ) - node_hours = ( - sum(row.elapsed_seconds * row.alloc_nodes for row in job_allocations) - / 3600.0 - ) - metric_value = ( - _summary_metric_value( - wall_hours=wall_hours, - cpu_hours=cpu_hours, - gpu_hours=gpu_hours, - billing_hours=billing_hours, - node_hours=node_hours, - rate_metric=rate_metric, - ) - if hourly_rate is not None - else None - ) - summaries.append( - JobCostSummary( - job_id=source.job_id, - label=source.label, - partition=",".join( - sorted({row.partition for row in job_allocations if row.partition}) - ), - account=",".join( - sorted({row.account for row in job_allocations if row.account}) - ), - states=sorted({row.state for row in job_allocations if row.state}), - allocation_count=len(job_allocations), - wall_hours=wall_hours, - cpu_hours=cpu_hours, - gpu_hours=gpu_hours, - billing_hours=billing_hours, - node_hours=node_hours, - estimated_cost=( - metric_value * hourly_rate - if metric_value is not None and hourly_rate is not None - else None - ), - ) - ) - return summaries - - -def _summary_metric_value( - *, - wall_hours: float, - cpu_hours: float, - gpu_hours: float, - billing_hours: float, - node_hours: float, - rate_metric: str | None, -) -> float: - metrics = { - "wall_hours": wall_hours, - "cpu_hours": cpu_hours, - "gpu_hours": gpu_hours, - "billing_hours": billing_hours, - "node_hours": node_hours, - } - if rate_metric is None: - raise ValueError("rate_metric must be set when hourly_rate is provided") - return metrics[rate_metric] - - -def total_summary( - summaries: list[JobCostSummary], *, hourly_rate: float | None = None -) -> JobCostSummary: - return JobCostSummary( - job_id=0, - label="TOTAL", - partition=",".join( - sorted({summary.partition for summary in summaries if summary.partition}) - ), - account=",".join( - sorted({summary.account for summary in summaries if summary.account}) - ), - states=sorted({state for summary in summaries for state in summary.states}), - allocation_count=sum(summary.allocation_count for summary in summaries), - wall_hours=sum(summary.wall_hours for summary in summaries), - cpu_hours=sum(summary.cpu_hours for summary in summaries), - gpu_hours=sum(summary.gpu_hours for summary in summaries), - billing_hours=sum(summary.billing_hours for summary in summaries), - node_hours=sum(summary.node_hours for summary in summaries), - estimated_cost=( - sum(summary.estimated_cost or 0.0 for summary in summaries) - if hourly_rate is not None - else None - ), - ) - - -def _format_float(value: float) -> str: - return f"{value:.3f}" - - -def _format_cost(value: float, currency: str) -> str: - return f"{currency} {value:.2f}" - - -def _tabular_rows( - summaries: list[JobCostSummary], *, currency: str, include_cost: bool -) -> list[dict[str, str]]: - rows: list[dict[str, str]] = [] - for summary in summaries: - row = { - "job": summary.label, - "job_id": str(summary.job_id), - "tasks": str(summary.allocation_count), - "state": ",".join(summary.states), - "gpu_h": _format_float(summary.gpu_hours), - "billing_h": _format_float(summary.billing_hours), - "cpu_h": _format_float(summary.cpu_hours), - "wall_h": _format_float(summary.wall_hours), - } - if include_cost and summary.estimated_cost is not None: - row["cost"] = _format_cost(summary.estimated_cost, currency) - rows.append(row) - return rows - - -def render_table(rows: list[dict[str, str]]) -> str: - if not rows: - return "" - headers = list(rows[0].keys()) - widths = {header: len(header) for header in headers} - for row in rows: - for header in headers: - widths[header] = max(widths[header], len(row[header])) - header_line = " ".join(f"{header:<{widths[header]}}" for header in headers) - separator_line = " ".join("-" * widths[header] for header in headers) - row_lines = [ - " ".join(f"{row[header]:<{widths[header]}}" for header in headers) - for row in rows - ] - return "\n".join([header_line, separator_line, *row_lines]) - - -def _summary_to_dict(summary: JobCostSummary) -> dict[str, object]: - return asdict(summary) - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - prog="python -m judgearena.slurm_costs", - description="Summarize Slurm job usage and simple cost estimates.", - ) - parser.add_argument( - "--job-id", - action="append", - type=int, - default=[], - help="Root Slurm job ID to summarize. Repeatable.", - ) - parser.add_argument( - "--job-path", - action="append", - default=[], - help="Path to a slurmpilot job directory or its jobid.json. Repeatable.", - ) - parser.add_argument( - "--rate-metric", - choices=RATE_METRIC_CHOICES, - default="gpu_hours", - help="Metric used for the optional hourly rate conversion.", - ) - parser.add_argument( - "--hourly-rate", - type=float, - default=None, - help="Optional hourly rate applied to --rate-metric.", - ) - parser.add_argument( - "--currency", - default="EUR", - help="Currency label for the optional cost estimate.", - ) - parser.add_argument( - "--json", - action="store_true", - help="Print JSON instead of a text table.", - ) - return parser - - -def main(argv: list[str] | None = None) -> int: - args = build_parser().parse_args(argv) - sources = resolve_job_sources(job_ids=args.job_id, job_paths=args.job_path) - if not sources: - raise SystemExit("Provide at least one --job-id or --job-path.") - allocations = query_sacct_allocations(source.job_id for source in sources) - summaries = summarize_job_costs( - sources, - allocations, - rate_metric=args.rate_metric, - hourly_rate=args.hourly_rate, - ) - total = total_summary(summaries, hourly_rate=args.hourly_rate) - if args.json: - payload = { - "jobs": [_summary_to_dict(summary) for summary in summaries], - "total": _summary_to_dict(total), - "rate_metric": args.rate_metric if args.hourly_rate is not None else None, - "hourly_rate": args.hourly_rate, - "currency": args.currency if args.hourly_rate is not None else None, - } - print(json.dumps(payload, indent=2, sort_keys=True)) - return 0 - - table_rows = _tabular_rows( - [*summaries, total], - currency=args.currency, - include_cost=args.hourly_rate is not None, - ) - print(render_table(table_rows)) - if args.hourly_rate is None: - print( - "\nNo hourly rate was provided. Pass --hourly-rate together with " - "--rate-metric to convert one of the reported hour metrics into money." - ) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/judgearena/utils.py b/judgearena/utils.py index 1026961..d0a98af 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -4,6 +4,7 @@ import time import warnings from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path import pandas as pd @@ -37,6 +38,129 @@ def _data_root_path() -> Path: ) +@dataclass(frozen=True) +class ReasoningModelDefaults: + reasoning_parser: str + reasoning_config_kwargs: dict[str, str] | None = None + enabled_chat_template_kwargs: dict[str, object] | None = None + disabled_chat_template_kwargs: dict[str, object] | None = None + + +_REASONING_MODEL_DEFAULTS: tuple[ + tuple[tuple[str, ...], ReasoningModelDefaults], ... +] = ( + ( + ("qwen3",), + ReasoningModelDefaults( + reasoning_parser="qwen3", + reasoning_config_kwargs={ + "reasoning_start_str": VLLM_QWEN_REASONING_START_STR, + "reasoning_end_str": VLLM_QWEN_REASONING_END_STR, + }, + disabled_chat_template_kwargs={"enable_thinking": False}, + ), + ), + (("qwq-32b",), ReasoningModelDefaults(reasoning_parser="deepseek_r1")), + ( + ("deepseek-r1", "r1-distill"), + ReasoningModelDefaults(reasoning_parser="deepseek_r1"), + ), + ( + ("deepseek-v3.1",), + ReasoningModelDefaults( + reasoning_parser="deepseek_v3", + enabled_chat_template_kwargs={"thinking": True}, + disabled_chat_template_kwargs={"thinking": False}, + ), + ), + ( + ("ernie-4.5", "ernie4.5"), + ReasoningModelDefaults(reasoning_parser="ernie45"), + ), + (("glm-4.5", "glm4.5"), ReasoningModelDefaults(reasoning_parser="glm45")), + ( + ("holo2",), + ReasoningModelDefaults( + reasoning_parser="holo2", + disabled_chat_template_kwargs={"thinking": False}, + ), + ), + ( + ("hunyuan-a13b",), + ReasoningModelDefaults(reasoning_parser="hunyuan_a13b"), + ), + ( + ("granite-3.2",), + ReasoningModelDefaults( + reasoning_parser="granite", + enabled_chat_template_kwargs={"thinking": True}, + disabled_chat_template_kwargs={"thinking": False}, + ), + ), + ( + ("minimax-m2",), + ReasoningModelDefaults(reasoning_parser="minimax_m2_append_think"), + ), +) + + +def get_reasoning_model_defaults(model_name: str) -> ReasoningModelDefaults | None: + """Return JudgeArena's explicit reasoning defaults for known model families.""" + normalized = model_name.lower() + for markers, defaults in _REASONING_MODEL_DEFAULTS: + if any(marker in normalized for marker in markers): + return defaults + return None + + +def should_default_thinking_token_budget( + model_name: str, vllm_kwargs: dict[str, object] +) -> bool: + """Return True when JudgeArena should auto-apply a thinking-token budget.""" + return ( + get_reasoning_model_defaults(model_name) is not None + or "reasoning_parser" in vllm_kwargs + or "reasoning_config" in vllm_kwargs + ) + + +def _resolve_chat_template_kwargs( + *, + explicit_chat_template_kwargs: dict[str, object] | None, + reasoning_defaults: ReasoningModelDefaults | None, + enable_reasoning: bool, + disable_thinking: bool, +) -> dict[str, object] | None: + chat_template_kwargs = dict(explicit_chat_template_kwargs or {}) + explicit_keys = set(chat_template_kwargs) + + if enable_reasoning and not disable_thinking and reasoning_defaults is not None: + for key, value in ( + reasoning_defaults.enabled_chat_template_kwargs or {} + ).items(): + chat_template_kwargs.setdefault(key, value) + + if disable_thinking: + disabled_defaults = ( + reasoning_defaults.disabled_chat_template_kwargs + if reasoning_defaults is not None + else {"enable_thinking": False} + ) + for key, value in disabled_defaults.items(): + if key not in explicit_keys: + chat_template_kwargs[key] = value + + return chat_template_kwargs or None + + +def _attach_provider_metadata(model_instance: object, provider_name: str) -> object: + try: + model_instance._judgearena_provider = provider_name + except Exception: + pass + return model_instance + + def set_langchain_cache(): set_llm_cache(SQLiteCache(database_path=str(data_root / ".langchain.db"))) @@ -280,8 +404,19 @@ def __init__( self.max_tokens = max_tokens disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) thinking_token_budget = vllm_kwargs.pop("thinking_token_budget", None) - self._chat_template_kwargs = ( - {"enable_thinking": False} if disable_thinking else None + explicit_chat_template_kwargs = vllm_kwargs.pop("chat_template_kwargs", None) + reasoning_defaults = get_reasoning_model_defaults(model) + explicit_reasoning_settings = ( + "reasoning_parser" in vllm_kwargs or "reasoning_config" in vllm_kwargs + ) + enable_reasoning = ( + explicit_reasoning_settings or thinking_token_budget is not None + ) + self._chat_template_kwargs = _resolve_chat_template_kwargs( + explicit_chat_template_kwargs=explicit_chat_template_kwargs, + reasoning_defaults=reasoning_defaults, + enable_reasoning=enable_reasoning, + disable_thinking=disable_thinking, ) # Cap max_model_len to the model's max_position_embeddings so that @@ -314,20 +449,28 @@ def __init__( "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } if thinking_token_budget is not None: - if "qwen3" in model.lower(): - vllm_kwargs.setdefault( - "reasoning_config", - ReasoningConfig( - reasoning_start_str=VLLM_QWEN_REASONING_START_STR, - reasoning_end_str=VLLM_QWEN_REASONING_END_STR, - ), + if reasoning_defaults is None and not explicit_reasoning_settings: + warnings.warn( + f"Model '{model}' is not in JudgeArena's supported reasoning-family map. " + "Ignoring thinking_token_budget unless reasoning_parser or " + "reasoning_config is provided explicitly.", + stacklevel=2, ) - vllm_kwargs.setdefault("reasoning_parser", "qwen3") else: - vllm_kwargs.setdefault("reasoning_config", ReasoningConfig()) - self._sampling_params_kwargs["thinking_token_budget"] = int( - thinking_token_budget - ) + if reasoning_defaults is not None: + vllm_kwargs.setdefault( + "reasoning_parser", reasoning_defaults.reasoning_parser + ) + if reasoning_defaults.reasoning_config_kwargs is not None: + vllm_kwargs.setdefault( + "reasoning_config", + ReasoningConfig( + **reasoning_defaults.reasoning_config_kwargs + ), + ) + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) @@ -340,12 +483,7 @@ def __init__( if chat_template: self.chat_template = chat_template self._use_generate = False - if disable_thinking: - print( - f"ChatVLLM: using explicit chat template with thinking disabled for '{model}'" - ) - else: - print(f"ChatVLLM: using explicit chat template for '{model}'") + print(f"ChatVLLM: using explicit chat template for '{model}'") else: if not getattr(self.tokenizer, "chat_template", None): warnings.warn( @@ -365,12 +503,7 @@ def __init__( else: self.chat_template = None self._use_generate = False - if disable_thinking: - print( - f"ChatVLLM: using tokenizer chat template with thinking disabled for '{model}'" - ) - else: - print(f"ChatVLLM: using tokenizer's chat template for '{model}'") + print(f"ChatVLLM: using tokenizer's chat template for '{model}'") def set_temperature(self, temperature: float) -> None: from vllm import SamplingParams @@ -517,7 +650,7 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): model_provider = model.split("/")[0] if model_provider == "Dummy": - return DummyModel(model) + return _attach_provider_metadata(DummyModel(model), model_provider) model_name = "/".join(model.split("/")[1:]) print(f"Loading {model_provider}(model={model_name})") @@ -527,18 +660,24 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): engine_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} engine_kwargs["chat_template"] = engine_kwargs.get("chat_template", None) - return ChatVLLM( - model=model_name, - **engine_kwargs, + return _attach_provider_metadata( + ChatVLLM( + model=model_name, + **engine_kwargs, + ), + model_provider, ) if model_provider == "OpenRouter": # Special case we need to override API url and key - return ChatOpenAI( - api_key=os.getenv("OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1", - model=model_name, - **engine_kwargs, + return _attach_provider_metadata( + ChatOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1", + model=model_name, + **engine_kwargs, + ), + model_provider, ) else: model_classes = [ @@ -566,17 +705,24 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): assert model_provider in model_cls_dict, ( f"{model_provider} not available, choose among {list(model_cls_dict.keys())}" ) - return model_cls_dict[model_provider](**engine_kwargs) + return _attach_provider_metadata( + model_cls_dict[model_provider](**engine_kwargs), model_provider + ) def infer_model_spec_from_instance(model: object) -> str | None: if isinstance(model, DummyModel): return model.name + provider_name = getattr(model, "_judgearena_provider", None) model_path = getattr(model, "model_path", None) if isinstance(model_path, str): - return f"VLLM/{model_path}" + if isinstance(provider_name, str): + return f"{provider_name}/{model_path}" + return model_path model_name = getattr(model, "model_name", None) or getattr(model, "model", None) if isinstance(model_name, str): + if isinstance(provider_name, str): + return f"{provider_name}/{model_name}" return f"{model.__class__.__name__}/{model_name}" return None diff --git a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py new file mode 100644 index 0000000..86704c4 --- /dev/null +++ b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py @@ -0,0 +1,65 @@ +from pathlib import Path + +from slurmpilot import JobCreationInfo, SlurmPilot, unify + +CLUSTER = "kislurm" +REMOTE_PROJECT_ROOT = Path("/work/dlclarge1/lushtake-hiwi/JudgeArena") +LOCAL_PROJECT_ROOT = Path(__file__).resolve().parent.parent +PYTHON_BINARY = REMOTE_PROJECT_ROOT / ".venv" / "bin" / "python" +ENTRYPOINT = "generate_and_evaluate.py" +SRC_DIR = str(LOCAL_PROJECT_ROOT / "judgearena") + +# Use L40S partitions from the all_dlc / ml_dlc families. +PARTITION_ALL_DLC_L40S = "testdlc2_gpu-l40s" +PARTITION_ML_DLC_L40S = "mldlc2_gpu-l40s" + +# Same weights as `VLLM/Qwen/Qwen3.5-27B-FP8`; repo-id loading fails offline in vLLM +# without a resolved revision — point at the HF hub snapshot dir under `HF_HOME`. +QWEN35_27B_FP8_SNAPSHOT = ( + "/work/dlclarge1/lushtake-hiwi/.cache/huggingface/hub/" + "models--Qwen--Qwen3.5-27B-FP8/snapshots/" + "2e1b21350ce589fcaafbb3c7d7eac526a7aed582" +) +JUDGE_MODEL = f"VLLM//{QWEN35_27B_FP8_SNAPSHOT.lstrip('/')}" + + +def submit_smoke_job(partition: str = PARTITION_ALL_DLC_L40S) -> tuple[str, str, int]: + slurm = SlurmPilot(clusters=[CLUSTER]) + dataset = "alpaca-eval" + jobname = unify("qwen3.5-smoke/judgearena-canonical", method="date") + + job_info = JobCreationInfo( + cluster=CLUSTER, + partition=partition, + jobname=jobname, + entrypoint=ENTRYPOINT, + python_binary=str(PYTHON_BINARY), + python_args={ + "dataset": dataset, + "model_A": "Dummy/no_answer", + "model_B": "Dummy/open_answer", + "judge_model": JUDGE_MODEL, + "n_instructions": 1, + "max_out_tokens_judge": 64, + }, + src_dir=SRC_DIR, + n_cpus=1, + max_runtime_minutes=20, + env={ + "HF_HUB_OFFLINE": "1", + # Ensure Hugging Face uses the shared cache location that + # already contains the Qwen3.5 FP8 checkpoint. + "HF_HOME": "/work/dlclarge1/lushtake-hiwi", + "JUDGEARENA_DATA": "/work/dlclarge1/lushtake-hiwi/judgearena-data", + }, + ) + job_id = slurm.schedule_job(job_info) + print(f"Submitted {dataset}: jobname={job_info.jobname}, job_id={job_id}") + return dataset, job_info.jobname, job_id + + +if __name__ == "__main__": + print(f"Using LOCAL_PROJECT_ROOT={LOCAL_PROJECT_ROOT}") + print(f"Using REMOTE_PROJECT_ROOT={REMOTE_PROJECT_ROOT}") + print(f"Using PYTHON_BINARY={PYTHON_BINARY}") + submit_smoke_job(partition=PARTITION_ALL_DLC_L40S) diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index 7407087..6621ee3 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -1,6 +1,8 @@ import sys from types import SimpleNamespace +import pytest + import judgearena.utils as utils @@ -83,3 +85,45 @@ def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch) assert captured["chat_call"]["kwargs"]["chat_template_kwargs"] == { "enable_thinking": False } + + +def test_chat_vllm_enables_family_specific_chat_template_kwargs(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + chat_model = utils.ChatVLLM( + model="deepseek-ai/DeepSeek-V3.1", + max_tokens=16, + thinking_token_budget=32, + gpu_memory_utilization=0.7, + ) + + outputs = chat_model.batch(["hello"]) + + assert outputs == ["ok"] + assert captured["sampling_kwargs"]["thinking_token_budget"] == 32 + assert captured["llm_init"]["kwargs"]["reasoning_parser"] == "deepseek_v3" + assert captured["chat_call"]["kwargs"]["chat_template_kwargs"] == {"thinking": True} + + +def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + with pytest.warns(UserWarning, match="supported reasoning-family map"): + utils.ChatVLLM( + model="meta-llama/Llama-3.3-70B-Instruct", + max_tokens=32, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert "thinking_token_budget" not in captured["sampling_kwargs"] + assert "reasoning_parser" not in captured["llm_init"]["kwargs"] + assert "reasoning_config" not in captured["llm_init"]["kwargs"] + + +def test_infer_model_spec_uses_attached_provider_name(): + model = SimpleNamespace( + _judgearena_provider="LlamaCpp", + model_path="./models/model.gguf", + ) + + assert utils.infer_model_spec_from_instance(model) == "LlamaCpp/./models/model.gguf" diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index c60af40..91f87dc 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -27,85 +27,67 @@ def test_load_judge_prompt_with_explanation_uses_freeform_scores(): assert "score_B:" in user_prompt -def test_main_aligns_local_reference_by_instruction_index(tmp_path, monkeypatch): +def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): instructions = pd.DataFrame( - {"instruction": ["Instruction B", "Instruction A"]}, - index=pd.Index(["b", "a"], name="instruction_index"), + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), ) - reference_path = tmp_path / "m-arena-hard-en-reference.csv" - pd.DataFrame( - { - "instruction_index": ["a", "b"], - "output": ["Answer A", "Answer B"], - } - ).to_csv(reference_path, index=False) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} monkeypatch.setattr( generate_and_evaluate, "load_instructions", - lambda dataset, n_instructions=None: ( - instructions.head(n_instructions) - if n_instructions is not None - else instructions - ), + lambda dataset, n_instructions=None: instructions, ) monkeypatch.setattr( generate_and_evaluate, - "cache_function_dataframe", - lambda fun, **_kwargs: fun(), + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, ) - captured = {} - - def fake_judge_and_parse_prefs( - *, - judge_chat_model, - instructions, - completions_A, - completions_B, - swap_mode, - provide_explanation, - system_prompt, - user_prompt_template, - truncate_input_chars, - use_tqdm, - usage_tracker, - usage_phase, - usage_model_spec, - ): - captured["instructions"] = instructions - captured["completions_A"] = completions_A - captured["completions_B"] = completions_B - annotations = [{"judge_completion": "score A: 0 score B: 10"}] * len( - instructions - ) - prefs = pd.Series([1.0] * len(instructions)) - return annotations, [], prefs + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = { + "model": model, + "max_tokens": max_tokens, + "max_model_len": max_model_len, + "chat_template": chat_template, + "kwargs": kwargs, + } + return object() + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) monkeypatch.setattr( generate_and_evaluate, "judge_and_parse_prefs", - fake_judge_and_parse_prefs, + lambda **kwargs: ( + [{"judge_completion": "score_A: 1\nscore_B: 2"}], + None, + pd.Series([1.0]), + ), ) prefs = main_generate_and_eval( CliArgs( - dataset="m-arena-hard-en", - model_A="Dummy/no answer", - model_B=str(reference_path), - judge_model="Dummy/score A: 0 score B: 10", - n_instructions=2, + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, result_folder=str(tmp_path / "results"), ) ) - assert captured["instructions"] == ["Instruction B", "Instruction A"] - assert captured["completions_A"] == ["no answer", "no answer"] - assert captured["completions_B"] == ["Answer B", "Answer A"] - assert prefs.tolist() == [1.0, 1.0] + assert prefs.tolist() == [1.0] + assert "structured_outputs_json" not in captured["make_model"]["kwargs"] + assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 -def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): +def test_main_passes_thinking_budget_to_vllm_judge_when_explanation_requested( + tmp_path, monkeypatch +): instructions = pd.DataFrame( {"instruction": ["Instruction A"]}, index=pd.Index([1], name="instruction_index"), @@ -127,13 +109,7 @@ def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): ) def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): - captured["make_model"] = { - "model": model, - "max_tokens": max_tokens, - "max_model_len": max_model_len, - "chat_template": chat_template, - "kwargs": kwargs, - } + captured["make_model"] = kwargs return object() monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) @@ -141,7 +117,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs generate_and_evaluate, "judge_and_parse_prefs", lambda **kwargs: ( - [{"judge_completion": "score_A: 1\nscore_B: 2"}], + [{"judge_completion": "Explanation: ok\nscore_A: 1\nscore_B: 2"}], None, pd.Series([1.0]), ), @@ -154,16 +130,17 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs model_B="Dummy/model-b", judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", n_instructions=1, + provide_explanation=True, result_folder=str(tmp_path / "results"), ) ) assert prefs.tolist() == [1.0] - assert "structured_outputs_json" not in captured["make_model"]["kwargs"] - assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 + assert "structured_outputs_json" not in captured["make_model"] + assert captured["make_model"]["thinking_token_budget"] == 512 -def test_main_passes_thinking_budget_to_vllm_judge_when_explanation_requested( +def test_main_does_not_pass_thinking_budget_to_non_reasoning_vllm_judge( tmp_path, monkeypatch ): instructions = pd.DataFrame( @@ -195,7 +172,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs generate_and_evaluate, "judge_and_parse_prefs", lambda **kwargs: ( - [{"judge_completion": "Explanation: ok\nscore_A: 1\nscore_B: 2"}], + [{"judge_completion": "score_A: 1\nscore_B: 2"}], None, pd.Series([1.0]), ), @@ -206,16 +183,14 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs dataset="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", - judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", n_instructions=1, - provide_explanation=True, result_folder=str(tmp_path / "results"), ) ) assert prefs.tolist() == [1.0] - assert "structured_outputs_json" not in captured["make_model"] - assert captured["make_model"]["thinking_token_budget"] == 512 + assert "thinking_token_budget" not in captured["make_model"] def test_annotate_battles_warns_when_judge_completions_are_truncated( diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 7925070..7f454a7 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -204,7 +204,7 @@ def test_parse_fastchat_verdict_marks_non_bracketed_outputs_as_error(): def test_pair_v2_system_prompt_matches_original_fastchat_contract(): - rendered = fastchat_compat._PAIR_V2.render_system_prompt(provide_explanation=True) + rendered = fastchat_compat._PAIR_V2.system_prompt assert "provide a short explanation" in rendered assert "valid JSON" not in rendered @@ -317,4 +317,4 @@ def fake_run_mt_bench_fastchat(**kwargs): "thinking_token_budget": 512, } assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" - assert captured["run_mt_bench_fastchat"]["provide_explanation"] is True + assert "provide_explanation" not in captured["run_mt_bench_fastchat"] diff --git a/tests/test_slurm_costs.py b/tests/test_slurm_costs.py deleted file mode 100644 index 94f173d..0000000 --- a/tests/test_slurm_costs.py +++ /dev/null @@ -1,77 +0,0 @@ -import json - -import pytest - -from judgearena.slurm_costs import ( - JobSource, - _tres_quantity, - load_job_source_from_path, - parse_elapsed_seconds, - parse_sacct_allocations, - parse_tres_map, - resolve_job_sources, - summarize_job_costs, - total_summary, -) - - -def test_parse_tres_map_preserves_entries_and_extracts_numeric_quantities(): - tres_map = parse_tres_map("billing=2,cpu=2,gres/gpu=1,mem=125G,node=1") - - assert tres_map["mem"] == "125G" - assert _tres_quantity(tres_map, "billing") == 2.0 - assert _tres_quantity(tres_map, "gres/gpu") == 1.0 - - -def test_parse_elapsed_seconds_supports_day_prefix(): - assert parse_elapsed_seconds("1-02:03:04") == 93784 - - -def test_load_job_source_from_path_uses_metadata_jobname(tmp_path): - job_dir = tmp_path / "bench" / "alpaca-eval-2026-04-06-16-25-10" - job_dir.mkdir(parents=True) - (job_dir / "jobid.json").write_text(json.dumps({"jobid": 28707665})) - (job_dir / "metadata.json").write_text( - json.dumps({"jobname": "bench/alpaca-eval-2026-04-06-16-25-10"}) - ) - - source = load_job_source_from_path(job_dir) - resolved = resolve_job_sources(job_ids=[28707665], job_paths=[job_dir]) - - assert source == JobSource( - job_id=28707665, - label="bench/alpaca-eval-2026-04-06-16-25-10", - ) - assert resolved == [source] - - -def test_summarize_job_costs_aggregates_job_arrays_and_rate_conversion(): - sacct_output = "\n".join( - [ - "28707665_0|bench/alpaca-eval|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|60|00:01:00|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", - "28707665_1|bench/alpaca-eval|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|90|00:01:30|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", - "28708344_0|bench/arena-hard|mldlc2_gpu-l40s|ml-dlc2|COMPLETED|120|00:02:00|2|1|billing=2,cpu=2,gres/gpu=1,node=1|billing=2,cpu=2,gres/gpu=1,node=1", - ] - ) - allocations = parse_sacct_allocations(sacct_output) - sources = [ - JobSource(job_id=28707665, label="bench/alpaca-eval"), - JobSource(job_id=28708344, label="bench/arena-hard"), - ] - - summaries = summarize_job_costs( - sources, - allocations, - rate_metric="gpu_hours", - hourly_rate=3.5, - ) - total = total_summary(summaries, hourly_rate=3.5) - - assert summaries[0].allocation_count == 2 - assert summaries[0].gpu_hours == pytest.approx(150 / 3600) - assert summaries[0].billing_hours == pytest.approx(300 / 3600) - assert summaries[0].estimated_cost == pytest.approx((150 / 3600) * 3.5) - assert summaries[1].gpu_hours == pytest.approx(120 / 3600) - assert total.gpu_hours == pytest.approx(270 / 3600) - assert total.cpu_hours == pytest.approx(540 / 3600) - assert total.estimated_cost == pytest.approx((270 / 3600) * 3.5) From c063f3dca27bb485cc2da92fbc0daea1c0a20c37 Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Fri, 17 Apr 2026 14:18:41 +0200 Subject: [PATCH 13/28] delete slurmpilot script --- .../launch_kislurm_qwen35_smoke.py | 65 ------------------- 1 file changed, 65 deletions(-) delete mode 100644 slurmpilot_scripts/launch_kislurm_qwen35_smoke.py diff --git a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py b/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py deleted file mode 100644 index 86704c4..0000000 --- a/slurmpilot_scripts/launch_kislurm_qwen35_smoke.py +++ /dev/null @@ -1,65 +0,0 @@ -from pathlib import Path - -from slurmpilot import JobCreationInfo, SlurmPilot, unify - -CLUSTER = "kislurm" -REMOTE_PROJECT_ROOT = Path("/work/dlclarge1/lushtake-hiwi/JudgeArena") -LOCAL_PROJECT_ROOT = Path(__file__).resolve().parent.parent -PYTHON_BINARY = REMOTE_PROJECT_ROOT / ".venv" / "bin" / "python" -ENTRYPOINT = "generate_and_evaluate.py" -SRC_DIR = str(LOCAL_PROJECT_ROOT / "judgearena") - -# Use L40S partitions from the all_dlc / ml_dlc families. -PARTITION_ALL_DLC_L40S = "testdlc2_gpu-l40s" -PARTITION_ML_DLC_L40S = "mldlc2_gpu-l40s" - -# Same weights as `VLLM/Qwen/Qwen3.5-27B-FP8`; repo-id loading fails offline in vLLM -# without a resolved revision — point at the HF hub snapshot dir under `HF_HOME`. -QWEN35_27B_FP8_SNAPSHOT = ( - "/work/dlclarge1/lushtake-hiwi/.cache/huggingface/hub/" - "models--Qwen--Qwen3.5-27B-FP8/snapshots/" - "2e1b21350ce589fcaafbb3c7d7eac526a7aed582" -) -JUDGE_MODEL = f"VLLM//{QWEN35_27B_FP8_SNAPSHOT.lstrip('/')}" - - -def submit_smoke_job(partition: str = PARTITION_ALL_DLC_L40S) -> tuple[str, str, int]: - slurm = SlurmPilot(clusters=[CLUSTER]) - dataset = "alpaca-eval" - jobname = unify("qwen3.5-smoke/judgearena-canonical", method="date") - - job_info = JobCreationInfo( - cluster=CLUSTER, - partition=partition, - jobname=jobname, - entrypoint=ENTRYPOINT, - python_binary=str(PYTHON_BINARY), - python_args={ - "dataset": dataset, - "model_A": "Dummy/no_answer", - "model_B": "Dummy/open_answer", - "judge_model": JUDGE_MODEL, - "n_instructions": 1, - "max_out_tokens_judge": 64, - }, - src_dir=SRC_DIR, - n_cpus=1, - max_runtime_minutes=20, - env={ - "HF_HUB_OFFLINE": "1", - # Ensure Hugging Face uses the shared cache location that - # already contains the Qwen3.5 FP8 checkpoint. - "HF_HOME": "/work/dlclarge1/lushtake-hiwi", - "JUDGEARENA_DATA": "/work/dlclarge1/lushtake-hiwi/judgearena-data", - }, - ) - job_id = slurm.schedule_job(job_info) - print(f"Submitted {dataset}: jobname={job_info.jobname}, job_id={job_id}") - return dataset, job_info.jobname, job_id - - -if __name__ == "__main__": - print(f"Using LOCAL_PROJECT_ROOT={LOCAL_PROJECT_ROOT}") - print(f"Using REMOTE_PROJECT_ROOT={REMOTE_PROJECT_ROOT}") - print(f"Using PYTHON_BINARY={PYTHON_BINARY}") - submit_smoke_job(partition=PARTITION_ALL_DLC_L40S) From ec7fc95d08d2ea601ba5c78aa118b9184d45970c Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Fri, 17 Apr 2026 14:24:03 +0200 Subject: [PATCH 14/28] Revert comment removal --- judgearena/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/judgearena/utils.py b/judgearena/utils.py index d0a98af..6471990 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -501,7 +501,7 @@ def __init__( stacklevel=2, ) else: - self.chat_template = None + self.chat_template = None # let vLLM use the tokenizer's own self._use_generate = False print(f"ChatVLLM: using tokenizer's chat template for '{model}'") From 20ca9a54b029aa27f328ca418f2294d18d4021df Mon Sep 17 00:00:00 2001 From: Erlis Lushtaku <59629249+ErlisLushtaku@users.noreply.github.com> Date: Fri, 17 Apr 2026 16:19:01 +0200 Subject: [PATCH 15/28] simplify and revert unnecessary changes --- judgearena/generate_and_evaluate.py | 18 +- judgearena/mt_bench/fastchat_compat.py | 5 - judgearena/mt_bench/mt_bench_utils.py | 18 +- judgearena/prompts/mt_bench/system-base.txt | 2 +- judgearena/utils.py | 218 ++++++-------------- tests/test_chat_vllm.py | 61 ++++-- tests/test_local_completion_loading.py | 58 ++++++ tests/test_mt_bench_downloads.py | 3 + 8 files changed, 176 insertions(+), 207 deletions(-) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 16aca13..a392a09 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -28,14 +28,13 @@ ) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( - DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, + build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, data_root, download_hf, make_model, read_df, - should_default_thinking_token_budget, ) @@ -279,18 +278,9 @@ def main(args: CliArgs): print(completions_B.values[0]) print(f"Evaluating completions with judge {args.judge_model}.") - judge_model_kwargs = dict(args.engine_kwargs) - if args.judge_model.split("/")[0] == "VLLM": - judge_model_name = "/".join(args.judge_model.split("/")[1:]) - if ( - "thinking_token_budget" not in judge_model_kwargs - and should_default_thinking_token_budget( - judge_model_name, judge_model_kwargs - ) - ): - judge_model_kwargs["thinking_token_budget"] = ( - DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET - ) + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, args.engine_kwargs + ) judge_chat_model = make_model( model=args.judge_model, diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 5bf23f5..254d108 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -77,11 +77,6 @@ def _build_system_prompt( task_description=task_description, focus_line=focus_segment, begin_instruction=begin_instruction, - output_format_instruction=( - "After providing your explanation, output your final verdict by strictly following this format: " - '"[[A]]" if assistant A is better, "[[B]]" if assistant B is better, ' - 'and "[[C]]" for a tie.' - ), ) diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index e86310c..106812c 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -30,11 +30,10 @@ ) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( - DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET, + build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, make_model, - should_default_thinking_token_budget, ) if TYPE_CHECKING: @@ -269,18 +268,9 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN} for the judge." ) args.max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN - judge_model_kwargs = dict(args.engine_kwargs) - if args.judge_model.split("/")[0] == "VLLM": - judge_model_name = "/".join(args.judge_model.split("/")[1:]) - if ( - "thinking_token_budget" not in judge_model_kwargs - and should_default_thinking_token_budget( - judge_model_name, judge_model_kwargs - ) - ): - judge_model_kwargs["thinking_token_budget"] = ( - DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET - ) + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, args.engine_kwargs + ) judge_chat_model = make_model( model=args.judge_model, diff --git a/judgearena/prompts/mt_bench/system-base.txt b/judgearena/prompts/mt_bench/system-base.txt index 7cf5120..b4aff2e 100644 --- a/judgearena/prompts/mt_bench/system-base.txt +++ b/judgearena/prompts/mt_bench/system-base.txt @@ -1 +1 @@ -Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. {output_format_instruction} +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie. diff --git a/judgearena/utils.py b/judgearena/utils.py index 6471990..f544a4d 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -4,7 +4,6 @@ import time import warnings from collections.abc import Callable -from dataclasses import dataclass from pathlib import Path import pandas as pd @@ -38,129 +37,46 @@ def _data_root_path() -> Path: ) -@dataclass(frozen=True) -class ReasoningModelDefaults: - reasoning_parser: str - reasoning_config_kwargs: dict[str, str] | None = None - enabled_chat_template_kwargs: dict[str, object] | None = None - disabled_chat_template_kwargs: dict[str, object] | None = None - - -_REASONING_MODEL_DEFAULTS: tuple[ - tuple[tuple[str, ...], ReasoningModelDefaults], ... -] = ( - ( - ("qwen3",), - ReasoningModelDefaults( - reasoning_parser="qwen3", - reasoning_config_kwargs={ - "reasoning_start_str": VLLM_QWEN_REASONING_START_STR, - "reasoning_end_str": VLLM_QWEN_REASONING_END_STR, - }, - disabled_chat_template_kwargs={"enable_thinking": False}, - ), - ), - (("qwq-32b",), ReasoningModelDefaults(reasoning_parser="deepseek_r1")), - ( - ("deepseek-r1", "r1-distill"), - ReasoningModelDefaults(reasoning_parser="deepseek_r1"), - ), - ( - ("deepseek-v3.1",), - ReasoningModelDefaults( - reasoning_parser="deepseek_v3", - enabled_chat_template_kwargs={"thinking": True}, - disabled_chat_template_kwargs={"thinking": False}, - ), - ), - ( - ("ernie-4.5", "ernie4.5"), - ReasoningModelDefaults(reasoning_parser="ernie45"), - ), - (("glm-4.5", "glm4.5"), ReasoningModelDefaults(reasoning_parser="glm45")), - ( - ("holo2",), - ReasoningModelDefaults( - reasoning_parser="holo2", - disabled_chat_template_kwargs={"thinking": False}, - ), - ), - ( - ("hunyuan-a13b",), - ReasoningModelDefaults(reasoning_parser="hunyuan_a13b"), - ), - ( - ("granite-3.2",), - ReasoningModelDefaults( - reasoning_parser="granite", - enabled_chat_template_kwargs={"thinking": True}, - disabled_chat_template_kwargs={"thinking": False}, - ), - ), - ( - ("minimax-m2",), - ReasoningModelDefaults(reasoning_parser="minimax_m2_append_think"), - ), -) +def _split_model_spec(model_spec: str) -> tuple[str, str]: + provider, sep, model_name = model_spec.partition("/") + if not sep: + return model_spec, "" + return provider, model_name -def get_reasoning_model_defaults(model_name: str) -> ReasoningModelDefaults | None: - """Return JudgeArena's explicit reasoning defaults for known model families.""" - normalized = model_name.lower() - for markers, defaults in _REASONING_MODEL_DEFAULTS: - if any(marker in normalized for marker in markers): - return defaults - return None +def is_qwen_reasoning_model(model_name: str) -> bool: + return "qwen3" in model_name.lower() -def should_default_thinking_token_budget( - model_name: str, vllm_kwargs: dict[str, object] -) -> bool: - """Return True when JudgeArena should auto-apply a thinking-token budget.""" - return ( - get_reasoning_model_defaults(model_name) is not None - or "reasoning_parser" in vllm_kwargs - or "reasoning_config" in vllm_kwargs - ) +def build_default_judge_model_kwargs( + judge_model: str, engine_kwargs: dict[str, object] +) -> dict[str, object]: + """Copy judge engine kwargs and add supported built-in defaults.""" + judge_model_kwargs = dict(engine_kwargs) + provider, model_name = _split_model_spec(judge_model) + if ( + provider == "VLLM" + and "thinking_token_budget" not in judge_model_kwargs + and is_qwen_reasoning_model(model_name) + ): + judge_model_kwargs["thinking_token_budget"] = ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) + return judge_model_kwargs def _resolve_chat_template_kwargs( *, explicit_chat_template_kwargs: dict[str, object] | None, - reasoning_defaults: ReasoningModelDefaults | None, - enable_reasoning: bool, disable_thinking: bool, ) -> dict[str, object] | None: chat_template_kwargs = dict(explicit_chat_template_kwargs or {}) - explicit_keys = set(chat_template_kwargs) - - if enable_reasoning and not disable_thinking and reasoning_defaults is not None: - for key, value in ( - reasoning_defaults.enabled_chat_template_kwargs or {} - ).items(): - chat_template_kwargs.setdefault(key, value) - - if disable_thinking: - disabled_defaults = ( - reasoning_defaults.disabled_chat_template_kwargs - if reasoning_defaults is not None - else {"enable_thinking": False} - ) - for key, value in disabled_defaults.items(): - if key not in explicit_keys: - chat_template_kwargs[key] = value + if disable_thinking and "enable_thinking" not in chat_template_kwargs: + chat_template_kwargs["enable_thinking"] = False return chat_template_kwargs or None -def _attach_provider_metadata(model_instance: object, provider_name: str) -> object: - try: - model_instance._judgearena_provider = provider_name - except Exception: - pass - return model_instance - - def set_langchain_cache(): set_llm_cache(SQLiteCache(database_path=str(data_root / ".langchain.db"))) @@ -405,17 +321,11 @@ def __init__( disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) thinking_token_budget = vllm_kwargs.pop("thinking_token_budget", None) explicit_chat_template_kwargs = vllm_kwargs.pop("chat_template_kwargs", None) - reasoning_defaults = get_reasoning_model_defaults(model) explicit_reasoning_settings = ( "reasoning_parser" in vllm_kwargs or "reasoning_config" in vllm_kwargs ) - enable_reasoning = ( - explicit_reasoning_settings or thinking_token_budget is not None - ) self._chat_template_kwargs = _resolve_chat_template_kwargs( explicit_chat_template_kwargs=explicit_chat_template_kwargs, - reasoning_defaults=reasoning_defaults, - enable_reasoning=enable_reasoning, disable_thinking=disable_thinking, ) @@ -449,28 +359,29 @@ def __init__( "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } if thinking_token_budget is not None: - if reasoning_defaults is None and not explicit_reasoning_settings: + if explicit_reasoning_settings: + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) + elif is_qwen_reasoning_model(model): + vllm_kwargs.setdefault( + "reasoning_config", + ReasoningConfig( + reasoning_start_str=VLLM_QWEN_REASONING_START_STR, + reasoning_end_str=VLLM_QWEN_REASONING_END_STR, + ), + ) + vllm_kwargs.setdefault("reasoning_parser", "qwen3") + self._sampling_params_kwargs["thinking_token_budget"] = int( + thinking_token_budget + ) + else: warnings.warn( - f"Model '{model}' is not in JudgeArena's supported reasoning-family map. " + f"Model '{model}' is not in JudgeArena's built-in Qwen reasoning defaults. " "Ignoring thinking_token_budget unless reasoning_parser or " "reasoning_config is provided explicitly.", stacklevel=2, ) - else: - if reasoning_defaults is not None: - vllm_kwargs.setdefault( - "reasoning_parser", reasoning_defaults.reasoning_parser - ) - if reasoning_defaults.reasoning_config_kwargs is not None: - vllm_kwargs.setdefault( - "reasoning_config", - ReasoningConfig( - **reasoning_defaults.reasoning_config_kwargs - ), - ) - self._sampling_params_kwargs["thinking_token_budget"] = int( - thinking_token_budget - ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) @@ -647,12 +558,11 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): # Dedicated arguments like max_tokens always win over engine_kwargs. engine_kwargs["max_tokens"] = max_tokens or 8192 - model_provider = model.split("/")[0] + model_provider, model_name = _split_model_spec(model) if model_provider == "Dummy": - return _attach_provider_metadata(DummyModel(model), model_provider) + return DummyModel(model) - model_name = "/".join(model.split("/")[1:]) print(f"Loading {model_provider}(model={model_name})") # Use our custom ChatVLLM wrapper which properly applies chat templates @@ -660,24 +570,18 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): engine_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} engine_kwargs["chat_template"] = engine_kwargs.get("chat_template", None) - return _attach_provider_metadata( - ChatVLLM( - model=model_name, - **engine_kwargs, - ), - model_provider, + return ChatVLLM( + model=model_name, + **engine_kwargs, ) if model_provider == "OpenRouter": # Special case we need to override API url and key - return _attach_provider_metadata( - ChatOpenAI( - api_key=os.getenv("OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1", - model=model_name, - **engine_kwargs, - ), - model_provider, + return ChatOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1", + model=model_name, + **engine_kwargs, ) else: model_classes = [ @@ -705,24 +609,20 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): assert model_provider in model_cls_dict, ( f"{model_provider} not available, choose among {list(model_cls_dict.keys())}" ) - return _attach_provider_metadata( - model_cls_dict[model_provider](**engine_kwargs), model_provider - ) + return model_cls_dict[model_provider](**engine_kwargs) def infer_model_spec_from_instance(model: object) -> str | None: if isinstance(model, DummyModel): return model.name - provider_name = getattr(model, "_judgearena_provider", None) - model_path = getattr(model, "model_path", None) - if isinstance(model_path, str): - if isinstance(provider_name, str): - return f"{provider_name}/{model_path}" - return model_path + if isinstance(model, ChatVLLM): + return f"VLLM/{model.model_path}" + if isinstance(model, LlamaCpp): + model_path = getattr(model, "model_path", None) + if isinstance(model_path, str): + return f"LlamaCpp/{model_path}" model_name = getattr(model, "model_name", None) or getattr(model, "model", None) if isinstance(model_name, str): - if isinstance(provider_name, str): - return f"{provider_name}/{model_name}" return f"{model.__class__.__name__}/{model_name}" return None diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index 6621ee3..f2f23f0 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -87,27 +87,51 @@ def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch) } -def test_chat_vllm_enables_family_specific_chat_template_kwargs(monkeypatch): +def test_build_default_judge_model_kwargs_only_defaults_qwen_judges(): + assert utils.build_default_judge_model_kwargs( + "VLLM/Qwen/Qwen3.5-27B-FP8", + {"gpu_memory_utilization": 0.7}, + ) == { + "gpu_memory_utilization": 0.7, + "thinking_token_budget": 512, + } + assert utils.build_default_judge_model_kwargs( + "VLLM/meta-llama/Llama-3.3-70B-Instruct", + {"gpu_memory_utilization": 0.7}, + ) == {"gpu_memory_utilization": 0.7} + assert ( + utils.build_default_judge_model_kwargs( + "OpenRouter/qwen/qwen3-32b", + {}, + ) + == {} + ) + + +def test_chat_vllm_preserves_explicit_reasoning_settings_for_non_qwen(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) - chat_model = utils.ChatVLLM( - model="deepseek-ai/DeepSeek-V3.1", + explicit_reasoning_config = object() + + utils.ChatVLLM( + model="meta-llama/Llama-3.3-70B-Instruct", max_tokens=16, thinking_token_budget=32, + reasoning_parser="custom-parser", + reasoning_config=explicit_reasoning_config, gpu_memory_utilization=0.7, ) - outputs = chat_model.batch(["hello"]) - - assert outputs == ["ok"] assert captured["sampling_kwargs"]["thinking_token_budget"] == 32 - assert captured["llm_init"]["kwargs"]["reasoning_parser"] == "deepseek_v3" - assert captured["chat_call"]["kwargs"]["chat_template_kwargs"] == {"thinking": True} + assert captured["llm_init"]["kwargs"]["reasoning_parser"] == "custom-parser" + assert ( + captured["llm_init"]["kwargs"]["reasoning_config"] is explicit_reasoning_config + ) def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) - with pytest.warns(UserWarning, match="supported reasoning-family map"): + with pytest.warns(UserWarning, match="built-in Qwen reasoning defaults"): utils.ChatVLLM( model="meta-llama/Llama-3.3-70B-Instruct", max_tokens=32, @@ -120,10 +144,19 @@ def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): assert "reasoning_config" not in captured["llm_init"]["kwargs"] -def test_infer_model_spec_uses_attached_provider_name(): - model = SimpleNamespace( - _judgearena_provider="LlamaCpp", - model_path="./models/model.gguf", - ) +def test_infer_model_spec_uses_type_based_vllm_fallback(): + model = object.__new__(utils.ChatVLLM) + model.model_path = "Qwen/Qwen3.5-27B-FP8" + + assert utils.infer_model_spec_from_instance(model) == "VLLM/Qwen/Qwen3.5-27B-FP8" + + +def test_infer_model_spec_uses_type_based_llamacpp_fallback(monkeypatch): + class FakeLlamaCpp: + def __init__(self, model_path: str): + self.model_path = model_path + + monkeypatch.setattr(utils, "LlamaCpp", FakeLlamaCpp) + model = FakeLlamaCpp("./models/model.gguf") assert utils.infer_model_spec_from_instance(model) == "LlamaCpp/./models/model.gguf" diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index 91f87dc..d3641b5 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -193,6 +193,64 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs assert "thinking_token_budget" not in captured["make_model"] +def test_main_preserves_explicit_reasoning_engine_kwargs_for_non_qwen_vllm_judge( + tmp_path, monkeypatch +): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + completions_df = pd.DataFrame( + {"instruction_index": [1], "completion": ["Loaded answer"]} + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: completions_df, + ) + + def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): + captured["make_model"] = kwargs + return object() + + monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + lambda **kwargs: ( + [{"judge_completion": "score_A: 1\nscore_B: 2"}], + None, + pd.Series([1.0]), + ), + ) + + prefs = main_generate_and_eval( + CliArgs( + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", + n_instructions=1, + result_folder=str(tmp_path / "results"), + engine_kwargs={ + "reasoning_parser": "custom-parser", + "thinking_token_budget": 2048, + }, + ) + ) + + assert prefs.tolist() == [1.0] + assert captured["make_model"]["reasoning_parser"] == "custom-parser" + assert captured["make_model"]["thinking_token_budget"] == 2048 + + def test_annotate_battles_warns_when_judge_completions_are_truncated( monkeypatch, capsys ): diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 7f454a7..c6349be 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -211,6 +211,9 @@ def test_pair_v2_system_prompt_matches_original_fastchat_contract(): assert '"[[A]]"' in rendered assert '"[[B]]"' in rendered assert '"[[C]]"' in rendered + assert rendered.endswith( + 'After providing your explanation, output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie.\n' + ) def test_conservative_winner_marks_one_sided_parse_failures_as_error(): From 217dc8d144e1ad43a2f786354e73c223c60b9347 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Sat, 18 Apr 2026 00:05:42 +0200 Subject: [PATCH 16/28] Support Skywork --- judgearena/cli_common.py | 58 +++- judgearena/evaluate.py | 255 +++++++++----- judgearena/generate.py | 267 ++++++++++++--- judgearena/generate_and_evaluate.py | 179 ++++++---- judgearena/judge_prompt_presets.py | 104 ++++++ judgearena/mt_bench/common.py | 78 +++-- judgearena/mt_bench/fastchat_compat.py | 247 +++++++++++++- judgearena/mt_bench/mt_bench_utils.py | 84 ++++- judgearena/prompts/prompt.txt | 8 +- .../skywork-prompt-with-explanation.txt | 14 + judgearena/prompts/skywork-prompt.txt | 14 + judgearena/utils.py | 236 ++++++++++++- tests/test_chat_vllm.py | 17 +- tests/test_generate_and_evaluate.py | 22 ++ tests/test_local_completion_loading.py | 318 ++++++++++++++---- tests/test_mt_bench_downloads.py | 98 +++++- tests/test_regexp.py | 38 ++- 17 files changed, 1717 insertions(+), 320 deletions(-) create mode 100644 judgearena/judge_prompt_presets.py create mode 100644 judgearena/prompts/skywork-prompt-with-explanation.txt create mode 100644 judgearena/prompts/skywork-prompt.txt diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index cbaa696..118464d 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -11,6 +11,8 @@ import json from dataclasses import dataclass, field +from judgearena.judge_prompt_presets import JUDGE_PROMPT_PRESETS + @dataclass class BaseCliArgs: @@ -22,6 +24,9 @@ class BaseCliArgs: provide_explanation: bool = False swap_mode: str = "fixed" ignore_cache: bool = False + judge_prompt_preset: str = "default" + battle_thinking_token_budget: int | None = None + strip_thinking_before_judging: bool = False truncate_all_input_chars: int = 8192 max_out_tokens_models: int = 32768 max_out_tokens_judge: int = 32768 @@ -37,6 +42,17 @@ def __post_init__(self): ) +def parse_optional_bool(raw: str | None) -> bool: + if raw is None: + return True + normalized = raw.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise argparse.ArgumentTypeError(f"Expected a boolean value, got '{raw}'.") + + def add_common_arguments(parser: argparse.ArgumentParser) -> None: """Register the CLI flags shared by all judgearena entrypoints.""" parser.add_argument( @@ -58,7 +74,10 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--provide_explanation", - action="store_true", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, help=( "If specified, judge will provide explanation before making a " "judgement. Does not necessarily improve the accuracy of the judge " @@ -79,9 +98,44 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--ignore_cache", - action="store_true", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, help="If specified, ignore cache of previous completions.", ) + parser.add_argument( + "--judge_prompt_preset", + type=str, + choices=JUDGE_PROMPT_PRESETS, + default="default", + help=( + "Judge prompt preset to use. 'default' preserves the existing score-first " + "JudgeArena prompts, while 'skywork' enables an optional Skywork-style " + "verdict-first preset." + ), + ) + parser.add_argument( + "--battle_thinking_token_budget", + type=int, + required=False, + default=None, + help=( + "Optional reasoning-token sub-budget for battle-model generation. " + "This stays inside --max_out_tokens_models." + ), + ) + parser.add_argument( + "--strip_thinking_before_judging", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, + help=( + "If specified, strip visible reasoning traces from model completions " + "before sending them to the judge." + ), + ) parser.add_argument( "--result_folder", type=str, diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 395aa13..0063623 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -14,6 +14,11 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + ResolvedJudgePrompt, + resolve_pairwise_judge_prompt, +) from judgearena.openrouter_reference_pricing import ( OpenRouterReferencePricingTracker, build_openrouter_reference_pricing_summary, @@ -21,6 +26,7 @@ ) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + LimitEventTracker, compute_pref_summary, data_root, do_inference, @@ -28,14 +34,16 @@ infer_model_spec_from_instance, read_df, strip_thinking_tags, - truncate, + strip_thinking_tags_with_metadata, + truncate_with_metadata, ) class PairScore: - def __init__(self): + def __init__(self, *, parser_mode: str = "score"): super(PairScore).__init__() self.temperature = 0.3 + self.parser_mode = parser_mode def preference_from_scores(self, score_a: float, score_b: float) -> float: return 1 - np.exp(self.temperature * score_a) / ( @@ -44,25 +52,30 @@ def preference_from_scores(self, score_a: float, score_b: float) -> float: def parse_model_raw(self, judge_completion: str) -> float | None: judge_completion = strip_thinking_tags(judge_completion) - # lower case to avoid confusion, e.g. when "a" is used instead of "A" - score_a = self.get_regexp_match( - judge_completion.lower(), r'score.*?a[": *\n]*(-?\d+)' - ) - score_b = self.get_regexp_match( - judge_completion.lower(), r'score.*?b[": *\n]*(-?\d+)' - ) + if self.parser_mode == "verdict": + return self._parse_bracketed_verdict(judge_completion) + if self.parser_mode == "score": + return self._parse_numeric_scores(judge_completion) + raise ValueError(f"Unsupported parser_mode '{self.parser_mode}'.") + + def _parse_numeric_scores(self, judge_completion: str) -> float | None: + lowered = judge_completion.lower() + score_a = self.get_regexp_match(lowered, r'score.*?a[": *\n]*(-?\d+)') + score_b = self.get_regexp_match(lowered, r'score.*?b[": *\n]*(-?\d+)') if score_a is None or score_b is None: - verdict_match = re.search(r"\[\[\s*([ABCabc])\s*\]\]", judge_completion) - if verdict_match is None: - return None - bracketed_verdict = verdict_match.group(1).lower() - return { - "a": 0.0, - "b": 1.0, - "c": 0.5, - }[bracketed_verdict] - else: - return float(self.preference_from_scores(score_a, score_b)) + return None + return float(self.preference_from_scores(score_a, score_b)) + + def _parse_bracketed_verdict(self, judge_completion: str) -> float | None: + verdict_match = re.search(r"\[\[\s*([ABCabc])\s*\]\]", judge_completion) + if verdict_match is None: + return None + bracketed_verdict = verdict_match.group(1).lower() + return { + "a": 0.0, + "b": 1.0, + "c": 0.5, + }[bracketed_verdict] def get_regexp_match(self, s: str, regex: str, group_index: int = 1): m = re.search(re.compile(regex), s) @@ -72,53 +85,32 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) -_COMPLETION_LABEL_SINGLE = "Answer" -_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" -_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" -_SCORE_FENCE = "\n```" - - def load_judge_system_and_user_prompt( provide_explanation: bool = True, multi_turn: bool = False, ) -> tuple[str, str]: - prompts_dir = Path(__file__).parent / "prompts" - system_prompt = (prompts_dir / "system-prompt.txt").read_text() - - prompt_filename = ( - "prompt-with-explanation.txt" if provide_explanation else "prompt.txt" - ) - user_prompt_template = (prompts_dir / prompt_filename).read_text() - user_prompt_template = user_prompt_template.replace( - "{completion_label}", - _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, - ) - user_prompt_template = user_prompt_template.replace( - "{explanation_suffix}", - _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + resolved = resolve_pairwise_judge_prompt( + prompt_preset=DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation=provide_explanation, + multi_turn=multi_turn, ) - return system_prompt, user_prompt_template + return resolved.system_prompt or "", resolved.user_prompt_template def resolve_judge_prompts( *, provide_explanation: bool, multi_turn: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, system_prompt: str | None = None, user_prompt_template: str | None = None, -) -> tuple[str, str]: - default_system_prompt, default_user_prompt_template = ( - load_judge_system_and_user_prompt( - provide_explanation=provide_explanation, multi_turn=multi_turn - ) - ) - return ( - system_prompt if system_prompt is not None else default_system_prompt, - ( - user_prompt_template - if user_prompt_template is not None - else default_user_prompt_template - ), +) -> ResolvedJudgePrompt: + return resolve_pairwise_judge_prompt( + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + multi_turn=multi_turn, + system_prompt=system_prompt, + user_prompt_template=user_prompt_template, ) @@ -131,6 +123,8 @@ def evaluate_completions( use_tqdm: bool = False, truncate_input_chars: int | None = 8192, provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, ): """ :param dataset: @@ -200,40 +194,49 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): judge_chat_model = Together(model="meta-llama/Llama-3.3-70B-Instruct-Turbo") judge_model_spec = infer_model_spec_from_instance(judge_chat_model) usage_tracker = OpenRouterReferencePricingTracker() + limit_event_tracker = LimitEventTracker() unique_string = dataset + "-" + datetime.now().strftime("%Y%m%d_%H%M%S") output_folder = data_root / "judge-evals" / unique_string print(f"Saving results in {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) - ( - judge_system_prompt, - judge_user_prompt_template, - ) = resolve_judge_prompts(provide_explanation=provide_explanation) + resolved_prompt = resolve_judge_prompts( + provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + ) annotations = annotate_battles( judge_chat_model=judge_chat_model, instructions=instructions.tolist(), completions_A=completions_A.loc[instructions.index].tolist(), completions_B=completions_B.loc[instructions.index].tolist(), - system_prompt=judge_system_prompt, - user_prompt_template=judge_user_prompt_template, + case_ids=instructions.index.tolist(), + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=resolved_prompt.user_prompt_template, + prompt_preset=resolved_prompt.preset_name, use_tqdm=use_tqdm, truncate_input_chars=truncate_input_chars, provide_explanation=provide_explanation, + strip_thinking_before_judging=strip_thinking_before_judging, usage_tracker=usage_tracker, usage_phase="judge", usage_model_spec=judge_model_spec, + limit_event_tracker=limit_event_tracker, ) # Pairwise judge results - score_parser = PairScore() + score_parser = PairScore(parser_mode=resolved_prompt.parser_mode) prefs = pd.Series( [ score_parser.parse_model_raw(annotation.judge_completion) for annotation in annotations ] ) - results = {**compute_pref_summary(prefs)} + results = { + **compute_pref_summary(prefs), + "judge_prompt_preset": resolved_prompt.preset_name, + "limit_events": limit_event_tracker.build_summary(), + } pd.DataFrame(annotations).to_csv(output_folder / "annotations.csv", index=False) print(f"{method_A} against {method_B}:\n{results}") @@ -256,6 +259,8 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): "use_tqdm": use_tqdm, "truncate_input_chars": truncate_input_chars, "provide_explanation": provide_explanation, + "judge_prompt_preset": resolved_prompt.preset_name, + "strip_thinking_before_judging": strip_thinking_before_judging, } try: @@ -270,8 +275,8 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str): "completions_A": completions_A.loc[instructions.index].tolist(), "completions_B": completions_B.loc[instructions.index].tolist(), }, - judge_system_prompt=judge_system_prompt, - judge_user_prompt_template=judge_user_prompt_template, + judge_system_prompt=resolved_prompt.system_prompt, + judge_user_prompt_template=resolved_prompt.user_prompt_template, started_at_utc=run_started_at, pricing_reference=pricing_reference, ) @@ -286,6 +291,12 @@ class JudgeAnnotation: completion_B: str # completion of the second model judge_completion: str # output of the judge judge_input: str | None = None # input that was passed to the judge + completion_A_for_judge: str | None = None + completion_B_for_judge: str | None = None + completion_A_reasoning_stripped: bool = False + completion_B_reasoning_stripped: bool = False + completion_A_truncated_for_judge: bool = False + completion_B_truncated_for_judge: bool = False def annotate_battles( @@ -293,14 +304,18 @@ def annotate_battles( instructions: list[str], completions_A: list[str], completions_B: list[str], + case_ids: list[object] | None = None, system_prompt: str | None = None, user_prompt_template: str = None, truncate_input_chars: int | None = 8192, use_tqdm: bool = False, provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, usage_model_spec: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, ) -> list[JudgeAnnotation]: """ Directly evaluate from list of instructions and completions @@ -332,25 +347,79 @@ def annotate_battles( """ # alternatively pass list of tuples assert len(instructions) == len(completions_A) == len(completions_B) + if case_ids is None: + case_ids = [None] * len(instructions) + assert len(case_ids) == len(instructions) - system_prompt, user_prompt_template = resolve_judge_prompts( + resolved_prompt = resolve_judge_prompts( provide_explanation=provide_explanation, + prompt_preset=prompt_preset, system_prompt=system_prompt, user_prompt_template=user_prompt_template, ) + message_templates: list[tuple[str, str]] = [] + if resolved_prompt.system_prompt is not None: + message_templates.append(("system", resolved_prompt.system_prompt)) + message_templates.append(("user", resolved_prompt.user_prompt_template)) - prompt_template = ChatPromptTemplate.from_messages( - [("system", system_prompt), ("user", user_prompt_template)] - ) + prompt_template = ChatPromptTemplate.from_messages(message_templates) truncated_completion_count = 0 input_payloads = [] - for user_prompt, completion_A, completion_B in zip( - instructions, completions_A, completions_B, strict=True + annotation_input_metadata: list[dict[str, object]] = [] + for case_id, user_prompt, completion_A, completion_B in zip( + case_ids, instructions, completions_A, completions_B, strict=True ): - truncated_completion_A = truncate(completion_A, max_len=truncate_input_chars) - truncated_completion_B = truncate(completion_B, max_len=truncate_input_chars) - truncated_completion_count += int(truncated_completion_A != completion_A) - truncated_completion_count += int(truncated_completion_B != completion_B) + raw_completion_A = completion_A if isinstance(completion_A, str) else "" + raw_completion_B = completion_B if isinstance(completion_B, str) else "" + completion_A_for_judge = raw_completion_A + completion_B_for_judge = raw_completion_B + stripped_A = False + stripped_B = False + if strip_thinking_before_judging: + completion_A_for_judge, stripped_A = strip_thinking_tags_with_metadata( + completion_A_for_judge + ) + completion_B_for_judge, stripped_B = strip_thinking_tags_with_metadata( + completion_B_for_judge + ) + if stripped_A and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field="completion_A", + case_id=case_id, + original_length=len(raw_completion_A), + final_length=len(completion_A_for_judge), + ) + if stripped_B and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field="completion_B", + case_id=case_id, + original_length=len(raw_completion_B), + final_length=len(completion_B_for_judge), + ) + truncated_completion_A, truncated_A = truncate_with_metadata( + completion_A_for_judge, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="judge_input_char_truncation", + stage="judge_input", + field="completion_A", + case_id=case_id, + ) + truncated_completion_B, truncated_B = truncate_with_metadata( + completion_B_for_judge, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="judge_input_char_truncation", + stage="judge_input", + field="completion_B", + case_id=case_id, + ) + truncated_completion_count += int(truncated_A) + truncated_completion_count += int(truncated_B) input_payloads.append( { "user_prompt": user_prompt, @@ -358,10 +427,20 @@ def annotate_battles( "completion_B": truncated_completion_B, } ) + annotation_input_metadata.append( + { + "completion_A_for_judge": truncated_completion_A, + "completion_B_for_judge": truncated_completion_B, + "completion_A_reasoning_stripped": stripped_A, + "completion_B_reasoning_stripped": stripped_B, + "completion_A_truncated_for_judge": truncated_A, + "completion_B_truncated_for_judge": truncated_B, + } + ) if truncated_completion_count: print( "Warning: truncated " - f"{truncated_completion_count} judge completions to " + f"{truncated_completion_count} judge inputs to " f"{truncate_input_chars} characters before evaluation." ) inputs = prompt_template.batch(input_payloads) @@ -377,12 +456,20 @@ def annotate_battles( ) annotations = [] - for judge_input, judge_completion, instruction, completion_A, completion_B in zip( + for ( + judge_input, + judge_completion, + instruction, + completion_A, + completion_B, + annotation_input_metadata_row, + ) in zip( inputs, judge_completions, instructions, completions_A, completions_B, + annotation_input_metadata, strict=True, ): annotations.append( @@ -392,6 +479,7 @@ def annotate_battles( instruction=instruction, completion_A=completion_A, completion_B=completion_B, + **annotation_input_metadata_row, ) ) return annotations @@ -402,8 +490,12 @@ def judge_and_parse_prefs( instructions: list[str], completions_A: list[str], completions_B: list[str], + case_ids: list[object] | None = None, swap_mode: str = "fixed", provide_explanation: bool = False, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + parser_mode: str = "score", + strip_thinking_before_judging: bool = False, system_prompt: str | None = None, user_prompt_template: str | None = None, truncate_input_chars: int = 8192, @@ -411,6 +503,7 @@ def judge_and_parse_prefs( usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, usage_model_spec: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, ) -> tuple[list[JudgeAnnotation], list[JudgeAnnotation] | None, pd.Series]: """Run judge annotation and parse preferences, handling swap_mode='both'. @@ -431,7 +524,10 @@ def judge_and_parse_prefs( instructions=instructions, completions_A=completions_A, completions_B=completions_B, + case_ids=case_ids, provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, system_prompt=system_prompt, user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, @@ -439,6 +535,7 @@ def judge_and_parse_prefs( usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=usage_model_spec, + limit_event_tracker=limit_event_tracker, ) annotations_reversed = None @@ -448,7 +545,10 @@ def judge_and_parse_prefs( instructions=instructions, completions_A=completions_B, completions_B=completions_A, + case_ids=case_ids, provide_explanation=provide_explanation, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, system_prompt=system_prompt, user_prompt_template=user_prompt_template, truncate_input_chars=truncate_input_chars, @@ -456,12 +556,13 @@ def judge_and_parse_prefs( usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=usage_model_spec, + limit_event_tracker=limit_event_tracker, ) def _none_to_nan(x): return float("nan") if x is None else x - score_parser = PairScore() + score_parser = PairScore(parser_mode=parser_mode) prefs = pd.Series( [score_parser.parse_model_raw(a.judge_completion) for a in annotations] ) diff --git a/judgearena/generate.py b/judgearena/generate.py index 9720fad..3a1c65b 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -1,13 +1,41 @@ +from typing import Any + import pandas as pd from langchain_core.prompts import ChatPromptTemplate from judgearena.utils import ( + LimitEventTracker, do_inference, make_model, - truncate, + truncate_with_metadata, ) +def _record_generation_output_limit_events( + *, + metadata: list[dict[str, Any]], + case_ids: list[object], + field: str, + model_spec: str, + limit_event_tracker: LimitEventTracker | None, +) -> list[bool]: + hit_token_limit: list[bool] = [] + for case_id, metadata_row in zip(case_ids, metadata, strict=True): + finish_reason = str((metadata_row or {}).get("finish_reason") or "").lower() + reached_limit = finish_reason == "length" + hit_token_limit.append(reached_limit) + if reached_limit and limit_event_tracker is not None: + limit_event_tracker.record( + "generation_output_token_limit", + stage="generation_output", + field=field, + case_id=case_id, + model_spec=model_spec, + note=finish_reason, + ) + return hit_token_limit + + def generate_instructions( instructions: pd.Series, model: str, @@ -17,9 +45,17 @@ def generate_instructions( system_prompt: str | None = None, usage_tracker=None, usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, **engine_kwargs, ) -> pd.DataFrame: - chat_model = make_model(model, max_tokens=max_tokens, **engine_kwargs) + chat_model = make_model( + model, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **engine_kwargs, + ) # TODO improve prompt to generate instructions if system_prompt is None: @@ -30,27 +66,50 @@ def generate_instructions( [("system", system_prompt), ("user", "{user_prompt}")] ) - inputs = prompt_template.batch( - [ - { - "user_prompt": truncate(user_prompt, max_len=truncate_input_chars), - } - for user_prompt in instructions - ] - ) + prompt_truncated: list[bool] = [] + input_payloads = [] + case_ids = instructions.index.tolist() + for case_id, user_prompt in zip(case_ids, instructions, strict=True): + truncated_user_prompt, was_truncated = truncate_with_metadata( + user_prompt, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="user_prompt", + case_id=case_id, + model_spec=model, + ) + prompt_truncated.append(was_truncated) + input_payloads.append({"user_prompt": truncated_user_prompt}) + inputs = prompt_template.batch(input_payloads) - completions = do_inference( + completions, completion_metadata = do_inference( chat_model=chat_model, inputs=inputs, use_tqdm=use_tqdm, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=model, + return_metadata=True, + ) + hit_token_limit = _record_generation_output_limit_events( + metadata=completion_metadata, + case_ids=case_ids, + field="completion", + model_spec=model, + limit_event_tracker=limit_event_tracker, ) df_outputs = pd.DataFrame( data={ "completion": completions, - "instruction_index": instructions.index.tolist(), + "instruction_index": case_ids, + "generation_prompt_truncated": prompt_truncated, + "generation_output_finish_reason": [ + metadata_row.get("finish_reason") + for metadata_row in completion_metadata + ], + "generation_output_hit_token_limit": hit_token_limit, }, ) return df_outputs @@ -76,8 +135,9 @@ def _infer_grouped_by_temperature( use_tqdm: bool, usage_tracker=None, usage_phase: str | None = None, -) -> list[str]: +) -> tuple[list[str], list[dict[str, Any]]]: outputs: list[str] = [""] * len(inputs) + outputs_metadata: list[dict[str, Any]] = [{} for _ in inputs] groups: dict[float, list[int]] = {} for idx, temp in enumerate(temperatures): groups.setdefault(float(temp), []).append(idx) @@ -94,18 +154,20 @@ def _infer_grouped_by_temperature( model_spec, max_tokens=max_tokens, temperature=temp, **model_kwargs ) - group_outs = do_inference( + group_outs, group_metadata = do_inference( chat_model=group_model, inputs=group_inputs, use_tqdm=use_tqdm, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=model_spec, + return_metadata=True, ) - for i, out in zip(idxs, group_outs, strict=True): + for i, out, metadata_row in zip(idxs, group_outs, group_metadata, strict=True): outputs[i] = out + outputs_metadata[i] = metadata_row - return outputs + return outputs, outputs_metadata def generate_multiturn( @@ -117,6 +179,7 @@ def generate_multiturn( temperature_config: dict[str, float] | None = None, usage_tracker=None, usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, **model_kwargs, ) -> pd.DataFrame: """Generate two-turn completions for MT-Bench style questions.""" @@ -126,10 +189,23 @@ def generate_multiturn( if use_category_temperatures and local_provider: chat_model = make_model( - model, max_tokens=max_tokens, temperature=0.0, **model_kwargs + model, + max_tokens=max_tokens, + temperature=0.0, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **model_kwargs, ) else: - chat_model = make_model(model, max_tokens=max_tokens, **model_kwargs) + chat_model = make_model( + model, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model, + **model_kwargs, + ) system_prompt = "You are a helpful assistant." idxs = questions.index.tolist() @@ -143,15 +219,25 @@ def generate_multiturn( turn1_template = ChatPromptTemplate.from_messages( [("system", system_prompt), ("user", "{user_prompt}")] ) - turn1_inputs = turn1_template.batch( - [ - {"user_prompt": truncate(row["turn_1"], max_len=truncate_input_chars)} - for _, row in questions.iterrows() - ] - ) + turn1_prompt_truncated: list[bool] = [] + turn1_payloads = [] + for question_id, row in questions.iterrows(): + truncated_turn_1, was_truncated = truncate_with_metadata( + row["turn_1"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1", + case_id=question_id, + model_spec=model, + ) + turn1_prompt_truncated.append(was_truncated) + turn1_payloads.append({"user_prompt": truncated_turn_1}) + turn1_inputs = turn1_template.batch(turn1_payloads) if use_category_temperatures: - completions_turn_1 = _infer_grouped_by_temperature( + completions_turn_1, turn1_metadata = _infer_grouped_by_temperature( model_spec=model, provider=provider, max_tokens=max_tokens, @@ -164,20 +250,34 @@ def generate_multiturn( usage_phase=usage_phase, ) else: - completions_turn_1 = do_inference( + completions_turn_1, turn1_metadata = do_inference( chat_model=chat_model, inputs=turn1_inputs, use_tqdm=use_tqdm, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=model, + return_metadata=True, ) + turn1_hit_token_limit = _record_generation_output_limit_events( + metadata=turn1_metadata, + case_ids=idxs, + field="completion_turn_1", + model_spec=model, + limit_event_tracker=limit_event_tracker, + ) turn2_inputs = [] - for (_, row), t1_answer in zip( + turn2_turn1_truncated: list[bool] = [] + turn2_answer_truncated: list[bool] = [] + turn2_prompt_truncated: list[bool] = [] + for (question_id, row), t1_answer in zip( questions.iterrows(), completions_turn_1, strict=True ): if row["turn_2"] is None: + turn2_turn1_truncated.append(False) + turn2_answer_truncated.append(False) + turn2_prompt_truncated.append(False) turn2_inputs.append( turn1_template.invoke({"user_prompt": "No follow-up question."}) ) @@ -190,20 +290,51 @@ def generate_multiturn( ("user", "{turn_2}"), ] ) + truncated_turn_1, turn1_was_truncated = truncate_with_metadata( + row["turn_1"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1_for_turn_2", + case_id=question_id, + model_spec=model, + ) + truncated_turn_1_answer, answer_was_truncated = truncate_with_metadata( + str(t1_answer), + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_1_answer", + case_id=question_id, + model_spec=model, + ) + truncated_turn_2, turn2_was_truncated = truncate_with_metadata( + row["turn_2"], + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="turn_2", + case_id=question_id, + model_spec=model, + ) + turn2_turn1_truncated.append(turn1_was_truncated) + turn2_answer_truncated.append(answer_was_truncated) + turn2_prompt_truncated.append(turn2_was_truncated) turn2_inputs.append( multi_turn_template.invoke( { - "turn_1": truncate(row["turn_1"], max_len=truncate_input_chars), - "turn_1_answer": truncate( - str(t1_answer), max_len=truncate_input_chars - ), - "turn_2": truncate(row["turn_2"], max_len=truncate_input_chars), + "turn_1": truncated_turn_1, + "turn_1_answer": truncated_turn_1_answer, + "turn_2": truncated_turn_2, } ) ) if use_category_temperatures: - completions_turn_2 = _infer_grouped_by_temperature( + completions_turn_2, turn2_metadata = _infer_grouped_by_temperature( model_spec=model, provider=provider, max_tokens=max_tokens, @@ -216,20 +347,40 @@ def generate_multiturn( usage_phase=usage_phase, ) else: - completions_turn_2 = do_inference( + completions_turn_2, turn2_metadata = do_inference( chat_model=chat_model, inputs=turn2_inputs, use_tqdm=use_tqdm, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=model, + return_metadata=True, ) + turn2_hit_token_limit = _record_generation_output_limit_events( + metadata=turn2_metadata, + case_ids=idxs, + field="completion_turn_2", + model_spec=model, + limit_event_tracker=limit_event_tracker, + ) return pd.DataFrame( data={ "instruction_index": idxs, "completion_turn_1": completions_turn_1, "completion_turn_2": completions_turn_2, + "generation_turn_1_prompt_truncated": turn1_prompt_truncated, + "generation_turn_1_finish_reason": [ + metadata_row.get("finish_reason") for metadata_row in turn1_metadata + ], + "generation_turn_1_hit_token_limit": turn1_hit_token_limit, + "generation_turn_2_turn_1_prompt_truncated": turn2_turn1_truncated, + "generation_turn_2_turn_1_answer_truncated": turn2_answer_truncated, + "generation_turn_2_prompt_truncated": turn2_prompt_truncated, + "generation_turn_2_finish_reason": [ + metadata_row.get("finish_reason") for metadata_row in turn2_metadata + ], + "generation_turn_2_hit_token_limit": turn2_hit_token_limit, }, ) @@ -242,29 +393,63 @@ def generate_base( use_tqdm: bool = False, usage_tracker=None, usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, **engine_kwargs, ) -> pd.DataFrame: model_spec = model - model = make_model(model_spec, max_tokens=max_tokens, **engine_kwargs) + model = make_model( + model_spec, + max_tokens=max_tokens, + limit_event_tracker=limit_event_tracker, + limit_event_stage="generation_model_init", + limit_event_model_spec=model_spec, + **engine_kwargs, + ) - inputs = [ - truncate(instruction, max_len=truncate_input_chars) - for instruction in instructions - ] + prompt_truncated: list[bool] = [] + case_ids = instructions.index.tolist() + inputs = [] + for case_id, instruction in zip(case_ids, instructions, strict=True): + truncated_instruction, was_truncated = truncate_with_metadata( + instruction, + max_len=truncate_input_chars, + tracker=limit_event_tracker, + kind="generation_input_char_truncation", + stage="generation_input", + field="instruction", + case_id=case_id, + model_spec=model_spec, + ) + prompt_truncated.append(was_truncated) + inputs.append(truncated_instruction) - completions = do_inference( + completions, completion_metadata = do_inference( chat_model=model, inputs=inputs, use_tqdm=use_tqdm, usage_tracker=usage_tracker, usage_phase=usage_phase, usage_model_spec=model_spec, + return_metadata=True, + ) + hit_token_limit = _record_generation_output_limit_events( + metadata=completion_metadata, + case_ids=case_ids, + field="completion", + model_spec=model_spec, + limit_event_tracker=limit_event_tracker, ) df_outputs = pd.DataFrame( data={ "completion": completions, - "instruction_index": instructions.index.tolist(), + "instruction_index": case_ids, + "generation_prompt_truncated": prompt_truncated, + "generation_output_finish_reason": [ + metadata_row.get("finish_reason") + for metadata_row in completion_metadata + ], + "generation_output_hit_token_limit": hit_token_limit, }, ) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index a392a09..51edaf8 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -4,15 +4,20 @@ """ import argparse +import hashlib import json from dataclasses import asdict, dataclass from datetime import UTC, datetime -from functools import partial from pathlib import Path import pandas as pd -from judgearena.cli_common import BaseCliArgs, add_common_arguments, parse_engine_kwargs +from judgearena.cli_common import ( + BaseCliArgs, + add_common_arguments, + parse_engine_kwargs, + parse_optional_bool, +) from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions @@ -20,6 +25,7 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.mt_bench.mt_bench_utils import run_mt_bench from judgearena.openrouter_reference_pricing import ( OpenRouterReferencePricingTracker, @@ -28,11 +34,13 @@ ) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + LimitEventTracker, build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, data_root, download_hf, + is_qwen_reasoning_model, make_model, read_df, ) @@ -110,7 +118,10 @@ def parse_args(cls): ) parser.add_argument( "--use_tqdm", - action="store_true", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, help="If specified, use tqdm, does not work with all model providers, vLLM in particular.", ) add_common_arguments(parser) @@ -126,6 +137,9 @@ def parse_args(cls): provide_explanation=args.provide_explanation, swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, + judge_prompt_preset=args.judge_prompt_preset, + battle_thinking_token_budget=args.battle_thinking_token_budget, + strip_thinking_before_judging=args.strip_thinking_before_judging, truncate_all_input_chars=args.truncate_all_input_chars, max_out_tokens_models=args.max_out_tokens_models, max_out_tokens_judge=args.max_out_tokens_judge, @@ -141,6 +155,53 @@ def load_contexts(dataset: str) -> pd.Series: return pd.read_csv(path).loc[:, "instruction"] +def _build_generation_model_kwargs( + *, args: CliArgs, model_spec: str +) -> dict[str, object]: + generation_model_kwargs = dict(args.engine_kwargs) + provider, _, model_name = model_spec.partition("/") + if ( + args.battle_thinking_token_budget is not None + and provider == "VLLM" + and is_qwen_reasoning_model(model_name) + ): + generation_model_kwargs["thinking_token_budget"] = min( + int(args.battle_thinking_token_budget), + int(args.max_out_tokens_models), + ) + return generation_model_kwargs + + +def _build_judge_model_kwargs( + *, args: CliArgs, limit_event_tracker: LimitEventTracker | None +) -> dict[str, object]: + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, args.engine_kwargs + ) + if limit_event_tracker is not None: + judge_model_kwargs["limit_event_tracker"] = limit_event_tracker + judge_model_kwargs["limit_event_stage"] = "judge_model_init" + judge_model_kwargs["limit_event_model_spec"] = args.judge_model + return judge_model_kwargs + + +def _generation_cache_name(args: CliArgs, *, model_spec: str) -> str: + generation_config = { + "truncate_all_input_chars": args.truncate_all_input_chars, + "max_out_tokens_models": args.max_out_tokens_models, + "max_model_len": args.max_model_len, + "chat_template": args.chat_template, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "engine_kwargs": _build_generation_model_kwargs( + args=args, model_spec=model_spec + ), + } + generation_config_hash = hashlib.sha256( + json.dumps(generation_config, sort_keys=True, default=str).encode("utf-8") + ).hexdigest()[:12] + return f"{args.dataset}_{model_spec}_{args.n_instructions}_{generation_config_hash}" + + def print_results(results): """Print battle results in a nice formatted way""" @@ -174,6 +235,7 @@ def main(args: CliArgs): run_started_at = datetime.now(UTC) usage_tracker = OpenRouterReferencePricingTracker() + limit_event_tracker = LimitEventTracker() print( f"Using dataset {args.dataset} and evaluating models {args.model_A} and {args.model_B}." ) @@ -207,69 +269,59 @@ def main(args: CliArgs): f"{args.model_B} (or loading them directly if present)" ) - # TODO currently we just support base models for fluency, we could also support instruction-tuned models - gen_fun = ( - partial( - generate_base, - truncate_input_chars=args.truncate_all_input_chars, - max_tokens=args.max_out_tokens_models, - max_model_len=args.max_model_len, - chat_template=args.chat_template, - use_tqdm=args.use_tqdm, - **args.engine_kwargs, - ) - if is_fluency_task - else partial( - generate_instructions, + generation_function = generate_base if is_fluency_task else generate_instructions + + def _run_generation(model_spec: str, usage_phase: str) -> pd.DataFrame: + return generation_function( + instructions=instructions, + model=model_spec, truncate_input_chars=args.truncate_all_input_chars, max_tokens=args.max_out_tokens_models, max_model_len=args.max_model_len, chat_template=args.chat_template, use_tqdm=args.use_tqdm, - **args.engine_kwargs, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + limit_event_tracker=limit_event_tracker, + **_build_generation_model_kwargs(args=args, model_spec=model_spec), ) - ) + + def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: + return df.set_index("instruction_index").loc[instructions.index].reset_index() + dataset_completions_A = try_load_dataset_completions( args.dataset, args.model_A, n_instructions ) if dataset_completions_A is not None: - completions_A = dataset_completions_A.set_index("instruction_index").loc[ - instructions.index, "completion" - ] + completions_A_df = _align_completion_dataframe(dataset_completions_A) else: - completions_A = cache_function_dataframe( - lambda: gen_fun( - instructions=instructions, - model=args.model_A, - use_tqdm=args.use_tqdm, - usage_tracker=usage_tracker, - usage_phase="generation_model_A", - ), - ignore_cache=ignore_cache, - cache_name=f"{args.dataset}_{args.model_A}_{args.n_instructions}", - ).set_index("instruction_index") - completions_A = completions_A.loc[:, "completion"] + completions_A_df = _align_completion_dataframe( + cache_function_dataframe( + lambda: _run_generation(args.model_A, "generation_model_A"), + ignore_cache=ignore_cache, + cache_name=_generation_cache_name(args, model_spec=args.model_A), + ) + ) + completions_A = completions_A_df.set_index("instruction_index").loc[ + instructions.index, "completion" + ] dataset_completions_B = try_load_dataset_completions( args.dataset, args.model_B, n_instructions ) if dataset_completions_B is not None: - completions_B = dataset_completions_B.set_index("instruction_index").loc[ - instructions.index, "completion" - ] + completions_B_df = _align_completion_dataframe(dataset_completions_B) else: - completions_B = cache_function_dataframe( - lambda: gen_fun( - instructions=instructions, - model=args.model_B, - use_tqdm=args.use_tqdm, - usage_tracker=usage_tracker, - usage_phase="generation_model_B", - ), - ignore_cache=ignore_cache, - cache_name=f"{args.dataset}_{args.model_B}_{args.n_instructions}", - ).set_index("instruction_index") - completions_B = completions_B.loc[:, "completion"] + completions_B_df = _align_completion_dataframe( + cache_function_dataframe( + lambda: _run_generation(args.model_B, "generation_model_B"), + ignore_cache=ignore_cache, + cache_name=_generation_cache_name(args, model_spec=args.model_B), + ) + ) + completions_B = completions_B_df.set_index("instruction_index").loc[ + instructions.index, "completion" + ] print(f"\nFirst instruction/context: {instructions.values[0]}") print(f"\nFirst completion of {args.model_A}") @@ -278,16 +330,12 @@ def main(args: CliArgs): print(completions_B.values[0]) print(f"Evaluating completions with judge {args.judge_model}.") - judge_model_kwargs = build_default_judge_model_kwargs( - args.judge_model, args.engine_kwargs - ) - judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, max_model_len=args.max_model_len, chat_template=args.chat_template, - **judge_model_kwargs, + **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) name = f"{args.dataset}-{args.model_A}-{args.model_B}-{args.judge_model}" @@ -311,11 +359,9 @@ def main(args: CliArgs): else: # the default system prompt of annotate is to compare instruction tuned models. system_prompt = None - ( - effective_judge_system_prompt, - judge_user_prompt_template, - ) = resolve_judge_prompts( + resolved_prompt = resolve_judge_prompts( provide_explanation=args.provide_explanation, + prompt_preset=args.judge_prompt_preset or DEFAULT_JUDGE_PROMPT_PRESET, system_prompt=system_prompt, ) @@ -324,15 +370,20 @@ def main(args: CliArgs): instructions=instructions.head(n_instructions).tolist(), completions_A=completions_A.head(n_instructions).tolist(), completions_B=completions_B.head(n_instructions).tolist(), + case_ids=instructions.head(n_instructions).index.tolist(), swap_mode=args.swap_mode, provide_explanation=args.provide_explanation, - system_prompt=effective_judge_system_prompt, - user_prompt_template=judge_user_prompt_template, + prompt_preset=resolved_prompt.preset_name, + parser_mode=resolved_prompt.parser_mode, + strip_thinking_before_judging=args.strip_thinking_before_judging, + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=resolved_prompt.user_prompt_template, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, usage_tracker=usage_tracker, usage_phase="judge", usage_model_spec=args.judge_model, + limit_event_tracker=limit_event_tracker, ) df = pd.DataFrame(annotations) @@ -361,7 +412,11 @@ def main(args: CliArgs): "model_A": args.model_A, "model_B": args.model_B, "judge_model": args.judge_model, + "judge_prompt_preset": resolved_prompt.preset_name, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, **summary, + "limit_events": limit_event_tracker.build_summary(), "preferences": prefs.tolist(), } print(f"{args.model_A} vs {args.model_B} judged by {args.judge_model}") @@ -396,8 +451,8 @@ def main(args: CliArgs): "completions_A": eval_completions_A, "completions_B": eval_completions_B, }, - judge_system_prompt=effective_judge_system_prompt, - judge_user_prompt_template=judge_user_prompt_template, + judge_system_prompt=resolved_prompt.system_prompt, + judge_user_prompt_template=resolved_prompt.user_prompt_template, started_at_utc=run_started_at, pricing_reference=pricing_reference, ) diff --git a/judgearena/judge_prompt_presets.py b/judgearena/judge_prompt_presets.py new file mode 100644 index 0000000..7a14313 --- /dev/null +++ b/judgearena/judge_prompt_presets.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +JudgeParserMode = Literal["score", "verdict"] + +DEFAULT_JUDGE_PROMPT_PRESET = "default" +SKYWORK_JUDGE_PROMPT_PRESET = "skywork" +JUDGE_PROMPT_PRESETS = ( + DEFAULT_JUDGE_PROMPT_PRESET, + SKYWORK_JUDGE_PROMPT_PRESET, +) + +_PROMPTS_DIR = Path(__file__).resolve().parent / "prompts" +_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" +_SCORE_FENCE = "\n```" + + +@dataclass(frozen=True) +class PairwiseJudgePromptPreset: + name: str + parser_mode: JudgeParserMode + system_prompt_filename: str | None + user_prompt_filename: str + user_prompt_with_explanation_filename: str + + +@dataclass(frozen=True) +class ResolvedJudgePrompt: + preset_name: str + parser_mode: JudgeParserMode + system_prompt: str | None + user_prompt_template: str + + +_PAIRWISE_PROMPT_PRESETS: dict[str, PairwiseJudgePromptPreset] = { + DEFAULT_JUDGE_PROMPT_PRESET: PairwiseJudgePromptPreset( + name=DEFAULT_JUDGE_PROMPT_PRESET, + parser_mode="score", + system_prompt_filename="system-prompt.txt", + user_prompt_filename="prompt.txt", + user_prompt_with_explanation_filename="prompt-with-explanation.txt", + ), + SKYWORK_JUDGE_PROMPT_PRESET: PairwiseJudgePromptPreset( + name=SKYWORK_JUDGE_PROMPT_PRESET, + parser_mode="verdict", + system_prompt_filename=None, + user_prompt_filename="skywork-prompt.txt", + user_prompt_with_explanation_filename="skywork-prompt-with-explanation.txt", + ), +} + + +def _render_user_prompt_template( + raw_template: str, *, provide_explanation: bool +) -> str: + template = raw_template.replace( + "{explanation_suffix}", + _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + ) + return template + + +def resolve_pairwise_judge_prompt( + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool, + multi_turn: bool = False, + system_prompt: str | None = None, + user_prompt_template: str | None = None, +) -> ResolvedJudgePrompt: + preset = _PAIRWISE_PROMPT_PRESETS.get(prompt_preset) + if preset is None: + supported = ", ".join(sorted(_PAIRWISE_PROMPT_PRESETS)) + raise ValueError( + f"Unsupported judge prompt preset '{prompt_preset}'. Choose from: {supported}." + ) + + prompt_filename = ( + preset.user_prompt_with_explanation_filename + if provide_explanation + else preset.user_prompt_filename + ) + default_system_prompt = ( + (_PROMPTS_DIR / preset.system_prompt_filename).read_text(encoding="utf-8") + if preset.system_prompt_filename is not None + else None + ) + default_user_prompt_template = _render_user_prompt_template( + (_PROMPTS_DIR / prompt_filename).read_text(encoding="utf-8"), + provide_explanation=provide_explanation, + ) + return ResolvedJudgePrompt( + preset_name=preset.name, + parser_mode=preset.parser_mode, + system_prompt=system_prompt + if system_prompt is not None + else default_system_prompt, + user_prompt_template=user_prompt_template + if user_prompt_template is not None + else default_user_prompt_template, + ) diff --git a/judgearena/mt_bench/common.py b/judgearena/mt_bench/common.py index d676e05..51b0963 100644 --- a/judgearena/mt_bench/common.py +++ b/judgearena/mt_bench/common.py @@ -5,7 +5,7 @@ import pandas as pd -from judgearena.utils import safe_text +from judgearena.utils import safe_text_with_metadata @dataclass(frozen=True) @@ -20,6 +20,14 @@ class MTBenchPairwiseRow: answer_b_2: str ref_1: str ref_2: str + turn_1_question_truncated: bool = False + turn_2_question_truncated: bool = False + answer_a_1_truncated: bool = False + answer_a_2_truncated: bool = False + answer_b_1_truncated: bool = False + answer_b_2_truncated: bool = False + ref_1_truncated: bool = False + ref_2_truncated: bool = False def iter_mt_bench_pairwise_rows( @@ -41,27 +49,55 @@ def iter_mt_bench_pairwise_rows( if question_id in completions_b.index else completions_b.iloc[0] ) + turn_1_question, turn_1_question_truncated = safe_text_with_metadata( + row.get("turn_1"), + truncate_input_chars, + ) + turn_2_question, turn_2_question_truncated = safe_text_with_metadata( + row.get("turn_2"), + truncate_input_chars, + ) + answer_a_1, answer_a_1_truncated = safe_text_with_metadata( + comp_a_row.get("completion_turn_1", ""), + truncate_input_chars, + ) + answer_a_2, answer_a_2_truncated = safe_text_with_metadata( + comp_a_row.get("completion_turn_2", ""), + truncate_input_chars, + ) + answer_b_1, answer_b_1_truncated = safe_text_with_metadata( + comp_b_row.get("completion_turn_1", ""), + truncate_input_chars, + ) + answer_b_2, answer_b_2_truncated = safe_text_with_metadata( + comp_b_row.get("completion_turn_2", ""), + truncate_input_chars, + ) + ref_1, ref_1_truncated = safe_text_with_metadata( + row.get("reference_turn_1"), + truncate_input_chars, + ) + ref_2, ref_2_truncated = safe_text_with_metadata( + row.get("reference_turn_2"), + truncate_input_chars, + ) yield MTBenchPairwiseRow( question_id=question_id, category=row.get("category"), - turn_1_question=safe_text(row.get("turn_1"), truncate_input_chars), - turn_2_question=safe_text(row.get("turn_2"), truncate_input_chars), - answer_a_1=safe_text( - comp_a_row.get("completion_turn_1", ""), - truncate_input_chars, - ), - answer_a_2=safe_text( - comp_a_row.get("completion_turn_2", ""), - truncate_input_chars, - ), - answer_b_1=safe_text( - comp_b_row.get("completion_turn_1", ""), - truncate_input_chars, - ), - answer_b_2=safe_text( - comp_b_row.get("completion_turn_2", ""), - truncate_input_chars, - ), - ref_1=safe_text(row.get("reference_turn_1"), truncate_input_chars), - ref_2=safe_text(row.get("reference_turn_2"), truncate_input_chars), + turn_1_question=turn_1_question, + turn_2_question=turn_2_question, + answer_a_1=answer_a_1, + answer_a_2=answer_a_2, + answer_b_1=answer_b_1, + answer_b_2=answer_b_2, + ref_1=ref_1, + ref_2=ref_2, + turn_1_question_truncated=turn_1_question_truncated, + turn_2_question_truncated=turn_2_question_truncated, + answer_a_1_truncated=answer_a_1_truncated, + answer_a_2_truncated=answer_a_2_truncated, + answer_b_1_truncated=answer_b_1_truncated, + answer_b_2_truncated=answer_b_2_truncated, + ref_1_truncated=ref_1_truncated, + ref_2_truncated=ref_2_truncated, ) diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index 254d108..dcaec19 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -10,9 +10,18 @@ import pandas as pd from langchain_core.prompts import ChatPromptTemplate +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + SKYWORK_JUDGE_PROMPT_PRESET, +) from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker -from judgearena.utils import do_inference, strip_thinking_tags +from judgearena.utils import ( + LimitEventTracker, + do_inference, + strip_thinking_tags, + strip_thinking_tags_with_metadata, +) FASTCHAT_TEMPERATURE_CONFIG: dict[str, float] = { "writing": 0.7, @@ -180,6 +189,79 @@ def _load_pairwise_prompt( ) +_SKYWORK_PAIR_V2 = _load_pairwise_prompt( + name="skywork-pair-v2", + multi_turn=False, + ref_based=False, + system_user_subject="prompt displayed below", + system_task_description=( + "You should choose the assistant that follows the user's instructions and " + "answers the user's prompt better. Your evaluation should consider factors " + "such as helpfulness, relevance, accuracy, depth, creativity, and level " + "of detail of the responses." + ), + system_begin_instruction="carefully comparing the two responses", +) + +_SKYWORK_PAIR_V2_MULTI = _load_pairwise_prompt( + name="skywork-pair-v2-multi-turn", + multi_turn=True, + ref_based=False, + system_user_subject="questions", + system_task_description=( + "You should choose the assistant that follows the user's instructions and " + "answers the user's questions better. Your evaluation should consider " + "factors such as helpfulness, relevance, accuracy, depth, creativity, and " + "level of detail of the responses." + ), + system_focus_line=( + "You should focus on which assistant better answers the second user question." + ), + system_begin_instruction="carefully comparing the two conversations", +) + +_SKYWORK_PAIR_MATH_V1 = _load_pairwise_prompt( + name="skywork-pair-math-v1", + multi_turn=False, + ref_based=True, + system_user_subject="prompt displayed below", + system_task_description=( + "You will be given a reference answer, assistant A's answer, and " + "assistant B's answer. Your evaluation should focus on correctness and " + "helpfulness while deciding which assistant is better." + ), + system_begin_instruction="carefully comparing both assistants' answers with the reference answer", +) + +_SKYWORK_PAIR_MATH_V1_MULTI = _load_pairwise_prompt( + name="skywork-pair-math-v1-multi-turn", + multi_turn=True, + ref_based=True, + system_user_subject="questions", + system_task_description=( + "You will be given reference answers together with assistant A's and " + "assistant B's answers. Your evaluation should focus on correctness and " + "helpfulness while deciding which assistant better answers the second user question." + ), + system_begin_instruction="carefully comparing both assistants' answers with the reference answers", +) + +_FASTCHAT_PROMPT_PRESET_REGISTRY: dict[str, dict[str, FastChatPairwisePrompt]] = { + DEFAULT_JUDGE_PROMPT_PRESET: { + "single": _PAIR_V2, + "multi": _PAIR_V2_MULTI, + "single_ref": _PAIR_MATH_V1, + "multi_ref": _PAIR_MATH_V1_MULTI, + }, + SKYWORK_JUDGE_PROMPT_PRESET: { + "single": _SKYWORK_PAIR_V2, + "multi": _SKYWORK_PAIR_V2_MULTI, + "single_ref": _SKYWORK_PAIR_MATH_V1, + "multi_ref": _SKYWORK_PAIR_MATH_V1_MULTI, + }, +} + + def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: stripped = strip_thinking_tags(judgment).strip() if "[[A]]" in stripped: @@ -227,15 +309,26 @@ def _winner_to_preference(winner: PairwiseWinner) -> float: return math.nan -def _select_prompt(category: str | None, multi_turn: bool) -> FastChatPairwisePrompt: +def _select_prompt( + category: str | None, + multi_turn: bool, + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, +) -> FastChatPairwisePrompt: + prompt_variants = _FASTCHAT_PROMPT_PRESET_REGISTRY.get(prompt_preset) + if prompt_variants is None: + supported = ", ".join(sorted(_FASTCHAT_PROMPT_PRESET_REGISTRY)) + raise ValueError( + f"Unsupported MT-Bench prompt preset '{prompt_preset}'. Choose from: {supported}." + ) needs_ref = (category or "") in FASTCHAT_NEED_REF_CATS if needs_ref and multi_turn: - return _PAIR_MATH_V1_MULTI + return prompt_variants["multi_ref"] if needs_ref: - return _PAIR_MATH_V1 + return prompt_variants["single_ref"] if multi_turn: - return _PAIR_V2_MULTI - return _PAIR_V2 + return prompt_variants["multi"] + return prompt_variants["single"] def _group_indices_by_prompt( @@ -312,8 +405,38 @@ def _build_fastchat_judge_items( eval_single: bool, eval_multi: bool, truncate_input_chars: int | None, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, + limit_event_tracker: LimitEventTracker | None = None, ) -> list[dict[str, Any]]: items: list[dict[str, Any]] = [] + + def _record_mt_bench_truncation( + *, case_id: str, field: str, truncated: bool + ) -> None: + if truncated and limit_event_tracker is not None: + limit_event_tracker.record( + "mt_bench_field_char_truncation", + stage="judge_input", + field=field, + case_id=case_id, + ) + + def _prepare_answer(answer: str, *, case_id: str, field: str) -> tuple[str, bool]: + if not strip_thinking_before_judging: + return answer, False + stripped_answer, stripped = strip_thinking_tags_with_metadata(answer) + if stripped and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field=field, + case_id=case_id, + original_length=len(answer), + final_length=len(stripped_answer), + ) + return stripped_answer, stripped + for pair_row in iter_mt_bench_pairwise_rows( questions=questions, completions_a=completions_a, @@ -322,14 +445,51 @@ def _build_fastchat_judge_items( ): category = pair_row.category if eval_single: - prompt = _select_prompt(category, multi_turn=False) + case_id = f"{pair_row.question_id}:turn1" + prompt = _select_prompt( + category, multi_turn=False, prompt_preset=prompt_preset + ) + answer_a, answer_a_stripped = _prepare_answer( + pair_row.answer_a_1, case_id=case_id, field="answer_a_1" + ) + answer_b, answer_b_stripped = _prepare_answer( + pair_row.answer_b_1, case_id=case_id, field="answer_b_1" + ) + _record_mt_bench_truncation( + case_id=case_id, + field="turn_1_question", + truncated=pair_row.turn_1_question_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_a_1", + truncated=pair_row.answer_a_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_b_1", + truncated=pair_row.answer_b_1_truncated, + ) kwargs: dict[str, str] = { "question": pair_row.turn_1_question, - "answer_a": pair_row.answer_a_1, - "answer_b": pair_row.answer_b_1, + "answer_a": answer_a, + "answer_b": answer_b, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_a_1_reasoning_stripped": answer_a_stripped, + "answer_b_1_reasoning_stripped": answer_b_stripped, } if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) kwargs["ref_answer_1"] = pair_row.ref_1 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated items.append( { "question_id": pair_row.question_id, @@ -338,22 +498,73 @@ def _build_fastchat_judge_items( "prompt": prompt, "prompt_name": prompt.name, "prompt_kwargs": kwargs, + "limit_flags": limit_flags, } ) if eval_multi and pair_row.turn_2_question: - prompt = _select_prompt(category, multi_turn=True) + case_id = f"{pair_row.question_id}:turn2" + prompt = _select_prompt( + category, multi_turn=True, prompt_preset=prompt_preset + ) + answer_a_1, answer_a_1_stripped = _prepare_answer( + pair_row.answer_a_1, case_id=case_id, field="answer_a_1" + ) + answer_a_2, answer_a_2_stripped = _prepare_answer( + pair_row.answer_a_2, case_id=case_id, field="answer_a_2" + ) + answer_b_1, answer_b_1_stripped = _prepare_answer( + pair_row.answer_b_1, case_id=case_id, field="answer_b_1" + ) + answer_b_2, answer_b_2_stripped = _prepare_answer( + pair_row.answer_b_2, case_id=case_id, field="answer_b_2" + ) + for field, truncated in ( + ("turn_1_question", pair_row.turn_1_question_truncated), + ("turn_2_question", pair_row.turn_2_question_truncated), + ("answer_a_1", pair_row.answer_a_1_truncated), + ("answer_a_2", pair_row.answer_a_2_truncated), + ("answer_b_1", pair_row.answer_b_1_truncated), + ("answer_b_2", pair_row.answer_b_2_truncated), + ): + _record_mt_bench_truncation( + case_id=case_id, field=field, truncated=truncated + ) kwargs = { "question_1": pair_row.turn_1_question, "question_2": pair_row.turn_2_question, - "answer_a_1": pair_row.answer_a_1, - "answer_a_2": pair_row.answer_a_2, - "answer_b_1": pair_row.answer_b_1, - "answer_b_2": pair_row.answer_b_2, + "answer_a_1": answer_a_1, + "answer_a_2": answer_a_2, + "answer_b_1": answer_b_1, + "answer_b_2": answer_b_2, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "turn_2_question_truncated": pair_row.turn_2_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_a_2_truncated": pair_row.answer_a_2_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_b_2_truncated": pair_row.answer_b_2_truncated, + "answer_a_1_reasoning_stripped": answer_a_1_stripped, + "answer_a_2_reasoning_stripped": answer_a_2_stripped, + "answer_b_1_reasoning_stripped": answer_b_1_stripped, + "answer_b_2_reasoning_stripped": answer_b_2_stripped, } if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="ref_2", + truncated=pair_row.ref_2_truncated, + ) kwargs["ref_answer_1"] = pair_row.ref_1 kwargs["ref_answer_2"] = pair_row.ref_2 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + limit_flags["ref_2_truncated"] = pair_row.ref_2_truncated items.append( { "question_id": pair_row.question_id, @@ -362,6 +573,7 @@ def _build_fastchat_judge_items( "prompt": prompt, "prompt_name": prompt.name, "prompt_kwargs": kwargs, + "limit_flags": limit_flags, } ) return items @@ -398,6 +610,7 @@ def _resolve_fastchat_item_result( "g1_verdict": g1_verdict, "g1_winner": g1_winner, } + annotation_row.update(item.get("limit_flags", {})) if g2_raw is not None: g2_verdict = _parse_fastchat_verdict(g2_raw) @@ -442,8 +655,11 @@ def judge_mt_bench_pairwise_fastchat( swap_mode: str, truncate_input_chars: int | None, use_tqdm: bool, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + strip_thinking_before_judging: bool = False, usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, ) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: """Run FastChat-style MT-Bench pairwise judging with bracketed verdict outputs.""" assert turns_mode in ("both", "single", "multi") @@ -459,6 +675,9 @@ def judge_mt_bench_pairwise_fastchat( eval_single=eval_single, eval_multi=eval_multi, truncate_input_chars=truncate_input_chars, + prompt_preset=prompt_preset, + strip_thinking_before_judging=strip_thinking_before_judging, + limit_event_tracker=limit_event_tracker, ) g1_judgments = _infer_by_prompt_groups( diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 106812c..00f380a 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +import hashlib import json import os from dataclasses import asdict @@ -19,6 +20,7 @@ from judgearena.generate import generate_multiturn from judgearena.instruction_dataset import load_instructions from judgearena.instruction_dataset.mt_bench import load_mt_bench_model_answers +from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.mt_bench.fastchat_compat import ( FASTCHAT_TEMPERATURE_CONFIG, judge_mt_bench_pairwise_fastchat, @@ -30,9 +32,11 @@ ) from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( + LimitEventTracker, build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, + is_qwen_reasoning_model, make_model, ) @@ -46,6 +50,53 @@ _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN = 28672 +def _build_mt_bench_generation_kwargs( + *, args: CliArgs, model_spec: str +) -> dict[str, object]: + generation_model_kwargs = dict(args.engine_kwargs) + provider, _, model_name = model_spec.partition("/") + if ( + args.battle_thinking_token_budget is not None + and provider == "VLLM" + and is_qwen_reasoning_model(model_name) + ): + generation_model_kwargs["thinking_token_budget"] = min( + int(args.battle_thinking_token_budget), + int(args.max_out_tokens_models), + ) + return generation_model_kwargs + + +def _build_mt_bench_judge_model_kwargs( + *, args: CliArgs, limit_event_tracker: LimitEventTracker | None +) -> dict[str, object]: + judge_model_kwargs = build_default_judge_model_kwargs( + args.judge_model, args.engine_kwargs + ) + if limit_event_tracker is not None: + judge_model_kwargs["limit_event_tracker"] = limit_event_tracker + judge_model_kwargs["limit_event_stage"] = "judge_model_init" + judge_model_kwargs["limit_event_model_spec"] = args.judge_model + return judge_model_kwargs + + +def _mt_bench_generation_cache_name(args: CliArgs, *, model_name: str) -> str: + generation_config = { + "truncate_all_input_chars": args.truncate_all_input_chars, + "max_out_tokens_models": args.max_out_tokens_models, + "max_model_len": args.max_model_len, + "chat_template": args.chat_template, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "engine_kwargs": _build_mt_bench_generation_kwargs( + args=args, model_spec=model_name + ), + } + generation_config_hash = hashlib.sha256( + json.dumps(generation_config, sort_keys=True, default=str).encode("utf-8") + ).hexdigest()[:12] + return f"mt-bench_{model_name}_{args.n_instructions}_{generation_config_hash}" + + def _align_mt_bench_completions( *, questions_df: pd.DataFrame, completions: pd.DataFrame, model_name: str ) -> pd.DataFrame: @@ -66,9 +117,9 @@ def _generate_mt_bench_completions( questions_df: pd.DataFrame, ignore_cache: bool, usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Load baseline MT-Bench answers or generate fresh multi-turn outputs.""" - cache_prefix = "mt-bench" def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: return generate_multiturn( @@ -82,7 +133,8 @@ def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: temperature_config=FASTCHAT_TEMPERATURE_CONFIG, usage_tracker=usage_tracker, usage_phase=usage_phase, - **args.engine_kwargs, + limit_event_tracker=limit_event_tracker, + **_build_mt_bench_generation_kwargs(args=args, model_spec=model_name), ) def _load_or_generate(model_name: str, usage_phase: str) -> pd.DataFrame: @@ -99,7 +151,7 @@ def _load_or_generate(model_name: str, usage_phase: str) -> pd.DataFrame: generated_answers = cache_function_dataframe( lambda: _run_generation(model_name, usage_phase), ignore_cache=ignore_cache, - cache_name=f"{cache_prefix}_{model_name}_{args.n_instructions}", + cache_name=_mt_bench_generation_cache_name(args, model_name=model_name), ) return _align_mt_bench_completions( questions_df=questions_df, @@ -166,7 +218,9 @@ def _run_mt_bench_fastchat( completions_a: pd.DataFrame, completions_b: pd.DataFrame, judge_chat_model, + prompt_preset: str, usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, started_at_utc: datetime, ) -> pd.Series: """Run FastChat-style MT-Bench judging and save the resulting artifacts.""" @@ -183,8 +237,11 @@ def _run_mt_bench_fastchat( swap_mode=args.swap_mode, truncate_input_chars=args.truncate_all_input_chars, use_tqdm=args.use_tqdm, + prompt_preset=prompt_preset, + strip_thinking_before_judging=args.strip_thinking_before_judging, usage_tracker=usage_tracker, usage_phase="judge", + limit_event_tracker=limit_event_tracker, ) ) @@ -194,8 +251,14 @@ def _run_mt_bench_fastchat( "model_A": args.model_A, "model_B": args.model_B, "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, "num_inconsistent": num_inconsistent, **stats, + "limit_events": limit_event_tracker.build_summary() + if limit_event_tracker is not None + else {}, "per_category": _compute_grouped_stats(prefs, combined_metadata, "category"), "per_turn": _compute_grouped_stats(prefs, combined_metadata, "turn"), "preferences": prefs.tolist(), @@ -228,7 +291,9 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): """MT-Bench pipeline with FastChat-compatible pairwise judging.""" run_started_at = datetime.now(UTC) usage_tracker = OpenRouterReferencePricingTracker() - if not args.provide_explanation: + limit_event_tracker = LimitEventTracker() + prompt_preset = args.judge_prompt_preset or DEFAULT_JUDGE_PROMPT_PRESET + if prompt_preset == DEFAULT_JUDGE_PROMPT_PRESET and not args.provide_explanation: print( "MT-Bench ignores provide_explanation=False and keeps the original " "FastChat-style explanation-plus-verdict prompt." @@ -256,6 +321,7 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): questions_df=questions_df, ignore_cache=ignore_cache, usage_tracker=usage_tracker, + limit_event_tracker=limit_event_tracker, ) if ( args.max_model_len is not None @@ -268,17 +334,15 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN} for the judge." ) args.max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN - judge_model_kwargs = build_default_judge_model_kwargs( - args.judge_model, args.engine_kwargs - ) - judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, temperature=0.0, max_model_len=args.max_model_len, chat_template=args.chat_template, - **judge_model_kwargs, + **_build_mt_bench_judge_model_kwargs( + args=args, limit_event_tracker=limit_event_tracker + ), ) return _run_mt_bench_fastchat( args=args, @@ -286,6 +350,8 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): completions_a=completions_a, completions_b=completions_b, judge_chat_model=judge_chat_model, + prompt_preset=prompt_preset, usage_tracker=usage_tracker, + limit_event_tracker=limit_event_tracker, started_at_utc=run_started_at, ) diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index 38021e6..1b93858 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's {completion_label}|> +<|The Start of Assistant A's Answer|> {completion_A} -<|The End of Assistant A's {completion_label}|> +<|The End of Assistant A's Answer|> -<|The Start of Assistant B's {completion_label}|> +<|The Start of Assistant B's Answer|> {completion_B} -<|The End of Assistant B's {completion_label}|> +<|The End of Assistant B's Answer|> # Your output diff --git a/judgearena/prompts/skywork-prompt-with-explanation.txt b/judgearena/prompts/skywork-prompt-with-explanation.txt new file mode 100644 index 0000000..e1f9250 --- /dev/null +++ b/judgearena/prompts/skywork-prompt-with-explanation.txt @@ -0,0 +1,14 @@ +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. +Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +Please briefly explain your reasoning first, then directly output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better. + +[User Question] +{user_prompt} + +[The Start of Assistant A's Answer] +{completion_A} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{completion_B} +[The End of Assistant B's Answer] diff --git a/judgearena/prompts/skywork-prompt.txt b/judgearena/prompts/skywork-prompt.txt new file mode 100644 index 0000000..97ad3c8 --- /dev/null +++ b/judgearena/prompts/skywork-prompt.txt @@ -0,0 +1,14 @@ +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. +Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +Please directly output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better. + +[User Question] +{user_prompt} + +[The Start of Assistant A's Answer] +{completion_A} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{completion_B} +[The End of Assistant B's Answer] diff --git a/judgearena/utils.py b/judgearena/utils.py index f544a4d..7bd0809 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -3,8 +3,11 @@ import re import time import warnings +from collections import Counter from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path +from typing import Any import pandas as pd from huggingface_hub import snapshot_download @@ -77,6 +80,81 @@ def _resolve_chat_template_kwargs( return chat_template_kwargs or None +@dataclass(frozen=True) +class LimitEvent: + kind: str + stage: str + field: str | None = None + case_id: str | None = None + model_spec: str | None = None + original_length: int | None = None + final_length: int | None = None + note: str | None = None + + +class LimitEventTracker: + def __init__(self) -> None: + self.events: list[LimitEvent] = [] + + def record( + self, + kind: str, + *, + stage: str, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, + original_length: int | None = None, + final_length: int | None = None, + note: str | None = None, + ) -> None: + self.events.append( + LimitEvent( + kind=kind, + stage=stage, + field=field, + case_id=None if case_id is None else str(case_id), + model_spec=model_spec, + original_length=original_length, + final_length=final_length, + note=note, + ) + ) + + def build_summary(self) -> dict[str, Any]: + counts_by_kind: Counter[str] = Counter() + counts_by_stage: Counter[str] = Counter() + counts_by_kind_and_field: dict[str, Counter[str]] = {} + affected_cases_total: set[str] = set() + affected_cases_by_kind: dict[str, set[str]] = {} + + for event in self.events: + counts_by_kind[event.kind] += 1 + counts_by_stage[event.stage] += 1 + field_key = event.field or "_all" + counts_by_kind_and_field.setdefault(event.kind, Counter())[field_key] += 1 + if event.case_id is None: + continue + case_key = f"{event.stage}:{event.case_id}" + affected_cases_total.add(case_key) + affected_cases_by_kind.setdefault(event.kind, set()).add(case_key) + + return { + "total_events": len(self.events), + "counts_by_kind": dict(sorted(counts_by_kind.items())), + "counts_by_stage": dict(sorted(counts_by_stage.items())), + "counts_by_kind_and_field": { + kind: dict(sorted(counter.items())) + for kind, counter in sorted(counts_by_kind_and_field.items()) + }, + "affected_cases_total": len(affected_cases_total), + "affected_cases_by_kind": { + kind: len(case_ids) + for kind, case_ids in sorted(affected_cases_by_kind.items()) + }, + } + + def set_langchain_cache(): set_llm_cache(SQLiteCache(database_path=str(data_root / ".langchain.db"))) @@ -157,6 +235,33 @@ def truncate(s: str, max_len: int | None = None) -> str: return s +def truncate_with_metadata( + s: str | None, + max_len: int | None = None, + *, + tracker: LimitEventTracker | None = None, + kind: str | None = None, + stage: str | None = None, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, +) -> tuple[str, bool]: + original = s if isinstance(s, str) else "" + truncated = truncate(original, max_len=max_len) + was_truncated = truncated != original + if was_truncated and tracker is not None and kind is not None and stage is not None: + tracker.record( + kind, + stage=stage, + field=field, + case_id=case_id, + model_spec=model_spec, + original_length=len(original), + final_length=len(truncated), + ) + return truncated, was_truncated + + def safe_text(value: object, truncate_chars: int | None) -> str: """Coerce *value* to a string and optionally truncate. @@ -171,14 +276,62 @@ def safe_text(value: object, truncate_chars: int | None) -> str: return truncate(str(value), max_len=truncate_chars) +def safe_text_with_metadata( + value: object, + truncate_chars: int | None, + *, + tracker: LimitEventTracker | None = None, + kind: str | None = None, + stage: str | None = None, + field: str | None = None, + case_id: object | None = None, + model_spec: str | None = None, +) -> tuple[str, bool]: + if value is None: + return "", False + is_missing = pd.isna(value) + if isinstance(is_missing, bool) and is_missing: + return "", False + return truncate_with_metadata( + str(value), + max_len=truncate_chars, + tracker=tracker, + kind=kind, + stage=stage, + field=field, + case_id=case_id, + model_spec=model_spec, + ) + + _THINK_BLOCK_RE = re.compile(r".*?", re.IGNORECASE | re.DOTALL) def strip_thinking_tags(text: str | None) -> str: """Remove full `...` blocks from raw model output.""" + return strip_thinking_tags_with_metadata(text)[0] + + +def strip_thinking_tags_with_metadata(text: str | None) -> tuple[str, bool]: + """Remove visible reasoning spans from raw model output.""" if not isinstance(text, str): - return "" - return _THINK_BLOCK_RE.sub("", text) + return "", False + + cleaned = _THINK_BLOCK_RE.sub("", text) + if cleaned != text: + return cleaned.lstrip(), True + + lowered = text.lower() + closing_tag = "" + closing_idx = lowered.find(closing_tag) + if closing_idx != -1 and "" not in lowered[:closing_idx]: + return text[closing_idx + len(closing_tag) :].lstrip(), True + + qwen_end_idx = text.find(VLLM_QWEN_REASONING_END_STR) + if qwen_end_idx != -1: + return text[qwen_end_idx + len(VLLM_QWEN_REASONING_END_STR) :].lstrip(), True + + return text, False def do_inference( @@ -188,6 +341,7 @@ def do_inference( usage_tracker: OpenRouterReferencePricingTracker | None = None, usage_phase: str | None = None, usage_model_spec: str | None = None, + return_metadata: bool = False, ): # Retries on rate-limit/server errors with exponential backoff. # Async path retries individual calls; batch path splits into 4^attempt chunks on failure. @@ -195,6 +349,7 @@ def do_inference( # "stop": ["```"], # "max_tokens": 100, } + metadata: list[dict[str, Any]] | None = None if use_tqdm: # perform inference asynchronously to be able to update tqdm, chat_model.batch does not work as it blocks until # all requests are received @@ -224,6 +379,8 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): chat_model=chat_model, inputs=inputs, pbar=pbar ) ) + if return_metadata: + metadata = [{} for _ in res] else: def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): @@ -236,9 +393,24 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): ] try: results = [] + results_metadata = [] for chunk in chunks: - results.extend(chat_model.batch(inputs=chunk, **invoke_kwargs)) - return results + if return_metadata and hasattr( + chat_model, "batch_with_metadata" + ): + chunk_results, chunk_metadata = ( + chat_model.batch_with_metadata( + inputs=chunk, **invoke_kwargs + ) + ) + else: + chunk_results = chat_model.batch( + inputs=chunk, **invoke_kwargs + ) + chunk_metadata = [{} for _ in chunk_results] + results.extend(chunk_results) + results_metadata.extend(chunk_metadata) + return results, results_metadata except Exception as e: if attempt == max_retries - 1 or not _is_retryable_error(e): raise @@ -249,7 +421,7 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): ) time.sleep(delay) - res = batch_with_retry(inputs) + res, metadata = batch_with_retry(inputs) # Not sure why the API of Langchain returns sometime a string and sometimes an AIMessage object # is it because of using Chat and barebones models? @@ -273,6 +445,8 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): f"Warning: failed to record token usage for phase " f"'{usage_phase}' ({usage_model_spec}): {e}" ) + if return_metadata: + return res, (metadata or [{} for _ in res]) return res @@ -318,6 +492,13 @@ def __init__( self.model_path = model self.max_tokens = max_tokens + limit_event_tracker: LimitEventTracker | None = vllm_kwargs.pop( + "limit_event_tracker", None + ) + limit_event_stage = str(vllm_kwargs.pop("limit_event_stage", "model_init")) + limit_event_model_spec = str( + vllm_kwargs.pop("limit_event_model_spec", f"VLLM/{model}") + ) disable_thinking = bool(vllm_kwargs.pop("disable_thinking", False)) thinking_token_budget = vllm_kwargs.pop("thinking_token_budget", None) explicit_chat_template_kwargs = vllm_kwargs.pop("chat_template_kwargs", None) @@ -339,6 +520,15 @@ def __init__( config = AutoConfig.from_pretrained(model, trust_remote_code=True) model_max_pos = getattr(config, "max_position_embeddings", None) if model_max_pos is not None and max_model_len > model_max_pos: + if limit_event_tracker is not None: + limit_event_tracker.record( + "max_model_len_clamped", + stage=limit_event_stage, + field="max_model_len", + model_spec=limit_event_model_spec, + original_length=int(max_model_len), + final_length=int(model_max_pos), + ) warnings.warn( f"Capping max_model_len from {max_model_len} to " f"{model_max_pos} (max_position_embeddings) for '{model}'.", @@ -359,6 +549,8 @@ def __init__( "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } if thinking_token_budget is not None: + if max_tokens is not None: + thinking_token_budget = min(int(thinking_token_budget), int(max_tokens)) if explicit_reasoning_settings: self._sampling_params_kwargs["thinking_token_budget"] = int( thinking_token_budget @@ -473,7 +665,7 @@ def _to_raw_text(self, input_item) -> str: return "\n".join(msg["content"] for msg in input_item) raise ValueError(f"Cannot extract raw text from: {type(input_item)}") - def batch(self, inputs: list, **invoke_kwargs) -> list[str]: + def _run_raw_batch(self, inputs: list): """Process a batch of inputs using vllm.LLM.chat() or llm.generate(). Uses ``llm.chat()`` when a chat template is available (instruct models), @@ -491,7 +683,28 @@ def batch(self, inputs: list, **invoke_kwargs) -> list[str]: chat_template=self.chat_template, chat_template_kwargs=self._chat_template_kwargs, ) - return [out.outputs[0].text for out in outputs] + return outputs + + def batch_with_metadata( + self, inputs: list, **invoke_kwargs + ) -> tuple[list[str], list[dict[str, Any]]]: + outputs = self._run_raw_batch(inputs) + texts: list[str] = [] + metadata: list[dict[str, Any]] = [] + for out in outputs: + first_output = out.outputs[0] + texts.append(first_output.text) + metadata.append( + { + "finish_reason": getattr(first_output, "finish_reason", None), + "stop_reason": getattr(first_output, "stop_reason", None), + } + ) + return texts, metadata + + def batch(self, inputs: list, **invoke_kwargs) -> list[str]: + texts, _metadata = self.batch_with_metadata(inputs, **invoke_kwargs) + return texts def _count_chat_prompt_tokens(self, messages: list[dict]) -> int: tokenizer_kwargs: dict[str, object] = { @@ -554,6 +767,9 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): # NOTE: this is a shallow copy since we are not modifying any # mutable objects in the dictionary. engine_kwargs = engine_kwargs.copy() + limit_event_tracker = engine_kwargs.pop("limit_event_tracker", None) + limit_event_stage = engine_kwargs.pop("limit_event_stage", None) + limit_event_model_spec = engine_kwargs.pop("limit_event_model_spec", None) # Dedicated arguments like max_tokens always win over engine_kwargs. engine_kwargs["max_tokens"] = max_tokens or 8192 @@ -569,6 +785,12 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): if model_provider == "VLLM": engine_kwargs = {k: v for k, v in engine_kwargs.items() if v is not None} engine_kwargs["chat_template"] = engine_kwargs.get("chat_template", None) + if limit_event_tracker is not None: + engine_kwargs["limit_event_tracker"] = limit_event_tracker + if limit_event_stage is not None: + engine_kwargs["limit_event_stage"] = limit_event_stage + if limit_event_model_spec is not None: + engine_kwargs["limit_event_model_spec"] = limit_event_model_spec return ChatVLLM( model=model_name, diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index f2f23f0..1c0eaaf 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -54,7 +54,7 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc utils.ChatVLLM( model="Qwen/Qwen3.5-27B-FP8", - max_tokens=32, + max_tokens=128, thinking_token_budget=64, gpu_memory_utilization=0.7, ) @@ -70,6 +70,19 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) +def test_chat_vllm_clamps_thinking_budget_to_total_max_tokens(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="Qwen/Qwen3.5-27B-FP8", + max_tokens=32, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 32 + + def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) chat_model = utils.ChatVLLM( @@ -121,7 +134,7 @@ def test_chat_vllm_preserves_explicit_reasoning_settings_for_non_qwen(monkeypatc gpu_memory_utilization=0.7, ) - assert captured["sampling_kwargs"]["thinking_token_budget"] == 32 + assert captured["sampling_kwargs"]["thinking_token_budget"] == 16 assert captured["llm_init"]["kwargs"]["reasoning_parser"] == "custom-parser" assert ( captured["llm_init"]["kwargs"]["reasoning_config"] is explicit_reasoning_config diff --git a/tests/test_generate_and_evaluate.py b/tests/test_generate_and_evaluate.py index f01e916..ada7ebe 100644 --- a/tests/test_generate_and_evaluate.py +++ b/tests/test_generate_and_evaluate.py @@ -96,3 +96,25 @@ def test_generate_and_evaluate_correct_order_bias(tmp_path): avg_pref = sum(prefs) / len(prefs) assert avg_pref == 0.5 + + +def test_cli_args_parse_optional_boolean_flags(monkeypatch): + monkeypatch.setattr( + "sys.argv", + [ + "generate_and_evaluate.py", + "--dataset=alpaca-eval", + "--model_A=Dummy/A", + "--model_B=Dummy/B", + "--judge_model=Dummy/Judge", + "--use_tqdm=True", + "--ignore_cache=True", + "--strip_thinking_before_judging=False", + ], + ) + + args = CliArgs.parse_args() + + assert args.use_tqdm is True + assert args.ignore_cache is True + assert args.strip_thinking_before_judging is False diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index d3641b5..c9b22e2 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -2,8 +2,10 @@ import judgearena.evaluate as evaluate import judgearena.generate_and_evaluate as generate_and_evaluate +from judgearena.cli_common import parse_optional_bool from judgearena.generate_and_evaluate import CliArgs from judgearena.generate_and_evaluate import main as main_generate_and_eval +from judgearena.judge_prompt_presets import SKYWORK_JUDGE_PROMPT_PRESET def test_load_judge_prompt_without_explanation_uses_freeform_scores(): @@ -14,6 +16,7 @@ def test_load_judge_prompt_without_explanation_uses_freeform_scores(): assert "valid JSON" not in user_prompt assert "score_A:" in user_prompt assert "score_B:" in user_prompt + assert "Assistant A's Answer" in user_prompt def test_load_judge_prompt_with_explanation_uses_freeform_scores(): @@ -25,15 +28,21 @@ def test_load_judge_prompt_with_explanation_uses_freeform_scores(): assert "first starts with an explanation of your judgement" in user_prompt assert "score_A:" in user_prompt assert "score_B:" in user_prompt + assert "Assistant B's Answer" in user_prompt -def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): +def test_parse_optional_bool_accepts_explicit_true_false_values(): + assert parse_optional_bool(None) is True + assert parse_optional_bool("true") is True + assert parse_optional_bool("False") is False + + +def test_main_passes_qwen_defaults_and_aligns_dataset_completions( + tmp_path, monkeypatch +): instructions = pd.DataFrame( - {"instruction": ["Instruction A"]}, - index=pd.Index([1], name="instruction_index"), - ) - completions_df = pd.DataFrame( - {"instruction_index": [1], "completion": ["Loaded answer"]} + {"instruction": ["Instruction B", "Instruction A"]}, + index=pd.Index(["b", "a"], name="instruction_index"), ) captured = {} @@ -42,10 +51,26 @@ def test_main_passes_thinking_budget_to_vllm_judge(tmp_path, monkeypatch): "load_instructions", lambda dataset, n_instructions=None: instructions, ) + + def fake_try_load_dataset_completions(dataset, model, n_instructions): + if model == "Dummy/model-a": + return pd.DataFrame( + { + "instruction_index": ["a", "b"], + "completion": ["Answer A", "no answer"], + } + ) + return pd.DataFrame( + { + "instruction_index": ["a", "b"], + "completion": ["Answer B", "Answer C"], + } + ) + monkeypatch.setattr( generate_and_evaluate, "try_load_dataset_completions", - lambda dataset, model, n_instructions: completions_df, + fake_try_load_dataset_completions, ) def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): @@ -58,69 +83,18 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs } return object() - monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) - monkeypatch.setattr( - generate_and_evaluate, - "judge_and_parse_prefs", - lambda **kwargs: ( - [{"judge_completion": "score_A: 1\nscore_B: 2"}], - None, - pd.Series([1.0]), - ), - ) - - prefs = main_generate_and_eval( - CliArgs( - dataset="alpaca-eval", - model_A="Dummy/model-a", - model_B="Dummy/model-b", - judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", - n_instructions=1, - result_folder=str(tmp_path / "results"), + def fake_judge_and_parse_prefs(**kwargs): + captured["judge_kwargs"] = kwargs + annotations = [{"judge_completion": "score_A: 0\nscore_B: 10"}] * len( + kwargs["instructions"] ) - ) - - assert prefs.tolist() == [1.0] - assert "structured_outputs_json" not in captured["make_model"]["kwargs"] - assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 - - -def test_main_passes_thinking_budget_to_vllm_judge_when_explanation_requested( - tmp_path, monkeypatch -): - instructions = pd.DataFrame( - {"instruction": ["Instruction A"]}, - index=pd.Index([1], name="instruction_index"), - ) - completions_df = pd.DataFrame( - {"instruction_index": [1], "completion": ["Loaded answer"]} - ) - captured = {} - - monkeypatch.setattr( - generate_and_evaluate, - "load_instructions", - lambda dataset, n_instructions=None: instructions, - ) - monkeypatch.setattr( - generate_and_evaluate, - "try_load_dataset_completions", - lambda dataset, model, n_instructions: completions_df, - ) - - def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs): - captured["make_model"] = kwargs - return object() + return annotations, None, pd.Series([1.0] * len(kwargs["instructions"])) monkeypatch.setattr(generate_and_evaluate, "make_model", fake_make_model) monkeypatch.setattr( generate_and_evaluate, "judge_and_parse_prefs", - lambda **kwargs: ( - [{"judge_completion": "Explanation: ok\nscore_A: 1\nscore_B: 2"}], - None, - pd.Series([1.0]), - ), + fake_judge_and_parse_prefs, ) prefs = main_generate_and_eval( @@ -129,15 +103,27 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", - n_instructions=1, - provide_explanation=True, + n_instructions=2, result_folder=str(tmp_path / "results"), ) ) - assert prefs.tolist() == [1.0] - assert "structured_outputs_json" not in captured["make_model"] - assert captured["make_model"]["thinking_token_budget"] == 512 + assert prefs.tolist() == [1.0, 1.0] + assert captured["make_model"]["kwargs"]["thinking_token_budget"] == 512 + assert captured["make_model"]["kwargs"]["limit_event_stage"] == "judge_model_init" + assert captured["make_model"]["kwargs"]["limit_event_model_spec"] == ( + "VLLM/Qwen/Qwen3.5-27B-FP8" + ) + assert captured["judge_kwargs"]["instructions"] == [ + "Instruction B", + "Instruction A", + ] + assert captured["judge_kwargs"]["completions_A"] == ["no answer", "Answer A"] + assert captured["judge_kwargs"]["completions_B"] == ["Answer C", "Answer B"] + assert captured["judge_kwargs"]["case_ids"] == ["b", "a"] + assert captured["judge_kwargs"]["prompt_preset"] == "default" + assert captured["judge_kwargs"]["parser_mode"] == "score" + assert captured["judge_kwargs"]["strip_thinking_before_judging"] is False def test_main_does_not_pass_thinking_budget_to_non_reasoning_vllm_judge( @@ -191,6 +177,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs assert prefs.tolist() == [1.0] assert "thinking_token_budget" not in captured["make_model"] + assert captured["make_model"]["limit_event_stage"] == "judge_model_init" def test_main_preserves_explicit_reasoning_engine_kwargs_for_non_qwen_vllm_judge( @@ -251,9 +238,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs assert captured["make_model"]["thinking_token_budget"] == 2048 -def test_annotate_battles_warns_when_judge_completions_are_truncated( - monkeypatch, capsys -): +def test_annotate_battles_warns_when_judge_inputs_are_truncated(monkeypatch, capsys): captured = {} def fake_do_inference( @@ -280,8 +265,7 @@ def fake_do_inference( stdout = capsys.readouterr().out assert ( - "Warning: truncated 2 judge completions to 3 characters before evaluation." - in stdout + "Warning: truncated 2 judge inputs to 3 characters before evaluation." in stdout ) assert "Ans" in captured["judge_prompt"] assert "Answer A" not in captured["judge_prompt"] @@ -290,3 +274,189 @@ def fake_do_inference( assert "score_A:" in captured["judge_prompt"] assert annotations[0].completion_A == "Answer A" assert annotations[0].completion_B == "Answer B" + + +def test_resolve_judge_prompts_supports_optional_skywork_preset(): + resolved = evaluate.resolve_judge_prompts( + provide_explanation=False, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert resolved.preset_name == SKYWORK_JUDGE_PROMPT_PRESET + assert resolved.parser_mode == "verdict" + assert resolved.system_prompt is None + assert "[[A]]" in resolved.user_prompt_template + assert "score_A:" not in resolved.user_prompt_template + assert "[User Question]" in resolved.user_prompt_template + assert "Assistant A's Answer" in resolved.user_prompt_template + + +def test_resolve_judge_prompts_skywork_explanation_prompt_has_fixed_answer_labels(): + resolved = evaluate.resolve_judge_prompts( + provide_explanation=True, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert ( + "Please briefly explain your reasoning first" in resolved.user_prompt_template + ) + assert "Assistant B's Answer" in resolved.user_prompt_template + + +def test_annotate_battles_records_limit_events_for_stripping_and_truncation( + monkeypatch, +): + tracker = evaluate.LimitEventTracker() + + monkeypatch.setattr( + evaluate, + "do_inference", + lambda **kwargs: ["[[A]]"], + ) + + evaluate.annotate_battles( + judge_chat_model=object(), + instructions=["Instruction"], + completions_A=["hiddenVisible answer"], + completions_B=["Short"], + case_ids=["case-1"], + truncate_input_chars=5, + strip_thinking_before_judging=True, + limit_event_tracker=tracker, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + summary = tracker.build_summary() + + assert summary["counts_by_kind"]["thinking_trace_stripped_before_judging"] == 1 + assert summary["counts_by_kind"]["judge_input_char_truncation"] == 1 + + +def test_main_passes_qwen_only_battle_budget_and_prompt_preset(tmp_path, monkeypatch): + instructions = pd.DataFrame( + {"instruction": ["Instruction A"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + generate_and_evaluate, + "load_instructions", + lambda dataset, n_instructions=None: instructions, + ) + monkeypatch.setattr( + generate_and_evaluate, + "try_load_dataset_completions", + lambda dataset, model, n_instructions: None, + ) + monkeypatch.setattr( + generate_and_evaluate, + "cache_function_dataframe", + lambda fun, **_kwargs: fun(), + ) + + def fake_generate_instructions( + *, + instructions, + model, + truncate_input_chars, + max_tokens, + max_model_len, + chat_template, + use_tqdm, + usage_tracker, + usage_phase, + limit_event_tracker, + **engine_kwargs, + ): + captured.setdefault("generation_calls", []).append( + { + "model": model, + "max_tokens": max_tokens, + "engine_kwargs": engine_kwargs, + } + ) + return pd.DataFrame( + { + "instruction_index": [1], + "completion": [f"{model}-answer"], + "generation_prompt_truncated": [False], + "generation_output_finish_reason": [None], + "generation_output_hit_token_limit": [False], + } + ) + + monkeypatch.setattr( + generate_and_evaluate, "generate_instructions", fake_generate_instructions + ) + + def fake_judge_and_parse_prefs(**kwargs): + captured["judge_kwargs"] = kwargs + return [{"judge_completion": "[[A]]"}], None, pd.Series([0.0]) + + monkeypatch.setattr( + generate_and_evaluate, + "judge_and_parse_prefs", + fake_judge_and_parse_prefs, + ) + monkeypatch.setattr( + generate_and_evaluate, + "make_model", + lambda **kwargs: object(), + ) + + prefs = main_generate_and_eval( + CliArgs( + dataset="alpaca-eval", + model_A="VLLM/Qwen/Qwen3.5-27B-FP8", + model_B="VLLM/allenai/Olmo-3-7B-Instruct", + judge_model="Dummy/judge", + n_instructions=1, + judge_prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + battle_thinking_token_budget=512, + strip_thinking_before_judging=True, + result_folder=str(tmp_path / "results"), + ) + ) + + assert prefs.tolist() == [0.0] + assert len(captured["generation_calls"]) == 2 + assert ( + captured["generation_calls"][0]["engine_kwargs"]["thinking_token_budget"] == 512 + ) + assert ( + "thinking_token_budget" not in captured["generation_calls"][1]["engine_kwargs"] + ) + assert captured["judge_kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET + assert captured["judge_kwargs"]["parser_mode"] == "verdict" + assert captured["judge_kwargs"]["strip_thinking_before_judging"] is True + + +def test_generation_cache_name_changes_with_generation_settings(): + args = CliArgs( + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="Dummy/judge", + n_instructions=1, + max_out_tokens_models=1024, + battle_thinking_token_budget=256, + ) + changed_args = CliArgs( + dataset="alpaca-eval", + model_A="Dummy/model-a", + model_B="Dummy/model-b", + judge_model="Dummy/judge", + n_instructions=1, + max_out_tokens_models=4096, + battle_thinking_token_budget=512, + ) + + cache_name = generate_and_evaluate._generation_cache_name( + args, model_spec="VLLM/Qwen/Qwen3.5-27B-FP8" + ) + changed_cache_name = generate_and_evaluate._generation_cache_name( + changed_args, model_spec="VLLM/Qwen/Qwen3.5-27B-FP8" + ) + + assert cache_name != changed_cache_name diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index c6349be..db78355 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -6,6 +6,7 @@ import judgearena.mt_bench.fastchat_compat as fastchat_compat import judgearena.mt_bench.mt_bench_utils as mt_bench_utils import judgearena.utils as utils +from judgearena.judge_prompt_presets import SKYWORK_JUDGE_PROMPT_PRESET def test_download_mt_bench_skips_question_download_if_cached(tmp_path, monkeypatch): @@ -130,6 +131,7 @@ def fake_generate_multiturn( temperature_config, usage_tracker, usage_phase, + limit_event_tracker, **engine_kwargs, ): generated_models.append(model) @@ -168,6 +170,7 @@ def fake_generate_multiturn( use_tqdm=False, max_model_len=16384, chat_template=None, + battle_thinking_token_budget=None, engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, ) @@ -176,6 +179,7 @@ def fake_generate_multiturn( questions_df=questions_df, ignore_cache=False, usage_tracker=object(), + limit_event_tracker=None, ) assert generated_models == ["VLLM/example/model-a"] @@ -244,7 +248,7 @@ def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): monkeypatch.setattr( mt_bench_utils, "_generate_mt_bench_completions", - lambda args, questions_df, ignore_cache, usage_tracker: ( + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( pd.DataFrame( { "completion_turn_1": ["A1"], @@ -304,6 +308,9 @@ def fake_run_mt_bench_fastchat(**kwargs): chat_template=None, provide_explanation=False, swap_mode="fixed", + judge_prompt_preset="default", + battle_thinking_token_budget=None, + strip_thinking_before_judging=False, engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, ) @@ -318,6 +325,93 @@ def fake_run_mt_bench_fastchat(**kwargs): "gpu_memory_utilization": 0.7, "language_model_only": True, "thinking_token_budget": 512, + "limit_event_stage": "judge_model_init", + "limit_event_model_spec": "VLLM/Qwen/Qwen3.5-27B-FP8", + "limit_event_tracker": captured["make_model"]["kwargs"]["limit_event_tracker"], } assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" - assert "provide_explanation" not in captured["run_mt_bench_fastchat"] + assert captured["run_mt_bench_fastchat"]["prompt_preset"] == "default" + assert ( + captured["run_mt_bench_fastchat"]["args"].strip_thinking_before_judging is False + ) + + +def test_select_prompt_supports_optional_skywork_mt_bench_preset(): + prompt = fastchat_compat._select_prompt( + "writing", + multi_turn=False, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + ) + + assert prompt.name == "skywork-pair-v2" + assert prompt.ref_based is False + + +def test_run_mt_bench_keeps_skywork_prompt_preset(monkeypatch): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + monkeypatch.setattr(mt_bench_utils, "make_model", lambda **kwargs: object()) + + def fake_run_mt_bench_fastchat(**kwargs): + captured["kwargs"] = kwargs + return pd.Series([0.0], dtype=float) + + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_fastchat", + fake_run_mt_bench_fastchat, + ) + + args = SimpleNamespace( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Skywork/Skywork-Critic-Llama-3.1-8B", + n_instructions=1, + truncate_all_input_chars=8192, + max_out_tokens_models=1024, + max_out_tokens_judge=256, + use_tqdm=False, + max_model_len=16384, + chat_template=None, + provide_explanation=False, + swap_mode="both", + judge_prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + battle_thinking_token_budget=512, + strip_thinking_before_judging=True, + engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, + ) + + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert captured["kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET + assert captured["kwargs"]["args"].strip_thinking_before_judging is True diff --git a/tests/test_regexp.py b/tests/test_regexp.py index 23af4d5..f11c1cd 100644 --- a/tests/test_regexp.py +++ b/tests/test_regexp.py @@ -1,4 +1,5 @@ from judgearena.evaluate import PairScore +from judgearena.utils import strip_thinking_tags def test_pair_score(): @@ -59,15 +60,15 @@ def test_pair_score_ignores_scores_inside_thinking_tags(): assert pref == 0.9525741268224333 -def test_pair_score_falls_back_to_bracketed_verdicts(): +def test_pair_score_score_mode_does_not_parse_bracketed_verdicts(): scorer = PairScore() - assert scorer.parse_model_raw("Explanation: ok\n[[A]]") == 0.0 - assert scorer.parse_model_raw("Explanation: ok\n[[B]]") == 1.0 - assert scorer.parse_model_raw("Explanation: ok\n[[C]]") == 0.5 + assert scorer.parse_model_raw("Explanation: ok\n[[A]]") is None + assert scorer.parse_model_raw("Explanation: ok\n[[B]]") is None + assert scorer.parse_model_raw("Explanation: ok\n[[C]]") is None -def test_pair_score_ignores_thinking_tags_before_bracketed_verdict(): +def test_pair_score_score_mode_ignores_bracketed_verdict_after_thinking(): raw_text = """ score_A: 0 @@ -79,4 +80,31 @@ def test_pair_score_ignores_thinking_tags_before_bracketed_verdict(): scorer = PairScore() + assert scorer.parse_model_raw(raw_text) is None + + +def test_strip_thinking_tags_handles_closing_tag_without_opening_tag(): + raw_text = ( + "Reasoning that started implicitly and kept going.\n" + "Still reasoning.\n" + "\n" + "Final answer." + ) + + assert strip_thinking_tags(raw_text) == "Final answer." + + +def test_pair_score_verdict_mode_uses_bracketed_verdicts(): + raw_text = "score_A: 10\nscore_B: 0\n[[B]]" + + scorer = PairScore(parser_mode="verdict") + assert scorer.parse_model_raw(raw_text) == 1.0 + + +def test_pair_score_verdict_mode_does_not_parse_score_only_outputs(): + raw_text = "score_A: 10\nscore_B: 0" + + scorer = PairScore(parser_mode="verdict") + + assert scorer.parse_model_raw(raw_text) is None From 8087c15ea1acd7083c98016b19c933d0cf9b3124 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Mon, 20 Apr 2026 16:41:51 +0200 Subject: [PATCH 17/28] Add judge input character truncation and model length configurations so that we have more customizability - Introduced `truncate_judge_input_chars` and `max_judge_model_len` to `BaseCliArgs` for better control over judge-side input limits. --- judgearena/cli_common.py | 78 ++++++++++-- judgearena/estimate_elo_ratings.py | 10 +- judgearena/generate.py | 16 ++- judgearena/generate_and_evaluate.py | 15 ++- judgearena/mt_bench/mt_bench_utils.py | 24 ++-- judgearena/utils.py | 132 +++++++++++++++----- tests/test_chat_vllm.py | 168 ++++++++++++++++++++++++-- tests/test_mt_bench_downloads.py | 40 +++++- tests/test_utils.py | 81 +++++++++++++ 9 files changed, 491 insertions(+), 73 deletions(-) diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 118464d..5327302 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -28,12 +28,15 @@ class BaseCliArgs: battle_thinking_token_budget: int | None = None strip_thinking_before_judging: bool = False truncate_all_input_chars: int = 8192 + truncate_judge_input_chars: int | None = None max_out_tokens_models: int = 32768 max_out_tokens_judge: int = 32768 max_model_len: int | None = None + max_judge_model_len: int | None = None chat_template: str | None = None result_folder: str = "results" engine_kwargs: dict = field(default_factory=dict) + judge_engine_kwargs: dict = field(default_factory=dict) def __post_init__(self): supported_modes = ["fixed", "both"] @@ -41,6 +44,26 @@ def __post_init__(self): f"Only {supported_modes} modes are supported but got {self.swap_mode}." ) + def effective_judge_truncation(self) -> int: + """Character cap applied to judge-side inputs (completions, reference, etc.). + + Falls back to the generation-side ``truncate_all_input_chars`` when a + dedicated judge cap is not configured. + """ + if self.truncate_judge_input_chars is not None: + return int(self.truncate_judge_input_chars) + return int(self.truncate_all_input_chars) + + def effective_judge_max_model_len(self) -> int | None: + """Total context window for the judge vLLM instance. + + Falls back to the generation-side ``max_model_len`` when a dedicated + judge context window is not configured. + """ + if self.max_judge_model_len is not None: + return int(self.max_judge_model_len) + return self.max_model_len + def parse_optional_bool(raw: str | None) -> bool: if raw is None: @@ -152,9 +175,23 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: required=False, default=8192, help=( - "Character-level truncation applied before tokenization: truncates " - "each instruction before model A/B generation and truncates each " - "completion before judge evaluation." + "Character-level truncation applied to generation-side inputs: " + "truncates each instruction before model A/B generation. When " + "--truncate_judge_input_chars is not set, this value also caps the " + "judge-side inputs (completions, reference, etc.)." + ), + ) + parser.add_argument( + "--truncate_judge_input_chars", + type=int, + required=False, + default=None, + help=( + "Character cap applied to judge-side inputs (completions, " + "reference, instruction) before judge evaluation. Falls back to " + "--truncate_all_input_chars when not specified. Set much higher " + "than the generation cap to avoid cutting model completions before " + "they reach the judge." ), ) parser.add_argument( @@ -183,10 +220,24 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: required=False, default=None, help=( - "Optional total context window for VLLM models (prompt + generation). " - "This is independent from --max_out_tokens_models/--max_out_tokens_judge, " - "which only cap generated tokens. This is useful on smaller GPUs to " - "avoid OOM." + "Optional total context window for the battle-generation VLLM " + "instances (prompt + generation). Independent from " + "--max_out_tokens_models/--max_out_tokens_judge, which only cap " + "generated tokens. When --max_judge_model_len is not set, this " + "value also sizes the judge instance." + ), + ) + parser.add_argument( + "--max_judge_model_len", + type=int, + required=False, + default=None, + help=( + "Optional total context window for the judge VLLM instance. Falls " + "back to --max_model_len when not specified. Set higher than the " + "battle model_len when the judge needs to see longer prompts " + "(e.g. long completions from both A and B) than the battle " + "generator can fit." ), ) parser.add_argument( @@ -211,6 +262,19 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: '\'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.' ), ) + parser.add_argument( + "--judge_engine_kwargs", + type=str, + required=False, + default="{}", + help=( + "Optional JSON dict of engine-specific kwargs that override " + "``--engine_kwargs`` only for the judge model. Useful when the " + "judge needs a different tensor-parallel or quantization config " + "than the battle models, e.g. a 70B judge on TP=2 while the " + "battle models run on TP=1 to dodge compile-time deadlocks." + ), + ) def parse_engine_kwargs(raw: str) -> dict: diff --git a/judgearena/estimate_elo_ratings.py b/judgearena/estimate_elo_ratings.py index d7dfbd7..0e95aa3 100644 --- a/judgearena/estimate_elo_ratings.py +++ b/judgearena/estimate_elo_ratings.py @@ -100,12 +100,15 @@ def parse_args(cls): swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, truncate_all_input_chars=args.truncate_all_input_chars, + truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, max_out_tokens_judge=args.max_out_tokens_judge, max_model_len=args.max_model_len, + max_judge_model_len=args.max_judge_model_len, chat_template=args.chat_template, result_folder=args.result_folder, engine_kwargs=parse_engine_kwargs(args.engine_kwargs), + judge_engine_kwargs=parse_engine_kwargs(args.judge_engine_kwargs), ) @@ -343,8 +346,9 @@ def replace_slash(s: str) -> str: ] judge_extra_kwargs = {} - if args.max_model_len is not None: - judge_extra_kwargs["max_model_len"] = args.max_model_len + effective_judge_max_model_len = args.effective_judge_max_model_len() + if effective_judge_max_model_len is not None: + judge_extra_kwargs["max_model_len"] = effective_judge_max_model_len if args.chat_template is not None: judge_extra_kwargs["chat_template"] = args.chat_template @@ -361,7 +365,7 @@ def run_judge() -> pd.DataFrame: completions_B=completions_B, swap_mode=args.swap_mode, provide_explanation=args.provide_explanation, - truncate_input_chars=args.truncate_all_input_chars, + truncate_input_chars=args.effective_judge_truncation(), use_tqdm=use_tqdm, ) return pd.DataFrame( diff --git a/judgearena/generate.py b/judgearena/generate.py index 3a1c65b..97254f7 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -21,10 +21,13 @@ def _record_generation_output_limit_events( ) -> list[bool]: hit_token_limit: list[bool] = [] for case_id, metadata_row in zip(case_ids, metadata, strict=True): - finish_reason = str((metadata_row or {}).get("finish_reason") or "").lower() + row = metadata_row or {} + finish_reason = str(row.get("finish_reason") or "").lower() reached_limit = finish_reason == "length" hit_token_limit.append(reached_limit) - if reached_limit and limit_event_tracker is not None: + if limit_event_tracker is None: + continue + if reached_limit: limit_event_tracker.record( "generation_output_token_limit", stage="generation_output", @@ -33,6 +36,15 @@ def _record_generation_output_limit_events( model_spec=model_spec, note=finish_reason, ) + if row.get("thinking_budget_exhausted"): + limit_event_tracker.record( + "generation_thinking_token_budget", + stage="generation_output", + field=field, + case_id=case_id, + model_spec=model_spec, + note=str(row.get("thinking_token_budget")), + ) return hit_token_limit diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 51edaf8..3ddc577 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -40,7 +40,7 @@ compute_pref_summary, data_root, download_hf, - is_qwen_reasoning_model, + is_thinking_model, make_model, read_df, ) @@ -141,12 +141,15 @@ def parse_args(cls): battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, truncate_all_input_chars=args.truncate_all_input_chars, + truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, max_out_tokens_judge=args.max_out_tokens_judge, max_model_len=args.max_model_len, + max_judge_model_len=args.max_judge_model_len, chat_template=args.chat_template, result_folder=args.result_folder, engine_kwargs=parse_engine_kwargs(args.engine_kwargs), + judge_engine_kwargs=parse_engine_kwargs(args.judge_engine_kwargs), ) @@ -163,7 +166,7 @@ def _build_generation_model_kwargs( if ( args.battle_thinking_token_budget is not None and provider == "VLLM" - and is_qwen_reasoning_model(model_name) + and is_thinking_model(model_name) ): generation_model_kwargs["thinking_token_budget"] = min( int(args.battle_thinking_token_budget), @@ -176,7 +179,9 @@ def _build_judge_model_kwargs( *, args: CliArgs, limit_event_tracker: LimitEventTracker | None ) -> dict[str, object]: judge_model_kwargs = build_default_judge_model_kwargs( - args.judge_model, args.engine_kwargs + args.judge_model, + args.engine_kwargs, + judge_engine_kwargs_override=args.judge_engine_kwargs, ) if limit_event_tracker is not None: judge_model_kwargs["limit_event_tracker"] = limit_event_tracker @@ -333,7 +338,7 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, - max_model_len=args.max_model_len, + max_model_len=args.effective_judge_max_model_len(), chat_template=args.chat_template, **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) @@ -378,7 +383,7 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: strip_thinking_before_judging=args.strip_thinking_before_judging, system_prompt=resolved_prompt.system_prompt, user_prompt_template=resolved_prompt.user_prompt_template, - truncate_input_chars=args.truncate_all_input_chars, + truncate_input_chars=args.effective_judge_truncation(), use_tqdm=args.use_tqdm, usage_tracker=usage_tracker, usage_phase="judge", diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 00f380a..864c8a2 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -36,7 +36,7 @@ build_default_judge_model_kwargs, cache_function_dataframe, compute_pref_summary, - is_qwen_reasoning_model, + is_thinking_model, make_model, ) @@ -58,7 +58,7 @@ def _build_mt_bench_generation_kwargs( if ( args.battle_thinking_token_budget is not None and provider == "VLLM" - and is_qwen_reasoning_model(model_name) + and is_thinking_model(model_name) ): generation_model_kwargs["thinking_token_budget"] = min( int(args.battle_thinking_token_budget), @@ -71,7 +71,9 @@ def _build_mt_bench_judge_model_kwargs( *, args: CliArgs, limit_event_tracker: LimitEventTracker | None ) -> dict[str, object]: judge_model_kwargs = build_default_judge_model_kwargs( - args.judge_model, args.engine_kwargs + args.judge_model, + args.engine_kwargs, + judge_engine_kwargs_override=args.judge_engine_kwargs, ) if limit_event_tracker is not None: judge_model_kwargs["limit_event_tracker"] = limit_event_tracker @@ -235,7 +237,7 @@ def _run_mt_bench_fastchat( model_b=args.model_B, turns_mode="both", swap_mode=args.swap_mode, - truncate_input_chars=args.truncate_all_input_chars, + truncate_input_chars=args.effective_judge_truncation(), use_tqdm=args.use_tqdm, prompt_preset=prompt_preset, strip_thinking_before_judging=args.strip_thinking_before_judging, @@ -323,22 +325,24 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): usage_tracker=usage_tracker, limit_event_tracker=limit_event_tracker, ) + effective_judge_max_model_len = args.effective_judge_max_model_len() if ( - args.max_model_len is not None - and args.max_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + effective_judge_max_model_len is not None + and effective_judge_max_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN ): print( "MT-Bench judge prompts require a larger total context window for " "prompt plus completion; " - f"overriding max_model_len from {args.max_model_len} " - f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN} for the judge." + f"overriding judge max_model_len from {effective_judge_max_model_len} " + f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN}." ) - args.max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + args.max_judge_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + effective_judge_max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, temperature=0.0, - max_model_len=args.max_model_len, + max_model_len=effective_judge_max_model_len, chat_template=args.chat_template, **_build_mt_bench_judge_model_kwargs( args=args, limit_event_tracker=limit_event_tracker diff --git a/judgearena/utils.py b/judgearena/utils.py index 7bd0809..5d3e7f4 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -34,10 +34,11 @@ def _data_root_path() -> Path: data_root = _data_root_path() DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET = 512 -VLLM_QWEN_REASONING_START_STR = "" -VLLM_QWEN_REASONING_END_STR = ( +VLLM_REASONING_START_STR = "" +VLLM_REASONING_END_STR = ( "I have to give the solution based on the thinking directly now." ) +_THINKING_MODEL_SUBSTRINGS = ("qwen3", "smollm3") def _split_model_spec(model_spec: str) -> tuple[str, str]: @@ -47,24 +48,49 @@ def _split_model_spec(model_spec: str) -> tuple[str, str]: return provider, model_name -def is_qwen_reasoning_model(model_name: str) -> bool: - return "qwen3" in model_name.lower() +def is_thinking_model(model_name: str) -> bool: + """Return True for reasoning models that emit `...` traces. + + Covers the Qwen3 family (e.g. `Qwen/Qwen3.5-9B`) and SmolLM3 (e.g. + `HuggingFaceTB/SmolLM3-3B`); both share the same ``/`` tag + convention so vLLM's budget enforcement and our tag-stripping apply + uniformly. Matching is case-insensitive to tolerate mixed-case HF repo + ids like `HuggingFaceTB/SmolLM3-3B`. + """ + lowered = model_name.lower() + return any(token in lowered for token in _THINKING_MODEL_SUBSTRINGS) def build_default_judge_model_kwargs( - judge_model: str, engine_kwargs: dict[str, object] + judge_model: str, + engine_kwargs: dict[str, object], + *, + judge_engine_kwargs_override: dict[str, object] | None = None, ) -> dict[str, object]: - """Copy judge engine kwargs and add supported built-in defaults.""" + """Copy judge engine kwargs and add supported built-in defaults. + + ``judge_engine_kwargs_override`` is layered on top of ``engine_kwargs`` + so callers can pin judge-only tweaks (e.g. a higher tensor-parallel size + for a 70B judge) without poisoning the battle-model engine config, which + must often stay on TP=1 to dodge compile-time deadlocks on hybrid models + such as Qwen3.5. + """ judge_model_kwargs = dict(engine_kwargs) + if judge_engine_kwargs_override: + judge_model_kwargs.update(judge_engine_kwargs_override) provider, model_name = _split_model_spec(judge_model) - if ( - provider == "VLLM" - and "thinking_token_budget" not in judge_model_kwargs - and is_qwen_reasoning_model(model_name) - ): - judge_model_kwargs["thinking_token_budget"] = ( - DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET - ) + if provider == "VLLM": + if "thinking_token_budget" not in judge_model_kwargs and is_thinking_model( + model_name + ): + judge_model_kwargs["thinking_token_budget"] = ( + DEFAULT_VLLM_JUDGE_THINKING_TOKEN_BUDGET + ) + # FP8 weights leave little KV headroom on consumer-class GPUs; default + # to FP8 KV cache so judges like Skywork-70B-FP8 fit comfortably on + # 2x L40S at 32k context. Explicit caller overrides still win. + if "kv_cache_dtype" not in judge_model_kwargs and "fp8" in model_name.lower(): + judge_model_kwargs["kv_cache_dtype"] = "fp8" return judge_model_kwargs @@ -327,13 +353,33 @@ def strip_thinking_tags_with_metadata(text: str | None) -> tuple[str, bool]: if closing_idx != -1 and "" not in lowered[:closing_idx]: return text[closing_idx + len(closing_tag) :].lstrip(), True - qwen_end_idx = text.find(VLLM_QWEN_REASONING_END_STR) - if qwen_end_idx != -1: - return text[qwen_end_idx + len(VLLM_QWEN_REASONING_END_STR) :].lstrip(), True + forced_end_idx = text.find(VLLM_REASONING_END_STR) + if forced_end_idx != -1: + return ( + text[forced_end_idx + len(VLLM_REASONING_END_STR) :].lstrip(), + True, + ) return text, False +def _extract_ai_message_metadata(result: object) -> dict[str, Any]: + """Extract finish_reason/stop_reason from a LangChain AIMessage result. + + LangChain chat models (ChatOpenAI for OpenRouter, Anthropic, etc.) return + AIMessage objects with a ``response_metadata`` dict. We propagate the + subset that downstream code consumes (finish_reason is critical: it gates + truncation detection in _record_generation_output_limit_events). + """ + response_metadata = getattr(result, "response_metadata", None) or {} + finish_reason = response_metadata.get("finish_reason") + stop_reason = response_metadata.get("stop_reason") + if finish_reason is None and isinstance(result, dict): + finish_reason = result.get("finish_reason") + stop_reason = result.get("stop_reason", stop_reason) + return {"finish_reason": finish_reason, "stop_reason": stop_reason} + + def do_inference( chat_model, inputs, @@ -380,7 +426,7 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): ) ) if return_metadata: - metadata = [{} for _ in res] + metadata = [_extract_ai_message_metadata(r) for r in res] else: def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): @@ -407,7 +453,9 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): chunk_results = chat_model.batch( inputs=chunk, **invoke_kwargs ) - chunk_metadata = [{} for _ in chunk_results] + chunk_metadata = [ + _extract_ai_message_metadata(r) for r in chunk_results + ] results.extend(chunk_results) results_metadata.extend(chunk_metadata) return results, results_metadata @@ -548,6 +596,8 @@ def __init__( "temperature": float(vllm_kwargs.pop("temperature", 0.6)), "top_p": float(vllm_kwargs.pop("top_p", 0.95)), } + self._thinking_budget_marker: str | None = None + self._thinking_budget_value: int | None = None if thinking_token_budget is not None: if max_tokens is not None: thinking_token_budget = min(int(thinking_token_budget), int(max_tokens)) @@ -555,23 +605,32 @@ def __init__( self._sampling_params_kwargs["thinking_token_budget"] = int( thinking_token_budget ) - elif is_qwen_reasoning_model(model): + self._thinking_budget_marker = VLLM_REASONING_END_STR + self._thinking_budget_value = int(thinking_token_budget) + elif is_thinking_model(model): vllm_kwargs.setdefault( "reasoning_config", ReasoningConfig( - reasoning_start_str=VLLM_QWEN_REASONING_START_STR, - reasoning_end_str=VLLM_QWEN_REASONING_END_STR, + reasoning_start_str=VLLM_REASONING_START_STR, + reasoning_end_str=VLLM_REASONING_END_STR, ), ) + # The `qwen3` reasoning_parser only runs inside vLLM's + # OpenAI-compatible server for `reasoning_content` extraction. + # For offline batch inference via LLM.chat() it is inert, so + # it is safe to reuse for any ``/`` model + # (Qwen3 + SmolLM3). vllm_kwargs.setdefault("reasoning_parser", "qwen3") self._sampling_params_kwargs["thinking_token_budget"] = int( thinking_token_budget ) + self._thinking_budget_marker = VLLM_REASONING_END_STR + self._thinking_budget_value = int(thinking_token_budget) else: warnings.warn( - f"Model '{model}' is not in JudgeArena's built-in Qwen reasoning defaults. " - "Ignoring thinking_token_budget unless reasoning_parser or " - "reasoning_config is provided explicitly.", + f"Model '{model}' is not in JudgeArena's built-in thinking-model " + "defaults (Qwen3/SmolLM3). Ignoring thinking_token_budget unless " + "reasoning_parser or reasoning_config is provided explicitly.", stacklevel=2, ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) @@ -691,15 +750,24 @@ def batch_with_metadata( outputs = self._run_raw_batch(inputs) texts: list[str] = [] metadata: list[dict[str, Any]] = [] + marker = self._thinking_budget_marker for out in outputs: first_output = out.outputs[0] - texts.append(first_output.text) - metadata.append( - { - "finish_reason": getattr(first_output, "finish_reason", None), - "stop_reason": getattr(first_output, "stop_reason", None), - } - ) + text = first_output.text + texts.append(text) + row: dict[str, Any] = { + "finish_reason": getattr(first_output, "finish_reason", None), + "stop_reason": getattr(first_output, "stop_reason", None), + } + if marker is not None: + # vLLM emits the forced reasoning-end marker verbatim when the + # per-request thinking-token budget is exhausted; the marker is + # absent otherwise. Detecting it here gives + # `_record_generation_output_limit_events` a deterministic + # signal to log a `generation_thinking_token_budget` event. + row["thinking_budget_exhausted"] = marker in text + row["thinking_token_budget"] = self._thinking_budget_value + metadata.append(row) return texts, metadata def batch(self, inputs: list, **invoke_kwargs) -> list[str]: diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index 1c0eaaf..a0f04b5 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -53,7 +53,7 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc captured, fake_reasoning_config = _install_fake_vllm(monkeypatch) utils.ChatVLLM( - model="Qwen/Qwen3.5-27B-FP8", + model="Qwen/Qwen3.5-9B", max_tokens=128, thinking_token_budget=64, gpu_memory_utilization=0.7, @@ -62,8 +62,28 @@ def test_chat_vllm_enables_reasoning_support_for_qwen_thinking_budget(monkeypatc assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 assert "structured_outputs" not in captured["sampling_kwargs"] assert captured["reasoning_config_kwargs"] == { - "reasoning_start_str": utils.VLLM_QWEN_REASONING_START_STR, - "reasoning_end_str": utils.VLLM_QWEN_REASONING_END_STR, + "reasoning_start_str": utils.VLLM_REASONING_START_STR, + "reasoning_end_str": utils.VLLM_REASONING_END_STR, + } + llm_kwargs = captured["llm_init"]["kwargs"] + assert llm_kwargs["reasoning_parser"] == "qwen3" + assert isinstance(llm_kwargs["reasoning_config"], fake_reasoning_config) + + +def test_chat_vllm_enables_reasoning_support_for_smollm3_thinking_budget(monkeypatch): + captured, fake_reasoning_config = _install_fake_vllm(monkeypatch) + + utils.ChatVLLM( + model="HuggingFaceTB/SmolLM3-3B", + max_tokens=128, + thinking_token_budget=64, + gpu_memory_utilization=0.7, + ) + + assert captured["sampling_kwargs"]["thinking_token_budget"] == 64 + assert captured["reasoning_config_kwargs"] == { + "reasoning_start_str": utils.VLLM_REASONING_START_STR, + "reasoning_end_str": utils.VLLM_REASONING_END_STR, } llm_kwargs = captured["llm_init"]["kwargs"] assert llm_kwargs["reasoning_parser"] == "qwen3" @@ -74,7 +94,7 @@ def test_chat_vllm_clamps_thinking_budget_to_total_max_tokens(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) utils.ChatVLLM( - model="Qwen/Qwen3.5-27B-FP8", + model="Qwen/Qwen3.5-9B", max_tokens=32, thinking_token_budget=64, gpu_memory_utilization=0.7, @@ -86,7 +106,7 @@ def test_chat_vllm_clamps_thinking_budget_to_total_max_tokens(monkeypatch): def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) chat_model = utils.ChatVLLM( - model="Qwen/Qwen3.5-27B-FP8", + model="Qwen/Qwen3.5-9B", max_tokens=16, disable_thinking=True, gpu_memory_utilization=0.7, @@ -102,7 +122,7 @@ def test_chat_vllm_passes_disable_thinking_via_chat_template_kwargs(monkeypatch) def test_build_default_judge_model_kwargs_only_defaults_qwen_judges(): assert utils.build_default_judge_model_kwargs( - "VLLM/Qwen/Qwen3.5-27B-FP8", + "VLLM/Qwen/Qwen3.5-9B", {"gpu_memory_utilization": 0.7}, ) == { "gpu_memory_utilization": 0.7, @@ -121,6 +141,69 @@ def test_build_default_judge_model_kwargs_only_defaults_qwen_judges(): ) +def test_build_default_judge_model_kwargs_sets_fp8_kv_cache_for_fp8_judges(): + fp8_defaults = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9}, + ) + assert fp8_defaults["kv_cache_dtype"] == "fp8" + # FP8 Skywork judge is not Qwen3/SmolLM3 so no thinking-token default. + assert "thinking_token_budget" not in fp8_defaults + + bf16_defaults = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-8B", + {"gpu_memory_utilization": 0.9}, + ) + assert "kv_cache_dtype" not in bf16_defaults + + explicit_override = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9, "kv_cache_dtype": "bfloat16"}, + ) + assert explicit_override["kv_cache_dtype"] == "bfloat16" + + # Non-VLLM providers never receive the FP8 KV default even if the name + # happens to contain "fp8". + non_vllm = utils.build_default_judge_model_kwargs("OpenRouter/some/Model-fp8", {}) + assert "kv_cache_dtype" not in non_vllm + + +def test_build_default_judge_model_kwargs_overlays_judge_override(): + """Judge-scoped overrides must win over shared ``engine_kwargs`` so the + battle engine can stay on TP=1 while the 70B FP8 judge pins TP=2.""" + merged = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"gpu_memory_utilization": 0.9}, + judge_engine_kwargs_override={"tensor_parallel_size": 2}, + ) + assert merged["tensor_parallel_size"] == 2 + assert merged["gpu_memory_utilization"] == 0.9 + assert merged["kv_cache_dtype"] == "fp8" + + overridden = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"tensor_parallel_size": 1, "gpu_memory_utilization": 0.9}, + judge_engine_kwargs_override={"tensor_parallel_size": 4}, + ) + assert overridden["tensor_parallel_size"] == 4 + + empty_override = utils.build_default_judge_model_kwargs( + "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", + {"tensor_parallel_size": 1}, + judge_engine_kwargs_override={}, + ) + assert empty_override["tensor_parallel_size"] == 1 + + +def test_is_thinking_model_matches_qwen3_and_smollm3_repo_ids(): + assert utils.is_thinking_model("Qwen/Qwen3.5-9B") + assert utils.is_thinking_model("HuggingFaceTB/SmolLM3-3B") + assert utils.is_thinking_model("Qwen/Qwen3-7B") + assert not utils.is_thinking_model("Qwen/Qwen2.5-7B") + assert not utils.is_thinking_model("utter-project/EuroLLM-9B-Instruct") + assert not utils.is_thinking_model("meta-llama/Llama-3.1-8B") + + def test_chat_vllm_preserves_explicit_reasoning_settings_for_non_qwen(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) explicit_reasoning_config = object() @@ -144,7 +227,7 @@ def test_chat_vllm_preserves_explicit_reasoning_settings_for_non_qwen(monkeypatc def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) - with pytest.warns(UserWarning, match="built-in Qwen reasoning defaults"): + with pytest.warns(UserWarning, match="built-in thinking-model"): utils.ChatVLLM( model="meta-llama/Llama-3.3-70B-Instruct", max_tokens=32, @@ -157,11 +240,78 @@ def test_chat_vllm_ignores_thinking_budget_for_unknown_family(monkeypatch): assert "reasoning_config" not in captured["llm_init"]["kwargs"] +def test_chat_vllm_records_thinking_budget_exhaustion_metadata(monkeypatch): + captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + class FakeLLMWithMarker: + def __init__(self, *, model, trust_remote_code, **kwargs): + captured["llm_init"] = {"model": model, "kwargs": kwargs} + + def get_tokenizer(self): + return SimpleNamespace(chat_template="{{ messages }}") + + def chat(self, messages, sampling_params, **kwargs): + return [ + SimpleNamespace( + outputs=[ + SimpleNamespace( + text=f"pre {utils.VLLM_REASONING_END_STR} answer", + finish_reason="stop", + stop_reason=None, + ) + ] + ), + SimpleNamespace( + outputs=[ + SimpleNamespace( + text="clean answer", + finish_reason="stop", + stop_reason=None, + ) + ] + ), + ] + + monkeypatch.setitem( + sys.modules, + "vllm", + SimpleNamespace( + LLM=FakeLLMWithMarker, + SamplingParams=sys.modules["vllm"].SamplingParams, + ), + ) + + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=64, + thinking_token_budget=32, + gpu_memory_utilization=0.7, + ) + _texts, metadata = chat_model.batch_with_metadata(["a", "b"]) + + assert metadata[0]["thinking_budget_exhausted"] is True + assert metadata[0]["thinking_token_budget"] == 32 + assert metadata[1]["thinking_budget_exhausted"] is False + assert metadata[1]["thinking_token_budget"] == 32 + + +def test_chat_vllm_omits_thinking_budget_metadata_without_budget(monkeypatch): + _captured, _fake_reasoning_config = _install_fake_vllm(monkeypatch) + + chat_model = utils.ChatVLLM( + model="Qwen/Qwen3.5-9B", + max_tokens=64, + gpu_memory_utilization=0.7, + ) + assert chat_model._thinking_budget_marker is None + assert chat_model._thinking_budget_value is None + + def test_infer_model_spec_uses_type_based_vllm_fallback(): model = object.__new__(utils.ChatVLLM) - model.model_path = "Qwen/Qwen3.5-27B-FP8" + model.model_path = "Qwen/Qwen3.5-9B" - assert utils.infer_model_spec_from_instance(model) == "VLLM/Qwen/Qwen3.5-27B-FP8" + assert utils.infer_model_spec_from_instance(model) == "VLLM/Qwen/Qwen3.5-9B" def test_infer_model_spec_uses_type_based_llamacpp_fallback(monkeypatch): diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index db78355..9428303 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -6,9 +6,31 @@ import judgearena.mt_bench.fastchat_compat as fastchat_compat import judgearena.mt_bench.mt_bench_utils as mt_bench_utils import judgearena.utils as utils +from judgearena.cli_common import BaseCliArgs from judgearena.judge_prompt_presets import SKYWORK_JUDGE_PROMPT_PRESET +def _mt_bench_args( + *, + dataset: str, + model_A: str, + model_B: str, + use_tqdm: bool = False, + **base_overrides, +) -> BaseCliArgs: + """Construct a ``BaseCliArgs`` with MT-Bench CLI-style extras attached. + + Using the real dataclass here ensures tests exercise the production + ``effective_judge_*`` fallback helpers instead of a duplicate shim. + """ + args = BaseCliArgs(**base_overrides) + args.dataset = dataset + args.model_A = model_A + args.model_B = model_B + args.use_tqdm = use_tqdm + return args + + def test_download_mt_bench_skips_question_download_if_cached(tmp_path, monkeypatch): question_path = tmp_path / "data" / "mt_bench" / "question.jsonl" question_path.parent.mkdir(parents=True, exist_ok=True) @@ -294,7 +316,7 @@ def fake_run_mt_bench_fastchat(**kwargs): fake_run_mt_bench_fastchat, ) - args = SimpleNamespace( + args = _mt_bench_args( dataset="mt-bench", model_A="VLLM/example/model-a", model_B="gpt-4", @@ -303,7 +325,6 @@ def fake_run_mt_bench_fastchat(**kwargs): truncate_all_input_chars=8192, max_out_tokens_models=1024, max_out_tokens_judge=256, - use_tqdm=False, max_model_len=16384, chat_template=None, provide_explanation=False, @@ -318,13 +339,17 @@ def fake_run_mt_bench_fastchat(**kwargs): assert args.swap_mode == "both" assert args.max_out_tokens_judge == 24576 - assert args.max_model_len == 28672 + assert args.max_model_len == 16384 + assert args.max_judge_model_len == 28672 + assert args.effective_judge_max_model_len() == 28672 + assert args.effective_judge_truncation() == 8192 assert captured["make_model"]["max_tokens"] == 24576 assert captured["make_model"]["max_model_len"] == 28672 assert captured["make_model"]["kwargs"] == { "gpu_memory_utilization": 0.7, "language_model_only": True, "thinking_token_budget": 512, + "kv_cache_dtype": "fp8", "limit_event_stage": "judge_model_init", "limit_event_model_spec": "VLLM/Qwen/Qwen3.5-27B-FP8", "limit_event_tracker": captured["make_model"]["kwargs"]["limit_event_tracker"], @@ -391,17 +416,18 @@ def fake_run_mt_bench_fastchat(**kwargs): fake_run_mt_bench_fastchat, ) - args = SimpleNamespace( + args = _mt_bench_args( dataset="mt-bench", model_A="VLLM/example/model-a", model_B="gpt-4", judge_model="VLLM/Skywork/Skywork-Critic-Llama-3.1-8B", n_instructions=1, truncate_all_input_chars=8192, + truncate_judge_input_chars=80000, max_out_tokens_models=1024, max_out_tokens_judge=256, - use_tqdm=False, max_model_len=16384, + max_judge_model_len=65536, chat_template=None, provide_explanation=False, swap_mode="both", @@ -415,3 +441,7 @@ def fake_run_mt_bench_fastchat(**kwargs): assert captured["kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET assert captured["kwargs"]["args"].strip_thinking_before_judging is True + assert args.effective_judge_max_model_len() == 65536 + assert args.effective_judge_truncation() == 80000 + assert captured["kwargs"]["args"].effective_judge_truncation() == 80000 + assert captured["kwargs"]["args"].effective_judge_max_model_len() == 65536 diff --git a/tests/test_utils.py b/tests/test_utils.py index d2a7411..9a96d7d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,87 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + import judgearena.utils as utils +def test_extract_ai_message_metadata_reads_finish_reason(): + ai_message = SimpleNamespace( + content="hi", + response_metadata={"finish_reason": "length", "stop_reason": None}, + ) + md = utils._extract_ai_message_metadata(ai_message) + assert md == {"finish_reason": "length", "stop_reason": None} + + +def test_extract_ai_message_metadata_handles_missing_response_metadata(): + bare_ai_message = SimpleNamespace(content="hello") + md = utils._extract_ai_message_metadata(bare_ai_message) + assert md == {"finish_reason": None, "stop_reason": None} + + +def test_extract_ai_message_metadata_handles_plain_dict_fallback(): + md = utils._extract_ai_message_metadata( + {"finish_reason": "stop", "stop_reason": "eos"} + ) + assert md == {"finish_reason": "stop", "stop_reason": "eos"} + + +def test_do_inference_async_path_propagates_finish_reason(monkeypatch): + async_results = [ + SimpleNamespace( + content="out1", + response_metadata={"finish_reason": "stop"}, + ), + SimpleNamespace( + content="out2", + response_metadata={"finish_reason": "length"}, + ), + ] + + async def fake_ainvoke(_input, **_kwargs): + return async_results.pop(0) + + chat_model = SimpleNamespace(ainvoke=fake_ainvoke) + texts, metadata = utils.do_inference( + chat_model=chat_model, + inputs=["prompt1", "prompt2"], + use_tqdm=True, + return_metadata=True, + ) + assert texts == ["out1", "out2"] + assert metadata == [ + {"finish_reason": "stop", "stop_reason": None}, + {"finish_reason": "length", "stop_reason": None}, + ] + + +def test_do_inference_batch_path_propagates_finish_reason_without_batch_with_metadata(): + batch_results = [ + SimpleNamespace( + content="a", + response_metadata={"finish_reason": "stop"}, + ), + SimpleNamespace( + content="b", + response_metadata={"finish_reason": "length"}, + ), + ] + chat_model = MagicMock() + chat_model.batch = MagicMock(return_value=batch_results) + # Ensure no batch_with_metadata attr so the else branch runs + if hasattr(chat_model, "batch_with_metadata"): + del chat_model.batch_with_metadata + + texts, metadata = utils.do_inference( + chat_model=chat_model, + inputs=["p1", "p2"], + use_tqdm=False, + return_metadata=True, + ) + assert [m["finish_reason"] for m in metadata] == ["stop", "length"] + assert texts == ["a", "b"] + + def test_download_all_dispatches_arena_hard_versions(monkeypatch, tmp_path): calls: list[tuple[str, str, object]] = [] From 91d67ef4d74a2434c062b2cf27991150ebe4de66 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Mon, 20 Apr 2026 16:42:40 +0200 Subject: [PATCH 18/28] add llmcompressor dev dependency for quantization --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 19de68f..1318c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ exclude = ["slurmpilot_scripts*"] [dependency-groups] dev = [ + "llmcompressor>=0.4.0", "pre-commit>=4.5.1", "pytest>=8.4.2", "ruff>=0.11.0", From 5e8efc9f832c13ad29f66a436cfdc4cb7041d8ab Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 21 Apr 2026 21:52:33 +0200 Subject: [PATCH 19/28] Update baseline handling for Arena-Hard datasets - Refactor baseline assignment for Arena-Hard datasets to support different baselines based on category same as original benchmark. --- README.md | 8 +- judgearena/generate_and_evaluate.py | 229 ++++++++++++---- judgearena/instruction_dataset/arena_hard.py | 248 +++++++++++------- tests/test_chat_vllm.py | 4 + .../test_generate_and_evaluate_arena_hard.py | 94 +++++++ tests/test_instruction_dataset.py | 230 +++++++++++++++- 6 files changed, 658 insertions(+), 155 deletions(-) create mode 100644 tests/test_generate_and_evaluate_arena_hard.py diff --git a/README.md b/README.md index 508ac9f..f5050e9 100644 --- a/README.md +++ b/README.md @@ -203,9 +203,11 @@ This override applies to all vLLM models in the run. For remote providers (OpenA | `m-arena-hard-EU` | All EU languages combined | | `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | -For Arena-Hard, JudgeArena resolves baseline metadata by dataset version: -- `arena-hard-v0.1`: `gpt-4-0314` -- `arena-hard-v2.0`: `o3-mini-2025-01-31` (standard prompts) +For Arena-Hard, JudgeArena mirrors the baseline assignment upstream uses in `lmarena-ai/arena-hard-auto`: +- `arena-hard-v0.1`: flat baseline `gpt-4-0314` for all 500 prompts. +- `arena-hard-v2.0`: per-question baseline routed by `category`: + - `o3-mini-2025-01-31` for `hard_prompt`, `coding`, and `math` (500 prompts). + - `gemini-2.0-flash-001` for `creative_writing` (250 prompts). ## 📈 Estimating ELO Ratings diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 3ddc577..6c2b9e3 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -6,6 +6,7 @@ import argparse import hashlib import json +from collections.abc import Mapping from dataclasses import asdict, dataclass from datetime import UTC, datetime from pathlib import Path @@ -22,6 +23,7 @@ from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions from judgearena.instruction_dataset.arena_hard import ( + arena_hard_native_baseline, download_arena_hard, is_arena_hard_dataset, ) @@ -113,8 +115,12 @@ def parse_args(cls): ) parser.add_argument( "--model_B", - required=True, - help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", + default=None, + help=( + "Name of the baseline LLM for a generation. Optional for Arena-Hard " + "datasets (which ship a dataset-native default per category; see " + "`ARENA_HARD_BASELINES`). Required for every other dataset." + ), ) parser.add_argument( "--use_tqdm", @@ -153,6 +159,89 @@ def parse_args(cls): ) +@dataclass(frozen=True) +class BaselinePlan: + """Row-aligned baseline assignment for `--model_B`. + + Mirrors upstream's `JUDGE_SETTINGS[question["category"]]["baseline"]` lookup + in `arena-hard-auto/gen_judgment.py`: a flat plan assigns one baseline to + every row, a per-row plan assigns a different baseline per category (v2.0 + mixes `o3-mini-2025-01-31` on hard prompts with `gemini-2.0-flash-001` on + creative writing). + """ + + baseline_by_index: pd.Series + + @classmethod + def flat(cls, model: str, *, index: pd.Index) -> "BaselinePlan": + return cls( + baseline_by_index=pd.Series(model, index=index, name="model_B", dtype=str) + ) + + @classmethod + def per_row(cls, series: pd.Series) -> "BaselinePlan": + return cls(baseline_by_index=series.astype(str).rename("model_B")) + + @property + def unique_models(self) -> list[str]: + return sorted(self.baseline_by_index.dropna().unique().tolist()) + + @property + def is_flat(self) -> bool: + return len(self.unique_models) == 1 + + @property + def single_model(self) -> str: + if not self.is_flat: + raise ValueError( + "BaselinePlan is per-row; use baseline_by_index for row-level lookups" + ) + return self.unique_models[0] + + @property + def display_name(self) -> str: + return self.single_model if self.is_flat else "+".join(self.unique_models) + + def aligned_to(self, index: pd.Index) -> pd.Series: + return self.baseline_by_index.loc[index] + + +def _resolve_baseline_plan( + args: CliArgs, instructions_df: pd.DataFrame +) -> BaselinePlan: + """Explicit `--model_B` wins; otherwise fall back to the dataset-native + assignment. Non-arena-hard datasets without an override raise. + """ + if args.model_B is not None: + return BaselinePlan.flat(args.model_B, index=instructions_df.index) + if not is_arena_hard_dataset(args.dataset): + raise ValueError( + f"--model_B is required for dataset '{args.dataset}'; only Arena-Hard " + "datasets ship a dataset-native baseline." + ) + native = arena_hard_native_baseline(args.dataset) + if isinstance(native, str): + return BaselinePlan.flat(native, index=instructions_df.index) + if isinstance(native, Mapping): + if "category" not in instructions_df.columns: + raise ValueError( + f"{args.dataset} requires a 'category' column for per-category " + "baseline routing; re-run dataset download to regenerate the " + "instructions table." + ) + per_row = instructions_df["category"].map(native) + if per_row.isna().any(): + unknown = sorted( + instructions_df.loc[per_row.isna(), "category"].unique().tolist() + ) + raise ValueError( + f"Unknown Arena-Hard categories for {args.dataset}: {unknown}. " + f"Known: {sorted(native.keys())}" + ) + return BaselinePlan.per_row(per_row) + raise ValueError(f"Unsupported baseline shape for dataset '{args.dataset}'.") + + def load_contexts(dataset: str) -> pd.Series: path = data_root / "contexts" / dataset return pd.read_csv(path).loc[:, "instruction"] @@ -241,9 +330,6 @@ def main(args: CliArgs): run_started_at = datetime.now(UTC) usage_tracker = OpenRouterReferencePricingTracker() limit_event_tracker = LimitEventTracker() - print( - f"Using dataset {args.dataset} and evaluating models {args.model_A} and {args.model_B}." - ) # Not working with vllm, not detecting model changes and serving the same cache for two different models... # if not args.ignore_cache: @@ -260,18 +346,29 @@ def main(args: CliArgs): # to match files in https://huggingface.co/datasets/geoalgo/multilingual-contexts-to-be-completed lang = args.dataset.split("-")[-1] instructions = load_contexts(f"{lang}-contexts.csv") + instructions_df = pd.DataFrame({"instruction": instructions.values}) + instructions_df.index = instructions.index else: - instructions = load_instructions( + instructions_df = load_instructions( dataset=args.dataset, n_instructions=args.n_instructions - ).loc[:, "instruction"] + ) + instructions = instructions_df["instruction"] n_instructions = args.n_instructions if args.n_instructions else len(instructions) if args.n_instructions is not None: - instructions = instructions[:n_instructions] + instructions_df = instructions_df.head(n_instructions) + instructions = instructions.head(n_instructions) + baseline_plan = _resolve_baseline_plan(args=args, instructions_df=instructions_df) + + print( + f"Using dataset {args.dataset} and evaluating {args.model_A} vs baseline " + f"{baseline_plan.display_name}." + ) print( - f"Generating completions for dataset {args.dataset} with model {args.model_A} and " - f"{args.model_B} (or loading them directly if present)" + f"Generating completions for dataset {args.dataset} with model {args.model_A} " + f"and baseline {baseline_plan.display_name} " + "(or loading them directly if present)" ) generation_function = generate_base if is_fluency_task else generate_instructions @@ -294,44 +391,54 @@ def _run_generation(model_spec: str, usage_phase: str) -> pd.DataFrame: def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: return df.set_index("instruction_index").loc[instructions.index].reset_index() - dataset_completions_A = try_load_dataset_completions( - args.dataset, args.model_A, n_instructions - ) - if dataset_completions_A is not None: - completions_A_df = _align_completion_dataframe(dataset_completions_A) - else: - completions_A_df = _align_completion_dataframe( - cache_function_dataframe( - lambda: _run_generation(args.model_A, "generation_model_A"), - ignore_cache=ignore_cache, - cache_name=_generation_cache_name(args, model_spec=args.model_A), - ) + def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Series: + preloaded = try_load_dataset_completions( + args.dataset, model_spec, n_instructions ) - completions_A = completions_A_df.set_index("instruction_index").loc[ - instructions.index, "completion" - ] + if preloaded is not None: + aligned = _align_completion_dataframe(preloaded) + else: + aligned = _align_completion_dataframe( + cache_function_dataframe( + lambda: _run_generation(model_spec, usage_phase), + ignore_cache=ignore_cache, + cache_name=_generation_cache_name(args, model_spec=model_spec), + ) + ) + return aligned.set_index("instruction_index").loc[ + instructions.index, "completion" + ] - dataset_completions_B = try_load_dataset_completions( - args.dataset, args.model_B, n_instructions - ) - if dataset_completions_B is not None: - completions_B_df = _align_completion_dataframe(dataset_completions_B) + completions_A = _load_or_generate_completions(args.model_A, "generation_model_A") + + baseline_per_index = baseline_plan.aligned_to(instructions.index) + if baseline_plan.is_flat: + completions_B = _load_or_generate_completions( + baseline_plan.single_model, "generation_model_B" + ) else: - completions_B_df = _align_completion_dataframe( - cache_function_dataframe( - lambda: _run_generation(args.model_B, "generation_model_B"), - ignore_cache=ignore_cache, - cache_name=_generation_cache_name(args, model_spec=args.model_B), + # Per-row plan: fetch one completion set per unique baseline, then stitch + # them together so completions_B[uid] uses the baseline that + # ARENA_HARD_BASELINES routes uid's category to. + per_baseline_completions: dict[str, pd.Series] = {} + for baseline_model in baseline_plan.unique_models: + per_baseline_completions[baseline_model] = _load_or_generate_completions( + baseline_model, f"generation_model_B::{baseline_model}" ) + completions_B = pd.Series( + [ + per_baseline_completions[model].loc[uid] + for uid, model in baseline_per_index.items() + ], + index=instructions.index, + name="completion", ) - completions_B = completions_B_df.set_index("instruction_index").loc[ - instructions.index, "completion" - ] + print(f"\nFirst instruction/context: {instructions.values[0]}") print(f"\nFirst completion of {args.model_A}") print(completions_A.values[0]) - print(f"\nFirst completion of {args.model_B}") + print(f"\nFirst completion of {baseline_plan.display_name}") print(completions_B.values[0]) print(f"Evaluating completions with judge {args.judge_model}.") @@ -343,14 +450,15 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) - name = f"{args.dataset}-{args.model_A}-{args.model_B}-{args.judge_model}" + name = ( + f"{args.dataset}-{args.model_A}-{baseline_plan.display_name}-{args.judge_model}" + ) name += f"-{args.swap_mode}" name = name.replace("/", "_") res_folder = Path(args.result_folder) / name res_folder.mkdir(parents=True, exist_ok=True) - # save argument for results analysis with open(res_folder / f"args-{name}.json", "w") as f: json.dump(asdict(args), f, indent=2) @@ -391,31 +499,33 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: limit_event_tracker=limit_event_tracker, ) + eval_instruction_index = instructions.head(n_instructions).index.tolist() + baseline_per_eval = baseline_per_index.loc[eval_instruction_index] + df = pd.DataFrame(annotations) - df["instruction_index"] = instructions.head(n_instructions).index.tolist() + df["instruction_index"] = eval_instruction_index df["model_A"] = args.model_A - df["model_B"] = args.model_B + df["model_B"] = baseline_per_eval.tolist() df["judge"] = args.judge_model if args.swap_mode == "both": df_reversed = pd.DataFrame(annotations_reversed) - df_reversed["instruction_index"] = instructions.head( - n_instructions - ).index.tolist() - df_reversed["model_A"] = args.model_B + df_reversed["instruction_index"] = eval_instruction_index + df_reversed["model_A"] = baseline_per_eval.tolist() df_reversed["model_B"] = args.model_A df_reversed["judge"] = args.judge_model df = pd.concat([df, df_reversed]) df.to_csv(res_folder / f"{name}-annotations.csv", index=False) - # compute and report statistics summary = compute_pref_summary(prefs) results = { "dataset": args.dataset, "model_A": args.model_A, - "model_B": args.model_B, + "model_B": baseline_plan.display_name, + "baseline_assignment": "per-row" if not baseline_plan.is_flat else "flat", + "baseline_models": baseline_plan.unique_models, "judge_model": args.judge_model, "judge_prompt_preset": resolved_prompt.preset_name, "strip_thinking_before_judging": args.strip_thinking_before_judging, @@ -424,22 +534,28 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: "limit_events": limit_event_tracker.build_summary(), "preferences": prefs.tolist(), } - print(f"{args.model_A} vs {args.model_B} judged by {args.judge_model}") + print( + f"{args.model_A} vs {baseline_plan.display_name} judged by {args.judge_model}" + ) print_results(results) + phase_model_specs: dict[str, str] = { + "generation_model_A": args.model_A, + "judge": args.judge_model, + } + if baseline_plan.is_flat: + phase_model_specs["generation_model_B"] = baseline_plan.single_model + else: + for baseline_model in baseline_plan.unique_models: + phase_model_specs[f"generation_model_B::{baseline_model}"] = baseline_model pricing_reference = build_openrouter_reference_pricing_summary( tracker=usage_tracker, - phase_model_specs={ - "generation_model_A": args.model_A, - "generation_model_B": args.model_B, - "judge": args.judge_model, - }, + phase_model_specs=phase_model_specs, ) print(format_openrouter_reference_pricing_summary(pricing_reference)) with open(res_folder / f"results-{name}.json", "w") as f: json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) - eval_instruction_index = instructions.head(n_instructions).index.tolist() eval_instructions = instructions.head(n_instructions).tolist() eval_completions_A = completions_A.head(n_instructions).tolist() eval_completions_B = completions_B.head(n_instructions).tolist() @@ -455,6 +571,7 @@ def _align_completion_dataframe(df: pd.DataFrame) -> pd.DataFrame: "instructions": eval_instructions, "completions_A": eval_completions_A, "completions_B": eval_completions_B, + "baseline_model_B": baseline_per_eval.tolist(), }, judge_system_prompt=resolved_prompt.system_prompt, judge_user_prompt_template=resolved_prompt.user_prompt_template, diff --git a/judgearena/instruction_dataset/arena_hard.py b/judgearena/instruction_dataset/arena_hard.py index 414d2f8..4dc70a4 100644 --- a/judgearena/instruction_dataset/arena_hard.py +++ b/judgearena/instruction_dataset/arena_hard.py @@ -1,104 +1,96 @@ -from dataclasses import dataclass +from collections.abc import Mapping from pathlib import Path +from typing import Any import pandas as pd -from datasets import Dataset, DatasetDict, IterableDataset, load_dataset +from huggingface_hub import snapshot_download ARENA_HARD_HF_REPO_ID = "lmarena-ai/arena-hard-auto" - -@dataclass(frozen=True) -class ArenaHardSpec: - hf_variant: str - baseline_model: str - - -ARENA_HARD_DATASETS: dict[str, ArenaHardSpec] = { - "arena-hard-v0.1": ArenaHardSpec( - hf_variant="arena-hard-v0.1", - baseline_model="gpt-4-0314", - ), - "arena-hard-v2.0": ArenaHardSpec( - hf_variant="arena-hard-v2.0", - baseline_model="o3-mini-2025-01-31", - ), +# Mirrors upstream's `JUDGE_SETTINGS` baseline assignment in +# `arena-hard-auto/utils/judge_utils.py`: v0.1 has a single flat baseline, +# v2.0 routes per question category. `is_arena_hard_dataset` and the +# dispatcher in `generate_and_evaluate.py` key off this map. +ARENA_HARD_BASELINES: dict[str, str | Mapping[str, str]] = { + "arena-hard-v0.1": "gpt-4-0314", + "arena-hard-v2.0": { + "hard_prompt": "o3-mini-2025-01-31", + "coding": "o3-mini-2025-01-31", + "math": "o3-mini-2025-01-31", + "creative_writing": "gemini-2.0-flash-001", + }, } - -def resolve_arena_hard_spec(dataset: str) -> ArenaHardSpec | None: - return ARENA_HARD_DATASETS.get(dataset) +# Dataset name -> upstream HF `data//` directory. Kept private so the +# public API of this module is just the baseline map and helpers below. +_ARENA_HARD_HF_VARIANTS: dict[str, str] = { + "arena-hard-v0.1": "arena-hard-v0.1", + "arena-hard-v2.0": "arena-hard-v2.0", +} def is_arena_hard_dataset(dataset: str) -> bool: - return resolve_arena_hard_spec(dataset) is not None + return dataset in ARENA_HARD_BASELINES -def arena_hard_baseline_model(dataset: str) -> str | None: - spec = resolve_arena_hard_spec(dataset) - if spec is None: - return None - return spec.baseline_model +def arena_hard_native_baseline( + dataset: str, +) -> str | Mapping[str, str] | None: + """Dataset-native baseline assignment. - -def _load_official_arena_hard_dataset(spec: ArenaHardSpec) -> pd.DataFrame: - data = load_dataset( - path=ARENA_HARD_HF_REPO_ID, - data_dir=f"data/{spec.hf_variant}", - ) - return _dataset_like_to_dataframe(data) - - -def _dataset_like_to_dataframe( - data: Dataset | DatasetDict | IterableDataset, -) -> pd.DataFrame: - if isinstance(data, DatasetDict): - if "train" in data: - return data["train"].to_pandas() - first_split = next(iter(data.keys())) - return data[first_split].to_pandas() - if isinstance(data, Dataset): - return data.to_pandas() - if isinstance(data, IterableDataset): - return pd.DataFrame(list(data)) - raise TypeError(f"Unsupported dataset object type: {type(data)}") + Returns a plain string for flat datasets (v0.1), a `{category: model}` + mapping for per-category datasets (v2.0), or `None` for datasets that + don't ship a native baseline. + """ + return ARENA_HARD_BASELINES.get(dataset) def normalize_official_arena_hard( raw_df: pd.DataFrame, dataset: str ) -> tuple[pd.DataFrame, pd.DataFrame | None]: - spec = resolve_arena_hard_spec(dataset) - if spec is None: + if dataset not in _ARENA_HARD_HF_VARIANTS: raise ValueError(f"Unsupported Arena-Hard dataset: {dataset}") - - instruction_index = _pick_instruction_index(raw_df) - instruction = _pick_instruction(raw_df) - df_instructions = pd.DataFrame( - { - "instruction_index": instruction_index, - "instruction": instruction, - } - ) - df_instructions = df_instructions.dropna( - subset=["instruction_index", "instruction"] - ) - df_instructions = df_instructions.drop_duplicates(subset=["instruction_index"]) - df_instructions = df_instructions.sort_values("instruction_index") - + df_instructions = _build_instructions(raw_df) df_model_outputs = _build_model_outputs(raw_df) return df_instructions, df_model_outputs def download_arena_hard(dataset: str, local_tables_path: Path) -> None: - """Load Arena-Hard from the Hub if instruction and model-output files are missing.""" - spec = resolve_arena_hard_spec(dataset) - if spec is None: + """Populate `{dataset}.csv` and `{dataset}.csv.zip` on disk if missing. + + Pulls the raw jsonl files directly via `snapshot_download` and reads them + with pandas: upstream's per-row `messages[].content` oscillates between + string and dict across answer files, so `datasets.load_dataset` can't + materialize them into a single Arrow schema. + + Re-downloads when the instructions table is stale - currently only v2.0 + detects this, because routing by category requires the `category` column + that older caches were written without. + """ + if dataset not in _ARENA_HARD_HF_VARIANTS: return instructions_path = local_tables_path / "instructions" / f"{dataset}.csv" model_outputs_path = local_tables_path / "model_outputs" / f"{dataset}.csv.zip" - if instructions_path.exists() and model_outputs_path.exists(): + if ( + instructions_path.exists() + and model_outputs_path.exists() + and _instructions_cache_is_fresh(instructions_path, dataset) + ): return - raw_df = _load_official_arena_hard_dataset(spec) + variant = _ARENA_HARD_HF_VARIANTS[dataset] + snapshot_root = snapshot_download( + repo_id=ARENA_HARD_HF_REPO_ID, + repo_type="dataset", + allow_patterns=[ + f"data/{variant}/question.jsonl", + f"data/{variant}/model_answer/*.jsonl", + ], + force_download=False, + ) + raw_df = _read_arena_hard_jsonl_frames( + variant_dir=Path(snapshot_root) / "data" / variant + ) df_instructions, df_model_outputs = normalize_official_arena_hard( raw_df=raw_df, dataset=dataset ) @@ -109,11 +101,100 @@ def download_arena_hard(dataset: str, local_tables_path: Path) -> None: df_model_outputs.to_csv(model_outputs_path, index=False) +def _instructions_cache_is_fresh(instructions_path: Path, dataset: str) -> bool: + """Category-aware datasets need a `category` column; older caches lack it.""" + native = arena_hard_native_baseline(dataset) + if not isinstance(native, Mapping): + return True + cached_columns = pd.read_csv(instructions_path, nrows=0).columns + return "category" in cached_columns + + +def _read_arena_hard_jsonl_frames(variant_dir: Path) -> pd.DataFrame: + frames: list[pd.DataFrame] = [] + question_path = variant_dir / "question.jsonl" + if question_path.exists(): + frames.append(pd.read_json(question_path, lines=True)) + answer_dir = variant_dir / "model_answer" + if answer_dir.exists(): + for jsonl_path in sorted(answer_dir.glob("*.jsonl")): + frames.append(pd.read_json(jsonl_path, lines=True)) + if not frames: + raise FileNotFoundError(f"No Arena-Hard jsonl files found under {variant_dir}") + return pd.concat(frames, ignore_index=True, sort=False) + + +def _build_instructions(raw_df: pd.DataFrame) -> pd.DataFrame: + # Question rows are the ones with a prompt; model-answer rows don't have + # one and must not leak into the instructions table. + if "prompt" in raw_df.columns: + question_rows = raw_df[raw_df["prompt"].notna()].reset_index(drop=True) + else: + question_rows = raw_df.reset_index(drop=True) + + if len(question_rows) == 0: + return pd.DataFrame(columns=["instruction_index", "instruction"]) + + columns: dict[str, pd.Series] = { + "instruction_index": _pick_instruction_index(question_rows), + "instruction": _pick_instruction(question_rows), + } + if "category" in question_rows.columns: + columns["category"] = question_rows["category"] + df = pd.DataFrame(columns) + df = df.dropna(subset=["instruction_index", "instruction"]) + df["instruction"] = df["instruction"].astype(str) + df = df.drop_duplicates(subset=["instruction_index"]) + df = df.sort_values("instruction_index").reset_index(drop=True) + return df + + +def _build_model_outputs(raw_df: pd.DataFrame) -> pd.DataFrame | None: + if "model" not in raw_df.columns: + return None + extracted_output = raw_df.apply(_extract_assistant_output, axis=1) + instruction_index = _pick_instruction_index(raw_df) + df = pd.DataFrame( + { + "instruction_index": instruction_index, + "model": raw_df["model"], + "output": extracted_output, + } + ) + df = df[df["model"].notna() & df["output"].notna()] + df = df.dropna(subset=["instruction_index"]) + if df.empty: + return None + df["instruction_index"] = df["instruction_index"].astype(str) + df["model"] = df["model"].astype(str) + df["output"] = df["output"].astype(str) + return df.reset_index(drop=True) + + +def _extract_assistant_output(row: pd.Series) -> str | None: + """Pull the assistant response out of either a flat `output` column or + upstream's nested `messages[-1].content.answer` shape. + """ + output_value = row.get("output") + if isinstance(output_value, str) and output_value: + return output_value + messages = row.get("messages") + if isinstance(messages, list) and messages: + last = messages[-1] + content = last.get("content") if isinstance(last, dict) else None + if isinstance(content, dict): + answer = content.get("answer") + return answer if isinstance(answer, str) and answer else None + if isinstance(content, str) and content: + return content + return None + + def _pick_instruction_index(raw_df: pd.DataFrame) -> pd.Series: - for col in ["instruction_index", "question_id", "id"]: + for col in ["instruction_index", "uid", "question_id", "id"]: if col in raw_df.columns: return raw_df[col].astype(str) - return pd.Series(range(len(raw_df)), dtype=str) + return pd.Series(range(len(raw_df)), dtype=str, index=raw_df.index) def _pick_instruction(raw_df: pd.DataFrame) -> pd.Series: @@ -121,13 +202,14 @@ def _pick_instruction(raw_df: pd.DataFrame) -> pd.Series: if col in raw_df.columns: if col == "turns": return raw_df[col].apply(_turns_to_text) - return raw_df[col].astype(str) + return raw_df[col] raise ValueError( - f"Unable to infer instruction text column from Arena-Hard data. Available columns: {raw_df.columns.tolist()}" + "Unable to infer instruction text column from Arena-Hard data. " + f"Available columns: {raw_df.columns.tolist()}" ) -def _turns_to_text(turns_value) -> str: +def _turns_to_text(turns_value: Any) -> str: if isinstance(turns_value, list): if not turns_value: return "" @@ -142,17 +224,3 @@ def _turns_to_text(turns_value) -> str: if key in turns_value: return str(turns_value[key]) return str(turns_value) - - -def _build_model_outputs(raw_df: pd.DataFrame) -> pd.DataFrame | None: - if not {"model", "output"}.issubset(raw_df.columns): - return None - instruction_index = _pick_instruction_index(raw_df) - df_outputs = pd.DataFrame( - { - "instruction_index": instruction_index, - "model": raw_df["model"].astype(str), - "output": raw_df["output"].fillna("").astype(str), - } - ) - return df_outputs diff --git a/tests/test_chat_vllm.py b/tests/test_chat_vllm.py index a0f04b5..0cf34b1 100644 --- a/tests/test_chat_vllm.py +++ b/tests/test_chat_vllm.py @@ -186,6 +186,10 @@ def test_build_default_judge_model_kwargs_overlays_judge_override(): judge_engine_kwargs_override={"tensor_parallel_size": 4}, ) assert overridden["tensor_parallel_size"] == 4 + # FP8 weights + FP8 KV cache are a name-driven invariant; the TP override + # must not silently drop `kv_cache_dtype=fp8` because we run the Skywork + # 70B FP8 judge on TP=2 and TP=4 interchangeably depending on the cell. + assert overridden["kv_cache_dtype"] == "fp8" empty_override = utils.build_default_judge_model_kwargs( "VLLM/Skywork/Skywork-Critic-Llama-3.1-70B-FP8", diff --git a/tests/test_generate_and_evaluate_arena_hard.py b/tests/test_generate_and_evaluate_arena_hard.py new file mode 100644 index 0000000..1dd51ff --- /dev/null +++ b/tests/test_generate_and_evaluate_arena_hard.py @@ -0,0 +1,94 @@ +import pandas as pd +import pytest + +from judgearena.generate_and_evaluate import ( + BaselinePlan, + CliArgs, + _resolve_baseline_plan, +) + + +def _make_args(dataset, model_b=None): + return CliArgs( + dataset=dataset, + model_A="A", + model_B=model_b, + judge_model="J", + ) + + +def _instructions(ids, categories=None): + data = {"instruction": list(ids)} + if categories is not None: + data["category"] = list(categories) + return pd.DataFrame(data, index=pd.Index(ids, name="instruction_index")) + + +def test_resolve_plan_v01_flat_default(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v0.1"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "gpt-4-0314" + + +def test_resolve_plan_v20_routes_per_category(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions( + ["qh", "qc"], + categories=["hard_prompt", "creative_writing"], + ), + ) + assert not plan.is_flat + assert plan.baseline_by_index.loc["qh"] == "o3-mini-2025-01-31" + assert plan.baseline_by_index.loc["qc"] == "gemini-2.0-flash-001" + + +def test_resolve_plan_explicit_model_b_overrides_native(): + plan = _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0", model_b="override"), + instructions_df=_instructions( + ["q1", "q2"], categories=["hard_prompt", "creative_writing"] + ), + ) + assert plan.is_flat + assert plan.single_model == "override" + + +def test_resolve_plan_non_arena_hard_requires_model_b(): + with pytest.raises(ValueError, match="model_B"): + _resolve_baseline_plan( + args=_make_args("alpaca-eval"), + instructions_df=_instructions(["q1"]), + ) + + +def test_resolve_plan_v20_missing_category_raises(): + with pytest.raises(ValueError, match="category"): + _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions(["q1"]), + ) + + +def test_resolve_plan_v20_unknown_category_raises(): + with pytest.raises(ValueError, match="brand_new"): + _resolve_baseline_plan( + args=_make_args("arena-hard-v2.0"), + instructions_df=_instructions(["q1"], categories=["brand_new"]), + ) + + +def test_baseline_plan_flat_repeats_model(): + plan = BaselinePlan.flat("b", index=pd.Index(["a", "b"])) + assert plan.is_flat + assert plan.baseline_by_index.tolist() == ["b", "b"] + + +def test_baseline_plan_per_row_preserves_order(): + series = pd.Series(["m1", "m2"], index=["a", "b"], name="model_B") + plan = BaselinePlan.per_row(series) + assert not plan.is_flat + assert plan.unique_models == ["m1", "m2"] diff --git a/tests/test_instruction_dataset.py b/tests/test_instruction_dataset.py index 4daa144..93334f3 100644 --- a/tests/test_instruction_dataset.py +++ b/tests/test_instruction_dataset.py @@ -1,21 +1,51 @@ from pathlib import Path import pandas as pd +import pytest import judgearena.generate_and_evaluate as generate_and_evaluate import judgearena.instruction_dataset as instruction_dataset +import judgearena.utils as judgearena_utils from judgearena.instruction_dataset.arena_hard import ( - arena_hard_baseline_model, + ARENA_HARD_BASELINES, + _build_instructions, + _build_model_outputs, + _extract_assistant_output, + arena_hard_native_baseline, normalize_official_arena_hard, ) -def test_arena_hard_baseline_resolution(): - assert arena_hard_baseline_model("arena-hard-v0.1") == "gpt-4-0314" - assert arena_hard_baseline_model("arena-hard-v2.0") == "o3-mini-2025-01-31" +def test_arena_hard_native_baseline_v01_is_flat_string(): + assert arena_hard_native_baseline("arena-hard-v0.1") == "gpt-4-0314" -def test_normalize_official_arena_hard_v01_shape(): +def test_arena_hard_native_baseline_v20_is_per_category_mapping(): + native = arena_hard_native_baseline("arena-hard-v2.0") + assert isinstance(native, dict) + assert native["hard_prompt"] == "o3-mini-2025-01-31" + assert native["coding"] == "o3-mini-2025-01-31" + assert native["math"] == "o3-mini-2025-01-31" + assert native["creative_writing"] == "gemini-2.0-flash-001" + + +def test_arena_hard_baselines_mapping_matches_upstream(): + """Pin the exact baseline assignment so a silent edit to + ARENA_HARD_BASELINES can't drift away from upstream + (arena-hard-auto/utils/judge_utils.py::JUDGE_SETTINGS). + """ + assert ARENA_HARD_BASELINES == { + "arena-hard-v0.1": "gpt-4-0314", + "arena-hard-v2.0": { + "hard_prompt": "o3-mini-2025-01-31", + "coding": "o3-mini-2025-01-31", + "math": "o3-mini-2025-01-31", + "creative_writing": "gemini-2.0-flash-001", + }, + } + + +def test_normalize_official_arena_hard_v01_drops_no_category(): raw_df = pd.DataFrame( { "question_id": ["q1", "q2"], @@ -35,6 +65,165 @@ def test_normalize_official_arena_hard_v01_shape(): assert set(df_outputs.columns) == {"instruction_index", "model", "output"} +def test_normalize_official_arena_hard_v20_preserves_category(): + raw_df = pd.DataFrame( + { + "question_id": ["q1", "q2", "q1"], + "prompt": ["First prompt", "Second prompt", None], + "category": ["hard_prompt", "creative_writing", None], + "model": [None, None, "o3-mini-2025-01-31"], + "output": [None, None, "answer text"], + } + ) + df_instructions, df_outputs = normalize_official_arena_hard( + raw_df=raw_df, dataset="arena-hard-v2.0" + ) + + assert "category" in df_instructions.columns + assert df_instructions.set_index("instruction_index")["category"].to_dict() == { + "q1": "hard_prompt", + "q2": "creative_writing", + } + assert df_outputs is not None + assert df_outputs["model"].tolist() == ["o3-mini-2025-01-31"] + assert df_outputs["output"].tolist() == ["answer text"] + + +def test_build_model_outputs_extracts_upstream_messages_shape(): + """Upstream's `model_answer/*.jsonl` rows keep the assistant response in + `messages[-1].content.answer` rather than a flat `output` column. Without + this extractor, a fresh `download_arena_hard` clone would silently drop + every baseline answer. + """ + raw_df = pd.DataFrame( + [ + { + "uid": "q1", + "model": "o3-mini-2025-01-31", + "messages": [ + {"role": "user", "content": "Prompt"}, + { + "role": "assistant", + "content": {"answer": "nested answer", "reasoning": "..."}, + }, + ], + }, + { + "uid": "q2", + "model": "gemini-2.0-flash-001", + "messages": [ + {"role": "user", "content": "Prompt"}, + {"role": "assistant", "content": "plain string answer"}, + ], + }, + { + "uid": "q3", + "model": "baseline", + "output": "flat output column", + }, + { + "uid": "q4", + "model": "no-output-model", + "messages": [{"role": "assistant", "content": {"reasoning": "..."}}], + }, + ] + ) + + df_outputs = _build_model_outputs(raw_df) + + assert df_outputs is not None + outputs_by_model = dict(zip(df_outputs["model"], df_outputs["output"], strict=True)) + assert outputs_by_model == { + "o3-mini-2025-01-31": "nested answer", + "gemini-2.0-flash-001": "plain string answer", + "baseline": "flat output column", + } + assert "no-output-model" not in outputs_by_model + + +@pytest.mark.parametrize( + "row, expected", + [ + ({"output": "flat"}, "flat"), + ( + { + "messages": [ + {"role": "user", "content": "p"}, + {"role": "assistant", "content": {"answer": "nested"}}, + ] + }, + "nested", + ), + ( + { + "messages": [ + {"role": "user", "content": "p"}, + {"role": "assistant", "content": "plain"}, + ] + }, + "plain", + ), + ({"output": None, "messages": None}, None), + ( + {"messages": [{"role": "assistant", "content": {"reasoning": "only"}}]}, + None, + ), + ], +) +def test_extract_assistant_output_covers_known_shapes(row, expected): + assert _extract_assistant_output(pd.Series(row)) == expected + + +def test_build_model_outputs_returns_multi_model_rows_per_upstream_zip(): + """The fresh-clone loader must produce one row per (model, uid) so the + flat zip consumed by `try_load_dataset_completions` pivots cleanly. + """ + raw_df = pd.DataFrame( + [ + { + "uid": "q1", + "model": "o3-mini-2025-01-31", + "messages": [{"role": "assistant", "content": {"answer": "o3 q1"}}], + }, + { + "uid": "q2", + "model": "o3-mini-2025-01-31", + "messages": [{"role": "assistant", "content": {"answer": "o3 q2"}}], + }, + { + "uid": "q1", + "model": "gemini-2.0-flash-001", + "messages": [{"role": "assistant", "content": {"answer": "gemini q1"}}], + }, + ] + ) + + df_outputs = _build_model_outputs(raw_df) + + assert df_outputs is not None + assert sorted(df_outputs["model"].unique().tolist()) == [ + "gemini-2.0-flash-001", + "o3-mini-2025-01-31", + ] + assert df_outputs.shape[0] == 3 + + +def test_build_instructions_drops_model_answer_rows(): + """Question rows and model-answer rows share a dataframe on fresh clone; + `_build_instructions` has to keep only the prompt rows so the instruction + table doesn't leak rows with no prompt text. + """ + raw_df = pd.DataFrame( + [ + {"uid": "q1", "prompt": "real prompt", "category": "hard_prompt"}, + {"uid": "q1", "model": "baseline", "output": "answer"}, + ] + ) + df = _build_instructions(raw_df) + assert df["instruction_index"].tolist() == ["q1"] + assert df["instruction"].tolist() == ["real prompt"] + + def test_load_instructions_uses_explicit_version_filename(monkeypatch): captured = {} @@ -52,7 +241,7 @@ def _fake_read_df(path: Path): ) monkeypatch.setattr(instruction_dataset, "download_arena_hard", _fake_ensure) - monkeypatch.setattr(instruction_dataset, "read_df", _fake_read_df) + monkeypatch.setattr(judgearena_utils, "read_df", _fake_read_df) df = instruction_dataset.load_instructions(dataset="arena-hard-v2.0") assert captured["dataset"] == "arena-hard-v2.0" @@ -60,6 +249,35 @@ def _fake_read_df(path: Path): assert df.index.tolist() == ["0", "1"] +def test_load_instructions_surfaces_category_for_v20(monkeypatch): + """The per-category baseline plan in `generate_and_evaluate` keys off + the `category` column, so `load_instructions` must keep it round-tripping + from the cached CSV. + """ + monkeypatch.setattr( + instruction_dataset, + "download_arena_hard", + lambda dataset, local_tables_path: None, + ) + monkeypatch.setattr( + judgearena_utils, + "read_df", + lambda path: pd.DataFrame( + { + "instruction_index": ["q1", "q2"], + "instruction": ["a", "b"], + "category": ["hard_prompt", "creative_writing"], + } + ), + ) + + df = instruction_dataset.load_instructions(dataset="arena-hard-v2.0") + + assert "category" in df.columns + assert df.loc["q1", "category"] == "hard_prompt" + assert df.loc["q2", "category"] == "creative_writing" + + def test_try_load_dataset_completions_uses_dataset_output_file(monkeypatch, tmp_path): tables_dir = tmp_path / "tables" / "model_outputs" tables_dir.mkdir(parents=True, exist_ok=True) From 2af471407dc4d85b378b6112ab2a2f5a5b12bb8e Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 21 Apr 2026 23:18:59 +0200 Subject: [PATCH 20/28] Add m-arenahard-v2.0 --- README.md | 27 ++-- judgearena/evaluate.py | 5 +- judgearena/generate_and_evaluate.py | 5 +- judgearena/instruction_dataset/__init__.py | 49 ++---- judgearena/instruction_dataset/m_arenahard.py | 146 ++++++++++++++---- judgearena/utils.py | 9 +- tests/test_generate_and_evaluate.py | 3 +- tests/test_mt_bench_downloads.py | 3 +- tests/test_utils.py | 7 +- 9 files changed, 170 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index f5050e9..fbf5ff3 100644 --- a/README.md +++ b/README.md @@ -193,15 +193,24 @@ This override applies to all vLLM models in the run. For remote providers (OpenA ## 📊 Supported Datasets -| Dataset | Description | -|-----------------------|------------------------------------------------------------------------------------------------| -| `alpaca-eval` | General instruction-following benchmark | -| `arena-hard-v2.0` | Arena-Hard v2.0 from official `lmarena-ai/arena-hard-auto` source | -| `arena-hard-v0.1` | Legacy Arena-Hard v0.1 from official `lmarena-ai/arena-hard-auto` source | -| `m-arena-hard` | Translated version of Arena-Hard in 23 languages | -| `m-arena-hard-{lang}` | Language-specific variants (e.g., `ar`, `cs`, `de`) | -| `m-arena-hard-EU` | All EU languages combined | -| `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | +| Dataset | Description | +|-----------------------------|----------------------------------------------------------------------------------------------------------------| +| `alpaca-eval` | General instruction-following benchmark | +| `arena-hard-v2.0` | Arena-Hard v2.0 from official `lmarena-ai/arena-hard-auto` source | +| `arena-hard-v0.1` | Legacy Arena-Hard v0.1 from official `lmarena-ai/arena-hard-auto` source | +| `m-arena-hard-v0.1` | `CohereLabs/m-ArenaHard` (500 prompts, Google-Translate) across 23 languages | +| `m-arena-hard-v0.1-{lang}` | Language-specific v0.1 slice (e.g., `ar`, `cs`, `de`, `uk`, `zh`, `pl`) | +| `m-arena-hard-v0.1-EU` | All EU v0.1 languages combined | +| `m-arena-hard-v2.0` | `CohereLabs/m-ArenaHard-v2.0` (498 prompts, in-house translation) across 23 languages | +| `m-arena-hard-v2.0-{lang}` | Language-specific v2.0 slice | +| `m-arena-hard-v2.0-EU` | All EU v2.0 languages combined | +| `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | + +For m-Arena-Hard, we use baseline completions based on the benchmark release: +- `m-arena-hard-v0.1`: Aya Expanse 8B (`CohereLabs/aya-expanse-8b`), ingested + from `CohereLabs/deja-vu-pairwise-evals` (repeat 0) via + [`scripts/multilingual_arena_hard/ingest_deja_vu_aya_references.py`](scripts/multilingual_arena_hard/ingest_deja_vu_aya_references.py). +- `m-arena-hard-v2.0`: We generate our own completions with Gemini 2.5 Flash (`google/gemini-2.5-flash`). For Arena-Hard, JudgeArena mirrors the baseline assignment upstream uses in `lmarena-ai/arena-hard-auto`: - `arena-hard-v0.1`: flat baseline `gpt-4-0314` for all 500 prompts. diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 0063623..205f739 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -150,8 +150,9 @@ def evaluate_completions( dataset=dataset, ).loc[:, "instruction"] - # A bit ugly, only loads if local path exist as we do not have a local path of completion for cases such as - # m-arena-hard. + # Only loads if the per-dataset local path exists; some datasets (e.g. + # language slices of m-arena-hard for which no baseline has been written + # yet) may not ship a local completions file. dataset_output_path = local_path_tables / "model_outputs" / f"{dataset}.csv.zip" if dataset_output_path.exists(): df_outputs = read_df(dataset_output_path) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 6c2b9e3..5484e55 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -105,8 +105,9 @@ def parse_args(cls): parser.add_argument( "--dataset", help="The dataset to use. For instance `alpaca-eval`, `arena-hard-v2.0`, " - "`arena-hard-v0.1`, `m-arena-hard-EU` for instruction " - "tuning cases or `french-contexts`, `spanish-contexts` for base models.", + "`arena-hard-v0.1`, `m-arena-hard-v0.1-EU`, `m-arena-hard-v2.0-uk` for " + "instruction tuning cases or `french-contexts`, `spanish-contexts` for " + "base models.", ) parser.add_argument( "--model_A", diff --git a/judgearena/instruction_dataset/__init__.py b/judgearena/instruction_dataset/__init__.py index 48fccd1..6570ee9 100644 --- a/judgearena/instruction_dataset/__init__.py +++ b/judgearena/instruction_dataset/__init__.py @@ -4,6 +4,10 @@ download_arena_hard, is_arena_hard_dataset, ) +from judgearena.instruction_dataset.m_arenahard import ( + load_m_arenahard, + split_m_arena_hard_dataset, +) def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.DataFrame: @@ -12,45 +16,16 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat df_instructions = load_mt_bench() - elif "m-arena-hard" in dataset: - from judgearena.instruction_dataset.m_arenahard import load_m_arenahard + elif (parsed := split_m_arena_hard_dataset(dataset)) is not None: from judgearena.utils import data_root - if dataset == "m-arena-hard": - language = None - else: - # read the suffix part "m-arena-hard-EU" -> "EU" - language = dataset.split("-")[-1] - assert language in [ - None, - "ar", - "cs", - "de", - "el", - "en", - "es", - "fa", - "fr", - "he", - "hi", - "id", - "it", - "ja", - "ko", - "nl", - "pl", - "pt", - "ro", - "ru", - "tr", - "uk", - "vi", - "zh", - "EU", - ] - print(f"Loading m-arena-hard with language specification set to {language}") - df_instructions = load_m_arenahard(local_path=data_root, language=language) - + version_key, lang_or_subset = parsed + print( + f"Loading {version_key} with language specification set to {lang_or_subset}" + ) + df_instructions = load_m_arenahard( + local_path=data_root, version=version_key, language=lang_or_subset + ) # sort by question_id, then language so that we get multiple languages if we truncate df_instructions.sort_values(["question_id", "lang"], inplace=True) df_instructions.rename( diff --git a/judgearena/instruction_dataset/m_arenahard.py b/judgearena/instruction_dataset/m_arenahard.py index 1c45f33..81b4f2e 100644 --- a/judgearena/instruction_dataset/m_arenahard.py +++ b/judgearena/instruction_dataset/m_arenahard.py @@ -1,37 +1,132 @@ +"""Version-aware m-ArenaHard loader. + +Mirrors ``judgearena/instruction_dataset/arena_hard.py``: each supported +``m-arena-hard-v{X.Y}`` maps to its dataset-native baseline, and a parallel +private dict carries the upstream HF repo id. The dispatcher in +``judgearena/instruction_dataset/__init__.py`` uses +``split_m_arena_hard_dataset`` to parse ``m-arena-hard-v{X.Y}[-{lang}|-EU]`` +and then calls ``load_m_arenahard``. +""" + from pathlib import Path import pandas as pd from huggingface_hub import snapshot_download +EU_LANGUAGES: tuple[str, ...] = ( + "cs", + "de", + "el", + "en", + "es", + "fr", + "it", + "nl", + "pl", + "pt", + "ro", + "uk", +) + +NON_EU_LANGUAGES: tuple[str, ...] = ( + "ar", + "fa", + "he", + "hi", + "id", + "ja", + "ko", + "ru", + "tr", + "vi", + "zh", +) + +ALL_LANGUAGES: tuple[str, ...] = (*EU_LANGUAGES, *NON_EU_LANGUAGES) + +# Dataset name -> dataset-native baseline model. Shape mirrors +# `ARENA_HARD_BASELINES` in `arena_hard.py`. v0.1 uses Aya Expanse 8B (free +# completions from CohereLabs/deja-vu-pairwise-evals); v2.0 uses Gemini 2.5 Flash. +M_ARENA_HARD_BASELINES: dict[str, str] = { + "m-arena-hard-v0.1": "CohereLabs/aya-expanse-8b", + "m-arena-hard-v2.0": "google/gemini-2.5-flash", +} + +# Dataset name -> upstream HF repo id. Kept private; the on-disk cache subdir +# is derived from the repo's short name. +_M_ARENA_HARD_HF_REPOS: dict[str, str] = { + "m-arena-hard-v0.1": "CohereLabs/m-ArenaHard", + "m-arena-hard-v2.0": "CohereLabs/m-ArenaHard-v2.0", +} + + +def is_m_arena_hard_dataset(dataset: str) -> bool: + return split_m_arena_hard_dataset(dataset) is not None + + +def split_m_arena_hard_dataset(dataset: str) -> tuple[str, str | None] | None: + """Parse ``m-arena-hard-v{X.Y}[-{lang}|-EU]`` into ``(version, suffix)``. -def load_m_arenahard(local_path, language: str | None = None): + Returns ``None`` for any name that doesn't match a known version or that + carries an unknown suffix. ``suffix`` is ``None`` for the all-languages + variant, ``"EU"`` for the EU subset, or a 2-letter code in + :data:`ALL_LANGUAGES`. Versioned names only -- the unversioned + ``m-arena-hard`` alias is deliberately not accepted. + """ + for version in M_ARENA_HARD_BASELINES: + if dataset == version: + return version, None + if dataset.startswith(f"{version}-"): + suffix = dataset[len(version) + 1 :] + if suffix == "EU" or suffix in ALL_LANGUAGES: + return version, suffix + return None + return None + + +def m_arena_hard_native_baseline(dataset: str) -> str | None: + """Baseline for a dataset name, or ``None`` if it isn't m-arena-hard.""" + parsed = split_m_arena_hard_dataset(dataset) + if parsed is None: + return None + return M_ARENA_HARD_BASELINES[parsed[0]] + + +def load_m_arenahard( + local_path: Path, + version: str, + language: str | None = None, +) -> pd.DataFrame: + """Load m-ArenaHard prompts for the requested version and language subset. + + ``version`` must be a key in :data:`M_ARENA_HARD_BASELINES`. ``language`` + is ``None`` for the full 23-language union, ``"EU"`` for the EU subset, + or a 2-letter language code for a single-language slice. + + The returned DataFrame carries the upstream columns plus a ``lang`` + column, with ``question_id`` rewritten to ``f"{question_id}-{lang}"`` so + multi-language slices have unique identifiers. + """ + if version not in _M_ARENA_HARD_HF_REPOS: + raise ValueError( + f"Unsupported m-ArenaHard version: {version!r}. " + f"Known versions: {sorted(_M_ARENA_HARD_HF_REPOS)}." + ) + repo_id = _M_ARENA_HARD_HF_REPOS[version] + local_subdir = repo_id.split("/", 1)[1] snapshot_download( - repo_id="CohereLabs/m-ArenaHard", + repo_id=repo_id, repo_type="dataset", allow_patterns="*", - local_dir=local_path / "m-ArenaHard", + local_dir=local_path / local_subdir, force_download=False, ) + m_arena_root = local_path / local_subdir - df_union = [] - m_arena_root = Path(local_path / "m-ArenaHard") - eu_languages = [ - "cs", - "de", - "el", - "en", - "es", - "fr", - "it", - "nl", - "pl", - "pt", - "ro", - "uk", - ] - for path in sorted(Path(m_arena_root).rglob("*.parquet")): + df_union: list[pd.DataFrame] = [] + for path in sorted(m_arena_root.rglob("*.parquet")): lg = path.parent.name - if language == "EU" and lg in eu_languages: + if language == "EU" and lg in EU_LANGUAGES: df = pd.read_parquet(path) df["lang"] = lg df_union.append(df) @@ -40,18 +135,17 @@ def load_m_arenahard(local_path, language: str | None = None): df["lang"] = lg df_union.append(df) - assert len(df_union) > 0, f"Invalid language passed {language}" + assert len(df_union) > 0, ( + f"No parquet matched under {m_arena_root} for language={language!r}." + ) df_res = pd.concat(df_union, ignore_index=True) - - # update index to still be unique by appendix language as a suffix df_res["question_id"] = df_res.apply( lambda row: f"{row['question_id']}-{row['lang']}", axis=1 ) - return df_res if __name__ == "__main__": from judgearena.utils import data_root - load_m_arenahard(local_path=data_root, language="EU") + load_m_arenahard(local_path=data_root, version="m-arena-hard-v0.1", language="EU") diff --git a/judgearena/utils.py b/judgearena/utils.py index 5d3e7f4..de4a518 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -918,14 +918,17 @@ def infer_model_spec_from_instance(model: object) -> str | None: def download_all(): + from judgearena.instruction_dataset.m_arenahard import M_ARENA_HARD_BASELINES + print(f"Downloading all dataset in {data_root}") local_path_tables = data_root / "tables" - for dataset in [ + datasets = [ "alpaca-eval", "arena-hard-v0.1", "arena-hard-v2.0", - "m-arena-hard", - ]: + *M_ARENA_HARD_BASELINES.keys(), + ] + for dataset in datasets: if is_arena_hard_dataset(dataset): download_arena_hard(dataset=dataset, local_tables_path=local_path_tables) else: diff --git a/tests/test_generate_and_evaluate.py b/tests/test_generate_and_evaluate.py index ada7ebe..6df21c5 100644 --- a/tests/test_generate_and_evaluate.py +++ b/tests/test_generate_and_evaluate.py @@ -55,7 +55,8 @@ def _run_without_cache(fun, **_kwargs): "arena-hard-v2.0", "arena-hard-v0.1", "fluency-french", - "m-arena-hard-EU", + "m-arena-hard-v0.1-EU", + "m-arena-hard-v2.0-EU", ], ) def test_generate_and_evaluate_context_completion(dataset: str, tmp_path): diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 9428303..98d146b 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -95,7 +95,8 @@ def _contexts_snapshot_stub(**_kwargs): tables_dir = tmp_path / "tables" assert [name for name, _ in hf_datasets] == [ "alpaca-eval", - "m-arena-hard", + "m-arena-hard-v0.1", + "m-arena-hard-v2.0", ] assert arena_hard_datasets == [ ("arena-hard-v0.1", tables_dir), diff --git a/tests/test_utils.py b/tests/test_utils.py index 9a96d7d..bdd4b8e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -109,13 +109,14 @@ def test_download_all_dispatches_arena_hard_versions(monkeypatch, tmp_path): utils.download_all() tables_dir = tmp_path / "tables" - assert calls[:4] == [ + assert calls[:5] == [ ("hf", "alpaca-eval", tables_dir), ("arena", "arena-hard-v0.1", tables_dir), ("arena", "arena-hard-v2.0", tables_dir), - ("hf", "m-arena-hard", tables_dir), + ("hf", "m-arena-hard-v0.1", tables_dir), + ("hf", "m-arena-hard-v2.0", tables_dir), ] - assert calls[4] == ( + assert calls[5] == ( "snapshot", "geoalgo/multilingual-contexts-to-be-completed", tmp_path / "contexts", From da6818eedb3b007a3e2c05796b66c9f54c9b41e9 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 22 Apr 2026 01:08:31 +0200 Subject: [PATCH 21/28] add default baseline for mt-bench --- README.md | 6 ++++++ judgearena/generate_and_evaluate.py | 3 ++- judgearena/instruction_dataset/mt_bench.py | 17 +++++++++++++++++ judgearena/mt_bench/mt_bench_utils.py | 16 ++++++++++------ judgearena/utils.py | 16 ++++++++++++++-- tests/test_instruction_dataset.py | 14 ++++++++++++++ tests/test_mt_bench_downloads.py | 4 ++-- 7 files changed, 65 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index fbf5ff3..cdf7b21 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,7 @@ This override applies to all vLLM models in the run. For remote providers (OpenA | Dataset | Description | |-----------------------------|----------------------------------------------------------------------------------------------------------------| | `alpaca-eval` | General instruction-following benchmark | +| `mt-bench` | FastChat MT-Bench (80 multi-turn questions) from the `lmsys/mt-bench` HF Space | | `arena-hard-v2.0` | Arena-Hard v2.0 from official `lmarena-ai/arena-hard-auto` source | | `arena-hard-v0.1` | Legacy Arena-Hard v0.1 from official `lmarena-ai/arena-hard-auto` source | | `m-arena-hard-v0.1` | `CohereLabs/m-ArenaHard` (500 prompts, Google-Translate) across 23 languages | @@ -206,6 +207,11 @@ This override applies to all vLLM models in the run. For remote providers (OpenA | `m-arena-hard-v2.0-EU` | All EU v2.0 languages combined | | `fluency-{lang}` | Fluency evaluation for pretrained models (`finnish`, `french`, `german`, `spanish`, `swedish`) | +For MT-Bench, the default pairwise baseline is `gpt-4`. +We diverge from FastChat's own `pairwise-baseline` default (`gpt-3.5-turbo`) to keep +a stronger reference consistent with Arena-Hard v0.1; the `gpt-4.jsonl` completions +ship in the `lmsys/mt-bench` HF Space. Override per run with `--model_B`. + For m-Arena-Hard, we use baseline completions based on the benchmark release: - `m-arena-hard-v0.1`: Aya Expanse 8B (`CohereLabs/aya-expanse-8b`), ingested from `CohereLabs/deja-vu-pairwise-evals` (repeat 0) via diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 5484e55..3ca1904 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -120,7 +120,8 @@ def parse_args(cls): help=( "Name of the baseline LLM for a generation. Optional for Arena-Hard " "datasets (which ship a dataset-native default per category; see " - "`ARENA_HARD_BASELINES`). Required for every other dataset." + "`ARENA_HARD_BASELINES`) and MT-Bench (see `MT_BENCH_BASELINES`, " + "defaults to `gpt-4`). Required for every other dataset." ), ) parser.add_argument( diff --git a/judgearena/instruction_dataset/mt_bench.py b/judgearena/instruction_dataset/mt_bench.py index 23aa0fe..291e13e 100644 --- a/judgearena/instruction_dataset/mt_bench.py +++ b/judgearena/instruction_dataset/mt_bench.py @@ -15,6 +15,23 @@ "fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl" ) +# Mirrors ``ARENA_HARD_BASELINES`` / ``M_ARENA_HARD_BASELINES``: dataset name -> +# dataset-native pairwise baseline. MT-Bench ships only one variant today, and +# ``gpt-4`` is the stronger-reference choice (FastChat's own ``pairwise-baseline`` +# default is ``gpt-3.5-turbo``; we deliberately diverge here). +MT_BENCH_BASELINES: dict[str, str] = { + "mt-bench": "gpt-4", +} + + +def is_mt_bench_dataset(dataset: str) -> bool: + return dataset in MT_BENCH_BASELINES + + +def mt_bench_native_baseline(dataset: str) -> str | None: + """Baseline for a dataset name, or ``None`` if it isn't mt-bench.""" + return MT_BENCH_BASELINES.get(dataset) + def _normalize_question_id(question_id: object) -> object: try: diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 864c8a2..c23cc7a 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -19,7 +19,10 @@ from judgearena.eval_utils import _compute_grouped_stats, print_results from judgearena.generate import generate_multiturn from judgearena.instruction_dataset import load_instructions -from judgearena.instruction_dataset.mt_bench import load_mt_bench_model_answers +from judgearena.instruction_dataset.mt_bench import ( + load_mt_bench_model_answers, + mt_bench_native_baseline, +) from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.mt_bench.fastchat_compat import ( FASTCHAT_TEMPERATURE_CONFIG, @@ -300,12 +303,13 @@ def run_mt_bench(args: CliArgs, ignore_cache: bool): "MT-Bench ignores provide_explanation=False and keeps the original " "FastChat-style explanation-plus-verdict prompt." ) - if args.swap_mode != "both": - print( - "MT-Bench requires swap_mode='both' to match FastChat and correct " - f"for position bias; overriding requested swap_mode='{args.swap_mode}'." + if args.model_B is None: + args.model_B = mt_bench_native_baseline(args.dataset) + if args.model_B is None: + raise ValueError( + f"--model_B is required for dataset '{args.dataset}'; " + "no dataset-native baseline registered." ) - args.swap_mode = "both" if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: print( "MT-Bench judge prompts require room for budgeted thinking, the " diff --git a/judgearena/utils.py b/judgearena/utils.py index de4a518..b1ff1af 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -17,6 +17,10 @@ from langchain_openai import ChatOpenAI from tqdm.asyncio import tqdm +from judgearena.chat_models import ( + OpenRouterGeminiSafetyTolerantChatOpenAI, + is_openrouter_gemini_model, +) from judgearena.instruction_dataset.arena_hard import ( download_arena_hard, is_arena_hard_dataset, @@ -866,8 +870,16 @@ def make_model(model: str, max_tokens: int | None = 8192, **engine_kwargs): ) if model_provider == "OpenRouter": - # Special case we need to override API url and key - return ChatOpenAI( + # Gemini's core policy filter rejects a small fraction of prompts with + # a hard PROHIBITED_CONTENT error that safety_settings cannot override; + # the subclass converts those into stub refusals so batch generation + # (e.g. benchmark baselines) completes instead of crashing. + chat_model_cls = ( + OpenRouterGeminiSafetyTolerantChatOpenAI + if is_openrouter_gemini_model(model) + else ChatOpenAI + ) + return chat_model_cls( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model=model_name, diff --git a/tests/test_instruction_dataset.py b/tests/test_instruction_dataset.py index 93334f3..849fd94 100644 --- a/tests/test_instruction_dataset.py +++ b/tests/test_instruction_dataset.py @@ -45,6 +45,20 @@ def test_arena_hard_baselines_mapping_matches_upstream(): } +def test_mt_bench_native_baseline_is_flat_string(): + from judgearena.instruction_dataset.mt_bench import ( + MT_BENCH_BASELINES, + is_mt_bench_dataset, + mt_bench_native_baseline, + ) + + assert is_mt_bench_dataset("mt-bench") is True + assert is_mt_bench_dataset("alpaca-eval") is False + assert mt_bench_native_baseline("mt-bench") == "gpt-4" + assert mt_bench_native_baseline("alpaca-eval") is None + assert MT_BENCH_BASELINES == {"mt-bench": "gpt-4"} + + def test_normalize_official_arena_hard_v01_drops_no_category(): raw_df = pd.DataFrame( { diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 98d146b..377b435 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -338,7 +338,7 @@ def fake_run_mt_bench_fastchat(**kwargs): mt_bench_utils.run_mt_bench(args, ignore_cache=False) - assert args.swap_mode == "both" + assert args.swap_mode == "fixed" assert args.max_out_tokens_judge == 24576 assert args.max_model_len == 16384 assert args.max_judge_model_len == 28672 @@ -355,7 +355,7 @@ def fake_run_mt_bench_fastchat(**kwargs): "limit_event_model_spec": "VLLM/Qwen/Qwen3.5-27B-FP8", "limit_event_tracker": captured["make_model"]["kwargs"]["limit_event_tracker"], } - assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "both" + assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "fixed" assert captured["run_mt_bench_fastchat"]["prompt_preset"] == "default" assert ( captured["run_mt_bench_fastchat"]["args"].strip_thinking_before_judging is False From 891c4174c96f846e6f7b2b322b5fcb9187e2d366 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 22 Apr 2026 01:09:18 +0200 Subject: [PATCH 22/28] handle prohibited content errors for gemini in openrouter --- judgearena/chat_models/__init__.py | 15 +++ judgearena/chat_models/openrouter_gemini.py | 102 ++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 judgearena/chat_models/__init__.py create mode 100644 judgearena/chat_models/openrouter_gemini.py diff --git a/judgearena/chat_models/__init__.py b/judgearena/chat_models/__init__.py new file mode 100644 index 0000000..91aa08d --- /dev/null +++ b/judgearena/chat_models/__init__.py @@ -0,0 +1,15 @@ +"""Chat-model adapters with provider-specific hardening.""" + +from judgearena.chat_models.openrouter_gemini import ( + GEMINI_SAFETY_REFUSAL_MARKER, + OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON, + OpenRouterGeminiSafetyTolerantChatOpenAI, + is_openrouter_gemini_model, +) + +__all__ = [ + "GEMINI_SAFETY_REFUSAL_MARKER", + "OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON", + "OpenRouterGeminiSafetyTolerantChatOpenAI", + "is_openrouter_gemini_model", +] diff --git a/judgearena/chat_models/openrouter_gemini.py b/judgearena/chat_models/openrouter_gemini.py new file mode 100644 index 0000000..bce637d --- /dev/null +++ b/judgearena/chat_models/openrouter_gemini.py @@ -0,0 +1,102 @@ +"""ChatOpenAI subclass tolerant to Gemini's PROHIBITED_CONTENT hard-refusals. + +Google's core policy filter rejects a small fraction of prompts (e.g. graphic +violence, sexual content involving minors) with HTTP 403 ``PROHIBITED_CONTENT`` +*regardless* of the adjustable ``safety_settings`` thresholds. These refusals +are legitimate, reproducible model behavior that a benchmark like +``m-arena-hard-v2.0`` surfaces: the baseline should contain them so the judge +can score them, not crash the run. + +The subclass intercepts the error response before LangChain raises, returns +a stub assistant message with a clearly marked refusal payload and +``finish_reason="content_filter"``, and lets the rest of the pipeline proceed +unchanged. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_openai import ChatOpenAI + +GEMINI_SAFETY_REFUSAL_MARKER = ( + "[Gemini safety refusal: PROHIBITED_CONTENT — Google's core policy filter " + "blocked this prompt regardless of safety_settings.]" +) +OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON = "content_filter" + +_PROHIBITED_CONTENT_TOKEN = "PROHIBITED_CONTENT" + + +def is_openrouter_gemini_model(model_spec: str) -> bool: + """Return True when ``model_spec`` targets a Gemini model via OpenRouter. + + Matches ``OpenRouter/google/gemini-2.5-flash`` and related variants. + """ + provider, sep, model_name = model_spec.partition("/") + if not sep: + return False + lowered = model_name.lower() + return provider == "OpenRouter" and ( + lowered.startswith("google/gemini") or lowered.startswith("google/gemma") + ) + + +def _error_is_prohibited_content(error: object) -> bool: + if error is None: + return False + return _PROHIBITED_CONTENT_TOKEN in str(error) + + +def _build_prohibited_content_stub_payload( + *, original_response: dict[str, Any], model_name: str +) -> dict[str, Any]: + stub_message = { + "role": "assistant", + "content": GEMINI_SAFETY_REFUSAL_MARKER, + } + stub_choice = { + "index": 0, + "message": stub_message, + "finish_reason": OPENROUTER_GEMINI_SAFETY_REFUSAL_FINISH_REASON, + } + return { + "id": original_response.get("id") or "openrouter-gemini-safety-stub", + "object": "chat.completion", + "created": original_response.get("created") or 0, + "model": original_response.get("model") or model_name, + "choices": [stub_choice], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + +class OpenRouterGeminiSafetyTolerantChatOpenAI(ChatOpenAI): + """ChatOpenAI that converts Gemini PROHIBITED_CONTENT errors to stubs. + + Only intercepts the specific OpenRouter error surface for Gemini's core + policy filter; all other errors propagate unchanged. The stub message has + ``content == GEMINI_SAFETY_REFUSAL_MARKER`` and ``finish_reason == + "content_filter"`` so upstream validators and judges see the refusal + explicitly rather than a silent drop. + """ + + def _create_chat_result( # type: ignore[override] + self, + response, + generation_info: dict | None = None, + ): + response_dict = ( + response if isinstance(response, dict) else response.model_dump() + ) + error = response_dict.get("error") + if _error_is_prohibited_content(error): + stub = _build_prohibited_content_stub_payload( + original_response=response_dict, + model_name=self.model_name, + ) + return super()._create_chat_result(stub, generation_info=generation_info) + return super()._create_chat_result(response, generation_info=generation_info) From fb361542156915e14b0aae0bdd883e6799475b7b Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 22 Apr 2026 02:44:01 +0200 Subject: [PATCH 23/28] update system prompt with alpaca eval version, fix mismatch for expected token count for max_model_len --- judgearena/evaluate.py | 201 +++++++++++++++++++++++++++ judgearena/generate_and_evaluate.py | 4 + judgearena/prompts/system-prompt.txt | 4 +- 3 files changed, 206 insertions(+), 3 deletions(-) diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index 205f739..5e6a407 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -38,6 +39,13 @@ truncate_with_metadata, ) +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +_PREFLIGHT_MAX_ITERATIONS = 3 +_PREFLIGHT_RESERVED_TOKENS = 256 +_PREFLIGHT_MIN_COMPLETION_CHARS = 512 + class PairScore: def __init__(self, *, parser_mode: str = "score"): @@ -317,6 +325,9 @@ def annotate_battles( usage_phase: str | None = None, usage_model_spec: str | None = None, limit_event_tracker: LimitEventTracker | None = None, + judge_tokenizer: "PreTrainedTokenizerBase | None" = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, ) -> list[JudgeAnnotation]: """ Directly evaluate from list of instructions and completions @@ -344,6 +355,16 @@ def annotate_battles( :param user_prompt_template: :param truncate_input_chars: Max characters to truncate completions before sending to judge. :param use_tqdm: + :param judge_tokenizer: Optional HF tokenizer matching the judge model; when + supplied together with ``max_judge_model_len`` triggers a preflight + tokenize-and-retry pass that shrinks per-completion character caps until + the rendered prompt fits the judge context window. Converts the hard + ``VLLMValidationError`` class into a soft ``judge_input_token_truncation`` + limit event. + :param max_judge_model_len: Judge-side ``max_model_len``; required for the + preflight pass to be active. + :param max_out_tokens_judge: Judge-side output budget subtracted from + ``max_judge_model_len`` to derive the per-request prompt budget. :return: """ # alternatively pass list of tuples @@ -446,6 +467,19 @@ def annotate_battles( ) inputs = prompt_template.batch(input_payloads) + if judge_tokenizer is not None and max_judge_model_len: + inputs = _preflight_shrink_to_judge_budget( + prompt_template=prompt_template, + inputs=inputs, + input_payloads=input_payloads, + annotation_input_metadata=annotation_input_metadata, + case_ids=case_ids, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + limit_event_tracker=limit_event_tracker, + ) + print(f"Start LLM judge annotation ({len(inputs)} annotations).") judge_completions = do_inference( chat_model=judge_chat_model, @@ -505,6 +539,9 @@ def judge_and_parse_prefs( usage_phase: str | None = None, usage_model_spec: str | None = None, limit_event_tracker: LimitEventTracker | None = None, + judge_tokenizer: "PreTrainedTokenizerBase | None" = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, ) -> tuple[list[JudgeAnnotation], list[JudgeAnnotation] | None, pd.Series]: """Run judge annotation and parse preferences, handling swap_mode='both'. @@ -537,6 +574,9 @@ def judge_and_parse_prefs( usage_phase=usage_phase, usage_model_spec=usage_model_spec, limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, ) annotations_reversed = None @@ -558,6 +598,9 @@ def judge_and_parse_prefs( usage_phase=usage_phase, usage_model_spec=usage_model_spec, limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, ) def _none_to_nan(x): @@ -579,3 +622,161 @@ def _none_to_nan(x): prefs = pd.concat([prefs, (1 - prefs_reversed)]).reset_index(drop=True) return annotations, annotations_reversed, prefs + + +_LC_ROLE_MAP = {"human": "user", "ai": "assistant", "system": "system"} + + +def _count_chat_tokens(prompt_value: Any, tokenizer: Any) -> int: + """Count tokens the way vLLM's ``llm.chat()`` tokenizes after applying the + tokenizer's chat template. Falls back to raw-string encoding for tokenizers + without a chat template or if template application raises.""" + if hasattr(prompt_value, "to_messages"): + messages = [ + { + "role": _LC_ROLE_MAP.get(msg.type, msg.type), + "content": msg.content, + } + for msg in prompt_value.to_messages() + ] + try: + return len(tokenizer.apply_chat_template(messages, tokenize=True)) + except Exception: + pass + if hasattr(prompt_value, "to_string"): + text = prompt_value.to_string() + else: + text = str(prompt_value) + return len(tokenizer.encode(text)) + + +def _find_token_overflows( + inputs: list[Any], tokenizer: Any, safe_budget: int +) -> list[tuple[int, int]]: + """Return ``(index, token_count)`` for inputs whose tokenized length exceeds + ``safe_budget``.""" + overflows: list[tuple[int, int]] = [] + for idx, item in enumerate(inputs): + token_count = _count_chat_tokens(item, tokenizer) + if token_count > safe_budget: + overflows.append((idx, token_count)) + return overflows + + +def _chars_per_token(text: str, tokenizer: Any) -> float: + """Return a conservative char-to-token ratio for ``text``, floored at 1.0. + + Short/empty inputs yield a low ratio, which under-truncates rather than + overflowing - the safe direction for the preflight shrink loop. + """ + text = text if isinstance(text, str) else "" + if not text: + return 1.0 + token_count = max(1, len(tokenizer.encode(text))) + return max(1.0, len(text) / token_count) + + +def _render_with_empty_completions( + prompt_template: ChatPromptTemplate, user_prompt: str +) -> Any: + """Render the prompt template with empty completions so the fixed template + + user-prompt overhead can be measured per case. ``ChatPromptTemplate`` + uses ``str.format()`` on each message, so empty strings substitute cleanly + for both completion slots.""" + return prompt_template.invoke( + { + "user_prompt": user_prompt, + "completion_A": "", + "completion_B": "", + } + ) + + +def _preflight_shrink_to_judge_budget( + *, + prompt_template: ChatPromptTemplate, + inputs: list[Any], + input_payloads: list[dict[str, str]], + annotation_input_metadata: list[dict[str, object]], + case_ids: list[object], + judge_tokenizer: Any, + max_judge_model_len: int, + max_out_tokens_judge: int | None, + limit_event_tracker: LimitEventTracker | None, +) -> list[Any]: + """Bounded shrink-and-re-render loop that converts judge-context overflows + into soft ``judge_input_token_truncation`` limit events instead of a hard + ``VLLMValidationError`` at request time. + + The per-completion budget subtracts the case-specific template + user-prompt + overhead so that one iteration typically suffices; the 3-iteration bound is + a genuine safety net for the rare pathological case where the char-to-token + ratio shifts after truncation (e.g. dropping multi-byte glyphs). + """ + safe_budget = ( + max_judge_model_len - (max_out_tokens_judge or 0) - _PREFLIGHT_RESERVED_TOKENS + ) + for _ in range(_PREFLIGHT_MAX_ITERATIONS): + overflows = _find_token_overflows(inputs, judge_tokenizer, safe_budget) + if not overflows: + return inputs + for idx, _token_count in overflows: + payload = input_payloads[idx] + fixed_tokens = _count_chat_tokens( + _render_with_empty_completions(prompt_template, payload["user_prompt"]), + judge_tokenizer, + ) + per_completion_budget = max(256, (safe_budget - fixed_tokens) // 2) + ratio_A = _chars_per_token(payload["completion_A"], judge_tokenizer) + ratio_B = _chars_per_token(payload["completion_B"], judge_tokenizer) + new_cap_A = max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int(per_completion_budget * ratio_A * 0.9), + ) + new_cap_B = max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int(per_completion_budget * ratio_B * 0.9), + ) + payload["completion_A"], shrunk_A = truncate_with_metadata( + payload["completion_A"], + max_len=new_cap_A, + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field="completion_A", + case_id=case_ids[idx], + ) + payload["completion_B"], shrunk_B = truncate_with_metadata( + payload["completion_B"], + max_len=new_cap_B, + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field="completion_B", + case_id=case_ids[idx], + ) + metadata_row = annotation_input_metadata[idx] + metadata_row["completion_A_for_judge"] = payload["completion_A"] + metadata_row["completion_B_for_judge"] = payload["completion_B"] + if shrunk_A: + metadata_row["completion_A_truncated_for_judge"] = True + if shrunk_B: + metadata_row["completion_B_truncated_for_judge"] = True + inputs = prompt_template.batch(input_payloads) + + final_overflows = _find_token_overflows(inputs, judge_tokenizer, safe_budget) + for idx, token_count in final_overflows: + if limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_token_truncation_failed", + stage="judge_input", + case_id=case_ids[idx], + original_length=token_count, + final_length=safe_budget, + note=( + f"{_PREFLIGHT_MAX_ITERATIONS} shrink iterations did not " + f"bring tokens under {safe_budget}; falling through to " + "vLLM validation." + ), + ) + return inputs diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 3ca1904..3c06f83 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -451,6 +451,7 @@ def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Serie chat_template=args.chat_template, **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) + judge_tokenizer = getattr(judge_chat_model, "tokenizer", None) name = ( f"{args.dataset}-{args.model_A}-{baseline_plan.display_name}-{args.judge_model}" @@ -499,6 +500,9 @@ def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Serie usage_phase="judge", usage_model_spec=args.judge_model, limit_event_tracker=limit_event_tracker, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=args.effective_judge_max_model_len(), + max_out_tokens_judge=args.max_out_tokens_judge, ) eval_instruction_index = instructions.head(n_instructions).index.tolist() diff --git a/judgearena/prompts/system-prompt.txt b/judgearena/prompts/system-prompt.txt index 41cb930..41b078f 100644 --- a/judgearena/prompts/system-prompt.txt +++ b/judgearena/prompts/system-prompt.txt @@ -1,3 +1 @@ -You are a highly efficient assistant, who evaluates and selects the best large language model based on the quality of their responses to a given instruction. -You will be shown one instruction and the output of Assistant A and Assistant B and will have to decide which one was best. -Make sure to not over-confidently prefer one assistant or the other and also make sure to not bias your preference based on the ordering or on the length of the answers. +You are a highly efficient assistant, who evaluates and ranks large language models (LLMs) based on the quality of their responses to given prompts. This process will create a leaderboard reflecting the most accurate and human-preferred answers. From f33f19172476bbcfbe7050261f2aff5176b32a97 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 28 Apr 2026 15:05:31 +0200 Subject: [PATCH 24/28] roll back to the default system prompt --- judgearena/prompts/system-prompt.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/judgearena/prompts/system-prompt.txt b/judgearena/prompts/system-prompt.txt index 41b078f..41cb930 100644 --- a/judgearena/prompts/system-prompt.txt +++ b/judgearena/prompts/system-prompt.txt @@ -1 +1,3 @@ -You are a highly efficient assistant, who evaluates and ranks large language models (LLMs) based on the quality of their responses to given prompts. This process will create a leaderboard reflecting the most accurate and human-preferred answers. +You are a highly efficient assistant, who evaluates and selects the best large language model based on the quality of their responses to a given instruction. +You will be shown one instruction and the output of Assistant A and Assistant B and will have to decide which one was best. +Make sure to not over-confidently prefer one assistant or the other and also make sure to not bias your preference based on the ordering or on the length of the answers. From e21639e2442618f0daa17c899bd52de27d3d7846 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Tue, 28 Apr 2026 15:07:20 +0200 Subject: [PATCH 25/28] update dependencies for qwen3.5 and gemma4 runs --- pyproject.toml | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1318c6f..8dfd62d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,12 +61,28 @@ exclude = ["slurmpilot_scripts*"] [dependency-groups] dev = [ - "llmcompressor>=0.4.0", "pre-commit>=4.5.1", "pytest>=8.4.2", "ruff>=0.11.0", "slurmpilot @ git+https://github.com/geoalgo/slurmpilot.git@main", ] +# `llmcompressor` pins older `compressed-tensors` / `transformers` that clash +# with the `vllm` extra (vLLM 0.19.1 pulls `compressed-tensors==0.15.0.1` and +# `transformers>=5.5.1` for Gemma-4). Keep it available under its own group so +# quantization workflows can still install it, but mark it mutually exclusive +# with the `vllm` extra via `[tool.uv] conflicts` so the universal lock can +# still resolve both sides. +quantize = [ + "llmcompressor>=0.4.0", +] + +[tool.uv] +conflicts = [ + [ + { extra = "vllm" }, + { group = "quantize" }, + ], +] [tool.ruff] target-version = "py312" @@ -82,7 +98,14 @@ quote-style = "double" indent-style = "space" [project.optional-dependencies] -# vLLM on PyPI pins transformers<5; optional extra matches that so `uv lock` can resolve. -# JudgeArena relies on v0.19+ for Qwen3.5 thinking_token_budget support and FP8 fixes. -vllm = ["vllm>=0.19.0,<1.0.0", "transformers>=4.56.0,<5.0.0"] +# vLLM 0.19.1 is the pinned judge/battle runtime on this branch. Its own +# Requires-Dist is `transformers!=5.0.*,...,!=5.5.0,>=4.56.0`, so the +# only sub-range >=5.x that vLLM 0.19.1 accepts is >=5.5.1 — hence the +# lower bound here. Qwen3.5 loading works under either 4.56.x or 5.5.x +# because vLLM ships its own `qwen3_5` config shim and model impl and +# registers them into `AutoConfig` at runtime. +vllm = [ + "vllm>=0.19.1,<1.0.0", + "transformers>=5.5.1,<6.0.0", +] llamacpp = ["llama-cpp-python>=0.3.0"] From 41d925ee073720b4bc71c70bb463164835528d91 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 29 Apr 2026 13:37:06 +0200 Subject: [PATCH 26/28] Improve pairwise benchmark run controls and accounting Add generation-only runs, turn-1 thinking cleanup for MT-Bench carryover, native baseline resolution across pairwise tasks, and API-reported token usage accounting for OpenRouter-style models. Add vLLM init retries to clearly transient CUDA startup failures. --- judgearena/cli.py | 12 +- judgearena/cli_common.py | 34 ++++ judgearena/generate.py | 24 ++- judgearena/generate_and_evaluate.py | 62 +++++- judgearena/instruction_dataset/__init__.py | 5 +- judgearena/instruction_dataset/arena_hard.py | 13 +- judgearena/mt_bench/mt_bench_utils.py | 37 +++- judgearena/openrouter_reference_pricing.py | 42 ++++ judgearena/utils.py | 122 +++++++++++- tests/test_cli.py | 31 ++- .../test_generate_and_evaluate_arena_hard.py | 41 +++- tests/test_mt_bench_downloads.py | 4 +- tests/test_openrouter_reference_pricing.py | 121 ++++++++++++ tests/test_strip_thinking_carryover.py | 185 ++++++++++++++++++ tests/test_utils.py | 148 ++++++++++++++ 15 files changed, 852 insertions(+), 29 deletions(-) create mode 100644 tests/test_strip_thinking_carryover.py diff --git a/judgearena/cli.py b/judgearena/cli.py index 3d22db3..d8f0408 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -17,7 +17,7 @@ ) from judgearena.estimate_elo_ratings import CliEloArgs from judgearena.estimate_elo_ratings import main as main_elo -from judgearena.generate_and_evaluate import CliArgs +from judgearena.generate_and_evaluate import CliArgs, native_pairwise_baseline from judgearena.generate_and_evaluate import main as main_generate_and_evaluate from judgearena.log import configure_logging, get_logger @@ -199,6 +199,8 @@ def _build_elo_args( judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, + strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, + skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, @@ -218,8 +220,10 @@ def _build_elo_args( def _build_generate_and_evaluate_args( args: argparse.Namespace, task: str, model_a: str | None ) -> CliArgs: - if model_a is None: - raise SystemExit(f"--model_A is required for task {task!r}.") + if model_a is None or ( + args.model_B is None and native_pairwise_baseline(task) is None + ): + raise SystemExit(f"--model_A and --model_B are required for task {task!r}.") return CliArgs( task=task, model_A=model_a, @@ -233,6 +237,8 @@ def _build_generate_and_evaluate_args( judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, + strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, + skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 6a33ce5..112cc58 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -27,6 +27,8 @@ class BaseCliArgs: judge_prompt_preset: str = "default" battle_thinking_token_budget: int | None = None strip_thinking_before_judging: bool = False + strip_thinking_in_turn_1_carryover: bool = True + skip_judging: bool = False truncate_all_input_chars: int = 8192 truncate_judge_input_chars: int | None = None max_out_tokens_models: int = 32768 @@ -162,6 +164,38 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: "before sending them to the judge." ), ) + parser.add_argument( + "--strip_thinking_in_turn_1_carryover", + nargs="?", + const=True, + default=True, + type=parse_optional_bool, + help=( + "When building the turn-2 prompt for multi-turn datasets, strip " + "... blocks (or vLLM forced thinking-budget closers) " + "from the turn-1 answer before the character-level truncation fires. " + "Matches what the Qwen3 chat template does internally for historical " + "assistant turns and prevents the turn-1 char cap from landing inside " + "a block and silently destroying the closer. Enabled " + "by default." + ), + ) + parser.add_argument( + "--skip_judging", + nargs="?", + const=True, + default=False, + type=parse_optional_bool, + help=( + "If specified, generate battle-model completions and write a " + "generation-only summary (gen-results-.json with limit_events) " + "but skip judge-model construction and the judging loop entirely. " + "Useful for decoupling expensive paid-judge calls from the cheap " + "local generation phase: run once with --skip_judging=True to " + "materialize the completion cache, inspect cap rates, then rerun " + "with --skip_judging=False to judge from cache." + ), + ) parser.add_argument( "--result_folder", type=str, diff --git a/judgearena/generate.py b/judgearena/generate.py index 97254f7..a25e8ec 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -7,6 +7,7 @@ LimitEventTracker, do_inference, make_model, + strip_thinking_tags_with_metadata, truncate_with_metadata, ) @@ -192,6 +193,7 @@ def generate_multiturn( usage_tracker=None, usage_phase: str | None = None, limit_event_tracker: LimitEventTracker | None = None, + strip_thinking_in_turn_1_carryover: bool = True, **model_kwargs, ) -> pd.DataFrame: """Generate two-turn completions for MT-Bench style questions.""" @@ -283,6 +285,7 @@ def generate_multiturn( turn2_turn1_truncated: list[bool] = [] turn2_answer_truncated: list[bool] = [] turn2_prompt_truncated: list[bool] = [] + turn2_turn1_answer_thinking_stripped: list[bool] = [] for (question_id, row), t1_answer in zip( questions.iterrows(), completions_turn_1, strict=True ): @@ -290,6 +293,7 @@ def generate_multiturn( turn2_turn1_truncated.append(False) turn2_answer_truncated.append(False) turn2_prompt_truncated.append(False) + turn2_turn1_answer_thinking_stripped.append(False) turn2_inputs.append( turn1_template.invoke({"user_prompt": "No follow-up question."}) ) @@ -312,8 +316,23 @@ def generate_multiturn( case_id=question_id, model_spec=model, ) + # Strip ... from the turn-1 answer before the + # character cap fires. Mirrors what the Qwen3 chat template does + # natively for historical assistant turns; applying it here + # ensures a 30K-char cap lands on the visible answer rather than + # deep inside a runaway reasoning block, which would silently + # destroy the closer and force the whole thinking + # fragment into the turn-2 prompt. + t1_answer_str = str(t1_answer) + if strip_thinking_in_turn_1_carryover: + t1_answer_str, thinking_stripped = strip_thinking_tags_with_metadata( + t1_answer_str + ) + else: + thinking_stripped = False + turn2_turn1_answer_thinking_stripped.append(thinking_stripped) truncated_turn_1_answer, answer_was_truncated = truncate_with_metadata( - str(t1_answer), + t1_answer_str, max_len=truncate_input_chars, tracker=limit_event_tracker, kind="generation_input_char_truncation", @@ -388,6 +407,9 @@ def generate_multiturn( "generation_turn_1_hit_token_limit": turn1_hit_token_limit, "generation_turn_2_turn_1_prompt_truncated": turn2_turn1_truncated, "generation_turn_2_turn_1_answer_truncated": turn2_answer_truncated, + "generation_turn_2_turn_1_answer_thinking_stripped": ( + turn2_turn1_answer_thinking_stripped + ), "generation_turn_2_prompt_truncated": turn2_prompt_truncated, "generation_turn_2_finish_reason": [ metadata_row.get("finish_reason") for metadata_row in turn2_metadata diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index a9fbc3a..edec74a 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -24,10 +24,15 @@ from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions from judgearena.instruction_dataset.arena_hard import ( - arena_hard_native_baseline, + ARENA_HARD_BASELINES, download_arena_hard, is_arena_hard_dataset, ) +from judgearena.instruction_dataset.m_arenahard import ( + M_ARENA_HARD_BASELINES, + split_m_arena_hard_dataset, +) +from judgearena.instruction_dataset.mt_bench import MT_BENCH_BASELINES from judgearena.judge_prompt_presets import DEFAULT_JUDGE_PROMPT_PRESET from judgearena.log import ( attach_file_handler, @@ -55,6 +60,17 @@ logger = get_logger(__name__) +ALPACA_EVAL_BASELINES: dict[str, str] = { + "alpaca-eval": "gpt4_1106_preview", +} + +PAIRWISE_BASELINES: dict[str, str | Mapping[str, str]] = { + **ALPACA_EVAL_BASELINES, + **ARENA_HARD_BASELINES, + **M_ARENA_HARD_BASELINES, + **MT_BENCH_BASELINES, +} + def try_load_dataset_completions( dataset: str, model: str, n_instructions: int | None @@ -158,6 +174,8 @@ def parse_args(cls): judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, + strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, + skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, max_out_tokens_models=args.max_out_tokens_models, @@ -229,12 +247,12 @@ def _resolve_baseline_plan( """ if args.model_B is not None: return BaselinePlan.flat(args.model_B, index=instructions_df.index) - if not is_arena_hard_dataset(args.task): + native = native_pairwise_baseline(args.task) + if native is None: raise ValueError( - f"--model_B is required for dataset '{args.task}'; only Arena-Hard " - "datasets ship a dataset-native baseline." + f"--model_B is required for dataset '{args.task}'; no dataset-native " + "baseline is registered." ) - native = arena_hard_native_baseline(args.task) if isinstance(native, str): return BaselinePlan.flat(native, index=instructions_df.index) if isinstance(native, Mapping): @@ -257,6 +275,17 @@ def _resolve_baseline_plan( raise ValueError(f"Unsupported baseline shape for dataset '{args.task}'.") +def native_pairwise_baseline(task: str) -> str | Mapping[str, str] | None: + """Return the dataset-native pairwise baseline, if the task defines one.""" + if task in PAIRWISE_BASELINES: + return PAIRWISE_BASELINES[task] + parsed_m_arena_hard = split_m_arena_hard_dataset(task) + if parsed_m_arena_hard is not None: + version_key, _lang_or_subset = parsed_m_arena_hard + return PAIRWISE_BASELINES[version_key] + return None + + def load_contexts(dataset: str) -> pd.Series: path = data_root / "contexts" / dataset return pd.read_csv(path).loc[:, "instruction"] @@ -482,6 +511,29 @@ def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Serie baseline_plan.display_name, completions_B.values[0], ) + if args.skip_judging: + with open(res_folder / f"args-{name}.json", "w") as f: + json.dump(asdict(args), f, indent=2) + generation_summary = { + "task": args.task, + "model_A": args.model_A, + "model_B": baseline_plan.display_name, + "baseline_assignment": "per-row" if not baseline_plan.is_flat else "flat", + "baseline_models": baseline_plan.unique_models, + "judge_model": args.judge_model, + "n_instructions": n_instructions, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "limit_events": limit_event_tracker.build_summary(), + "skip_judging": True, + } + with open(res_folder / f"gen-results-{name}.json", "w") as f: + json.dump(_to_jsonable(generation_summary), f, indent=2, allow_nan=False) + logger.info( + "skip_judging=True: wrote gen-results-%s.json and returning before judge construction.", + name, + ) + return None logger.info("Evaluating completions with judge %s.", args.judge_model) judge_chat_model = make_model( diff --git a/judgearena/instruction_dataset/__init__.py b/judgearena/instruction_dataset/__init__.py index 56920ff..7d3185c 100644 --- a/judgearena/instruction_dataset/__init__.py +++ b/judgearena/instruction_dataset/__init__.py @@ -9,7 +9,6 @@ split_m_arena_hard_dataset, ) from judgearena.log import get_logger -from judgearena.utils import data_root, download_hf, read_df logger = get_logger(__name__) @@ -21,6 +20,8 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat df_instructions = load_mt_bench() elif (parsed := split_m_arena_hard_dataset(dataset)) is not None: + from judgearena.utils import data_root + version_key, lang_or_subset = parsed logger.info( "Loading %s with language specification set to %s", @@ -47,6 +48,8 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.Dat "arena-hard-v0.1", "arena-hard-v2.0", ] + from judgearena.utils import data_root, download_hf, read_df + local_path_tables = data_root / "tables" if is_arena_hard_dataset(dataset): download_arena_hard(dataset=dataset, local_tables_path=local_path_tables) diff --git a/judgearena/instruction_dataset/arena_hard.py b/judgearena/instruction_dataset/arena_hard.py index 4dc70a4..5a59db0 100644 --- a/judgearena/instruction_dataset/arena_hard.py +++ b/judgearena/instruction_dataset/arena_hard.py @@ -8,9 +8,16 @@ ARENA_HARD_HF_REPO_ID = "lmarena-ai/arena-hard-auto" # Mirrors upstream's `JUDGE_SETTINGS` baseline assignment in -# `arena-hard-auto/utils/judge_utils.py`: v0.1 has a single flat baseline, -# v2.0 routes per question category. `is_arena_hard_dataset` and the -# dispatcher in `generate_and_evaluate.py` key off this map. +# `arena-hard-auto/utils/judge_utils.py` verbatim: v0.1 has a single flat +# baseline, v2.0 routes per question category. `is_arena_hard_dataset` and +# the dispatcher in `generate_and_evaluate.py` key off this map. +# +# Note: the released v2.0 `question.jsonl` only tags rows as `hard_prompt` +# (500) or `creative_writing` (250); `coding` and `math` are inert keys +# upstream ships for forward compatibility (no question carries those +# labels, so the dispatcher never looks them up). We keep them so any +# future re-tagging upstream lights up automatically without a code +# change here. ARENA_HARD_BASELINES: dict[str, str | Mapping[str, str]] = { "arena-hard-v0.1": "gpt-4-0314", "arena-hard-v2.0": { diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 78e921c..4b3a505 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -95,6 +95,7 @@ def _mt_bench_generation_cache_name(args: CliArgs, *, model_name: str) -> str: "max_model_len": args.max_model_len, "chat_template": args.chat_template, "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_in_turn_1_carryover": args.strip_thinking_in_turn_1_carryover, "engine_kwargs": _build_mt_bench_generation_kwargs( args=args, model_spec=model_name ), @@ -142,6 +143,9 @@ def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: usage_tracker=usage_tracker, usage_phase=usage_phase, limit_event_tracker=limit_event_tracker, + strip_thinking_in_turn_1_carryover=( + args.strip_thinking_in_turn_1_carryover + ), **_build_mt_bench_generation_kwargs(args=args, model_spec=model_name), ) @@ -305,8 +309,8 @@ def run_mt_bench( args: CliArgs, ignore_cache: bool, *, - res_folder: Path, - result_name: str, + res_folder: Path | None = None, + result_name: str | None = None, ): """MT-Bench pipeline with FastChat-compatible pairwise judging.""" run_started_at = datetime.now(UTC) @@ -325,6 +329,11 @@ def run_mt_bench( f"--model_B is required for dataset '{args.task}'; " "no dataset-native baseline registered." ) + if result_name is None: + result_name = _build_mt_bench_result_name(args, suffix="mtbench") + if res_folder is None: + res_folder = Path(args.result_folder) / result_name + res_folder.mkdir(parents=True, exist_ok=True) if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: logger.info( "MT-Bench judge prompts require room for budgeted thinking, the " @@ -346,6 +355,30 @@ def run_mt_bench( usage_tracker=usage_tracker, limit_event_tracker=limit_event_tracker, ) + if args.skip_judging: + res_folder.mkdir(parents=True, exist_ok=True) + with open(res_folder / f"args-{result_name}.json", "w") as f: + json.dump(_to_jsonable(asdict(args)), f, indent=2, allow_nan=False) + generation_summary = { + "task": args.task, + "model_A": args.model_A, + "model_B": args.model_B, + "judge_model": args.judge_model, + "n_instructions": args.n_instructions + if args.n_instructions is not None + else len(questions_df), + "battle_thinking_token_budget": args.battle_thinking_token_budget, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "limit_events": limit_event_tracker.build_summary(), + "skip_judging": True, + } + with open(res_folder / f"gen-results-{result_name}.json", "w") as f: + json.dump(_to_jsonable(generation_summary), f, indent=2, allow_nan=False) + logger.info( + "skip_judging=True: wrote gen-results-%s.json and returning before judge model construction.", + result_name, + ) + return None effective_judge_max_model_len = args.effective_judge_max_model_len() if ( effective_judge_max_model_len is not None diff --git a/judgearena/openrouter_reference_pricing.py b/judgearena/openrouter_reference_pricing.py index cf13820..a1bf052 100644 --- a/judgearena/openrouter_reference_pricing.py +++ b/judgearena/openrouter_reference_pricing.py @@ -163,6 +163,48 @@ def record_batch_from_model( ) return True + def record_batch_from_usage_metadata( + self, + *, + phase: str, + model_spec: str, + usages: list[dict[str, Any] | None], + ) -> bool: + """Record API-reported token usage extracted from LangChain AIMessages. + + Each entry in ``usages`` is either ``None`` (no usage data) or a dict + carrying ``input_tokens``/``output_tokens`` (langchain-core shape) or + ``prompt_tokens``/``completion_tokens`` (OpenAI-shape). This is the + path used for OpenRouter / ChatOpenAI-backed models, which do not + expose ``count_*_batch`` helpers but do return per-call usage from + the upstream API. + + Returns ``True`` if at least one entry produced a record, else + ``False`` so callers can fall back to ``record_batch_from_model``. + """ + recorded = False + for usage in usages: + if not isinstance(usage, dict): + continue + prompt = usage.get("input_tokens") + if prompt is None: + prompt = usage.get("prompt_tokens") + completion = usage.get("output_tokens") + if completion is None: + completion = usage.get("completion_tokens") + if prompt is None and completion is None: + continue + self._records.append( + TokenUsageRecord( + phase=phase, + model_spec=model_spec, + prompt_tokens=int(prompt or 0), + completion_tokens=int(completion or 0), + ) + ) + recorded = True + return recorded + def _parse_catalog_model(raw_model: dict[str, Any]) -> OpenRouterModelEntry: raw_pricing = raw_model.get("pricing") or {} diff --git a/judgearena/utils.py b/judgearena/utils.py index 02f59eb..09253db 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -388,6 +388,39 @@ def _extract_ai_message_metadata(result: object) -> dict[str, Any]: return {"finish_reason": finish_reason, "stop_reason": stop_reason} +def _extract_token_usage(result: object) -> dict[str, int] | None: + """Pull API-reported token usage from a LangChain AIMessage-like result. + + Two shapes coexist depending on langchain-openai version: + - langchain-core AIMessage.usage_metadata: ``{"input_tokens", "output_tokens", "total_tokens"}`` + - response_metadata.token_usage (OpenAI-shape): ``{"prompt_tokens", "completion_tokens", "total_tokens"}`` + + Returns the first shape that carries non-null counts, or ``None`` if neither + is present (e.g. provider that does not surface usage). Used by + ``OpenRouterReferencePricingTracker.record_batch_from_usage_metadata`` to + capture per-call billing tokens for OpenRouter / ChatOpenAI runs, which + cannot be tokenised via ``count_*_batch`` helpers. + """ + usage_metadata = getattr(result, "usage_metadata", None) + if isinstance(usage_metadata, dict) and ( + usage_metadata.get("input_tokens") is not None + or usage_metadata.get("output_tokens") is not None + ): + return dict(usage_metadata) + response_metadata = getattr(result, "response_metadata", None) or {} + token_usage = ( + response_metadata.get("token_usage") + if isinstance(response_metadata, dict) + else None + ) + if isinstance(token_usage, dict) and ( + token_usage.get("prompt_tokens") is not None + or token_usage.get("completion_tokens") is not None + ): + return dict(token_usage) + return None + + def do_inference( chat_model, inputs, @@ -407,7 +440,14 @@ def do_inference( if use_tqdm: # perform inference asynchronously to be able to update tqdm, chat_model.batch does not work as it blocks until # all requests are received + # JUDGEARENA_JUDGE_MAX_CONCURRENCY caps simultaneous in-flight ainvokes + # (e.g. against OpenRouter). Unset = unbounded, preserving prior behaviour. + cap_raw = os.environ.get("JUDGEARENA_JUDGE_MAX_CONCURRENCY") + cap = int(cap_raw) if cap_raw and int(cap_raw) > 0 else None + async def process_with_real_progress(chat_model, inputs, pbar): + sem = asyncio.Semaphore(cap) if cap else None + async def process_single(input_item, max_retries=5, base_delay=1.0): for attempt in range(max_retries): try: @@ -427,8 +467,14 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): ) await asyncio.sleep(delay) + async def gated(inp): + if sem is None: + return await process_single(inp) + async with sem: + return await process_single(inp) + # asyncio.gather preserves order (unlike as_completed) - results = await asyncio.gather(*[process_single(inp) for inp in inputs]) + results = await asyncio.gather(*[gated(inp) for inp in inputs]) return results with logging_redirect_tqdm(), tqdm(total=len(inputs)) as pbar: @@ -437,8 +483,9 @@ async def process_single(input_item, max_retries=5, base_delay=1.0): chat_model=chat_model, inputs=inputs, pbar=pbar ) ) - if return_metadata: - metadata = [_extract_ai_message_metadata(r) for r in res] + # Always materialize metadata; it is cheap and keeps return_metadata + # behavior consistent with the batch path. + metadata = [_extract_ai_message_metadata(r) for r in res] else: def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): @@ -488,6 +535,11 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): res, metadata = batch_with_retry(inputs) + # Pull per-call usage from AIMessage objects BEFORE flattening to .content; + # OpenRouter / ChatOpenAI surface API-billed token counts here but lose + # them after .content extraction. + per_call_usages = [_extract_token_usage(r) for r in res] + # Not sure why the API of Langchain returns sometime a string and sometimes an AIMessage object # is it because of using Chat and barebones models? # when using OpenAI, the output is AIMessage not a string... @@ -498,13 +550,19 @@ def batch_with_retry(batch_inputs, max_retries=5, base_delay=1.0): and usage_model_spec is not None ): try: - usage_tracker.record_batch_from_model( + recorded = usage_tracker.record_batch_from_usage_metadata( phase=usage_phase, model_spec=usage_model_spec, - chat_model=chat_model, - inputs=list(inputs), - outputs=res, + usages=per_call_usages, ) + if not recorded: + usage_tracker.record_batch_from_model( + phase=usage_phase, + model_spec=usage_model_spec, + chat_model=chat_model, + inputs=list(inputs), + outputs=res, + ) except Exception as e: print( f"Warning: failed to record token usage for phase " @@ -530,6 +588,52 @@ async def ainvoke(self, input, **invoke_kwargs): return self.message +_VLLM_INIT_RETRY_SIGNATURES = ( + "cudaErrorDevicesUnavailable", + "CUDA-capable device(s) is/are busy or unavailable", + "CUDA error: initialization error", +) +_VLLM_INIT_MAX_ATTEMPTS = int(os.getenv("JUDGEARENA_VLLM_INIT_MAX_ATTEMPTS", "4")) +_VLLM_INIT_BACKOFF_SECONDS = int( + os.getenv("JUDGEARENA_VLLM_INIT_BACKOFF_SECONDS", "20") +) + + +def _init_llm_with_retry(llm_cls, **kwargs): + """Instantiate ``vllm.LLM`` with retries on transient GPU-init races. + + On shared Slurm nodes with MaxGRESPerAccount throttling, freshly scheduled + jobs can hit ``cudaErrorDevicesUnavailable`` because the previous tenant's + driver cleanup has not finished when our process starts. This manifests as + an immediate engine-core init failure and is almost always resolved by a + 15-30 s sleep + retry on the same GPU. We retry up to + ``JUDGEARENA_VLLM_INIT_MAX_ATTEMPTS`` times with exponential backoff before + giving up, which keeps persistent configuration errors from looping. + """ + last_exc: Exception | None = None + for attempt in range(1, _VLLM_INIT_MAX_ATTEMPTS + 1): + try: + return llm_cls(**kwargs) + except Exception as exc: + message = f"{type(exc).__name__}: {exc}" + if not any(sig in message for sig in _VLLM_INIT_RETRY_SIGNATURES): + raise + last_exc = exc + if attempt == _VLLM_INIT_MAX_ATTEMPTS: + break + delay = _VLLM_INIT_BACKOFF_SECONDS * (2 ** (attempt - 1)) + warnings.warn( + f"vLLM init attempt {attempt}/{_VLLM_INIT_MAX_ATTEMPTS} failed " + f"with transient GPU-init signature ({message.splitlines()[0]}); " + f"sleeping {delay}s before retry.", + RuntimeWarning, + stacklevel=2, + ) + time.sleep(delay) + assert last_exc is not None + raise last_exc + + class ChatVLLM: """VLLM wrapper that auto-detects whether to use chat() or generate(). @@ -652,7 +756,9 @@ def __init__( ) self.sampling_params = SamplingParams(**self._sampling_params_kwargs) - self.llm = LLM(model=model, trust_remote_code=True, **vllm_kwargs) + self.llm = _init_llm_with_retry( + LLM, model=model, trust_remote_code=True, **vllm_kwargs + ) self.tokenizer = self.llm.get_tokenizer() # Resolve chat template: diff --git a/tests/test_cli.py b/tests/test_cli.py index 30be4fa..5d905ea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -248,12 +248,12 @@ def test_unknown_elo_task_raises(capture_mains): ) -def test_generate_and_evaluate_requires_model_a_and_b(capture_mains): +def test_pairwise_task_without_native_baseline_requires_model_a_and_b(capture_mains): with pytest.raises(SystemExit, match="--model_A and --model_B are required"): cli_module.cli( [ "--task", - "alpaca-eval", + "fluency-french", "--model_A", "Dummy/A", "--judge", @@ -262,6 +262,33 @@ def test_generate_and_evaluate_requires_model_a_and_b(capture_mains): ) +@pytest.mark.parametrize( + "task", + [ + "alpaca-eval", + "m-arena-hard-v2.0-EU", + ], +) +def test_pairwise_task_allows_missing_model_b_when_native_baseline_exists( + capture_mains, task: str +): + cli_module.cli( + [ + "--task", + task, + "--model_A", + "Dummy/A", + "--judge", + "Dummy/J", + ] + ) + assert capture_mains["module"] == "generate_and_evaluate" + ge_args: CliArgs = capture_mains["args"] + assert ge_args.task == task + assert ge_args.model_A == "Dummy/A" + assert ge_args.model_B is None + + def test_deprecated_model_flag_routes_into_pairwise_task(capture_mains): """`--model` is a deprecated alias for `--model_A` even on pairwise tasks.""" with pytest.warns(DeprecationWarning, match="--model is deprecated"): diff --git a/tests/test_generate_and_evaluate_arena_hard.py b/tests/test_generate_and_evaluate_arena_hard.py index 1dd51ff..aa184a9 100644 --- a/tests/test_generate_and_evaluate_arena_hard.py +++ b/tests/test_generate_and_evaluate_arena_hard.py @@ -2,15 +2,18 @@ import pytest from judgearena.generate_and_evaluate import ( + ALPACA_EVAL_BASELINES, + PAIRWISE_BASELINES, BaselinePlan, CliArgs, _resolve_baseline_plan, + native_pairwise_baseline, ) def _make_args(dataset, model_b=None): return CliArgs( - dataset=dataset, + task=dataset, model_A="A", model_B=model_b, judge_model="J", @@ -46,6 +49,38 @@ def test_resolve_plan_v20_routes_per_category(): assert plan.baseline_by_index.loc["qc"] == "gemini-2.0-flash-001" +def test_resolve_plan_alpaca_eval_uses_native_baseline(): + plan = _resolve_baseline_plan( + args=_make_args("alpaca-eval"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "gpt4_1106_preview" + + +def test_native_pairwise_baseline_mapping_covers_flat_tasks(): + assert ALPACA_EVAL_BASELINES == {"alpaca-eval": "gpt4_1106_preview"} + assert PAIRWISE_BASELINES["alpaca-eval"] == "gpt4_1106_preview" + assert native_pairwise_baseline("alpaca-eval") == "gpt4_1106_preview" + assert native_pairwise_baseline("mt-bench") == "gpt-4" + + +def test_resolve_plan_m_arena_hard_uses_native_baseline(): + plan = _resolve_baseline_plan( + args=_make_args("m-arena-hard-v2.0-EU"), + instructions_df=_instructions(["q1", "q2"]), + ) + assert plan.is_flat + assert plan.single_model == "google/gemini-2.5-flash" + + +def test_native_pairwise_baseline_resolves_m_arena_hard_variants(): + assert ( + native_pairwise_baseline("m-arena-hard-v0.1-uk") == "CohereLabs/aya-expanse-8b" + ) + assert native_pairwise_baseline("m-arena-hard-v2.0-EU") == "google/gemini-2.5-flash" + + def test_resolve_plan_explicit_model_b_overrides_native(): plan = _resolve_baseline_plan( args=_make_args("arena-hard-v2.0", model_b="override"), @@ -57,10 +92,10 @@ def test_resolve_plan_explicit_model_b_overrides_native(): assert plan.single_model == "override" -def test_resolve_plan_non_arena_hard_requires_model_b(): +def test_resolve_plan_task_without_native_baseline_requires_model_b(): with pytest.raises(ValueError, match="model_B"): _resolve_baseline_plan( - args=_make_args("alpaca-eval"), + args=_make_args("fluency-french"), instructions_df=_instructions(["q1"]), ) diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 377b435..3c50429 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -24,7 +24,7 @@ def _mt_bench_args( ``effective_judge_*`` fallback helpers instead of a duplicate shim. """ args = BaseCliArgs(**base_overrides) - args.dataset = dataset + args.task = dataset args.model_A = model_A args.model_B = model_B args.use_tqdm = use_tqdm @@ -155,6 +155,7 @@ def fake_generate_multiturn( usage_tracker, usage_phase, limit_event_tracker, + strip_thinking_in_turn_1_carryover, **engine_kwargs, ): generated_models.append(model) @@ -194,6 +195,7 @@ def fake_generate_multiturn( max_model_len=16384, chat_template=None, battle_thinking_token_budget=None, + strip_thinking_in_turn_1_carryover=True, engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, ) diff --git a/tests/test_openrouter_reference_pricing.py b/tests/test_openrouter_reference_pricing.py index 1dbe0c2..c26ce48 100644 --- a/tests/test_openrouter_reference_pricing.py +++ b/tests/test_openrouter_reference_pricing.py @@ -47,6 +47,127 @@ def test_do_inference_records_token_usage(): ] +class _FakeAIMessage: + """Minimal AIMessage stand-in: .content + langchain-core .usage_metadata.""" + + def __init__(self, content: str, usage_metadata: dict[str, int] | None) -> None: + self.content = content + self.usage_metadata = usage_metadata + self.response_metadata: dict[str, object] = {} + + +class _OpenRouterShapeModel: + """Mimics ChatOpenAI: returns AIMessage objects with usage_metadata, no + count_*_batch helpers (so the fallback tokeniser path is unavailable).""" + + def __init__(self, usages: list[dict[str, int] | None]) -> None: + self._usages = usages + + def batch(self, inputs, **invoke_kwargs): + return [ + _FakeAIMessage(content=f"output-{idx}", usage_metadata=self._usages[idx]) + for idx, _ in enumerate(inputs) + ] + + +def test_do_inference_records_usage_metadata_for_openrouter_shape_models(): + """For OpenRouter ChatOpenAI calls, the tracker must record API-reported + token counts pulled from AIMessage.usage_metadata; ``record_batch_from_model`` + is a no-op because ChatOpenAI lacks ``count_*_batch`` helpers.""" + tracker = pricing.OpenRouterReferencePricingTracker() + model = _OpenRouterShapeModel( + usages=[ + {"input_tokens": 1234, "output_tokens": 17, "total_tokens": 1251}, + {"input_tokens": 800, "output_tokens": 42, "total_tokens": 842}, + ] + ) + + outputs = do_inference( + chat_model=model, + inputs=["prompt-1", "prompt-2"], + usage_tracker=tracker, + usage_phase="judge", + usage_model_spec="OpenRouter/google/gemma-4-31b-it", + ) + + assert outputs == ["output-0", "output-1"] + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=1234, + completion_tokens=17, + requests=1, + ), + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=800, + completion_tokens=42, + requests=1, + ), + ] + + +def test_do_inference_falls_back_to_count_batch_when_usage_metadata_missing(): + """When AIMessage results carry no ``usage_metadata`` (e.g. local vLLM path), + ``do_inference`` must fall back to ``record_batch_from_model`` and use the + chat_model's ``count_*_batch`` helpers.""" + tracker = pricing.OpenRouterReferencePricingTracker() + model = CountingModel() + + do_inference( + chat_model=model, + inputs=["abc", "de"], + usage_tracker=tracker, + usage_phase="generation_model_A", + usage_model_spec="VLLM/org/model", + ) + + assert [(r.prompt_tokens, r.completion_tokens) for r in tracker.records] == [ + (3, 8), + (2, 8), + ] + + +def test_record_batch_from_usage_metadata_returns_false_on_all_none(): + """A batch where every entry is ``None`` must signal "no records added" so + the caller can fall through to the tokeniser-based path.""" + tracker = pricing.OpenRouterReferencePricingTracker() + + recorded = tracker.record_batch_from_usage_metadata( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + usages=[None, None, None], + ) + + assert recorded is False + assert tracker.records == [] + + +def test_record_batch_from_usage_metadata_accepts_openai_shape_keys(): + """OpenAI-shape keys (``prompt_tokens``/``completion_tokens``) appear on + ``response_metadata.token_usage`` for older langchain-openai versions.""" + tracker = pricing.OpenRouterReferencePricingTracker() + + recorded = tracker.record_batch_from_usage_metadata( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + usages=[{"prompt_tokens": 100, "completion_tokens": 25, "total_tokens": 125}], + ) + + assert recorded is True + assert tracker.records == [ + pricing.TokenUsageRecord( + phase="judge", + model_spec="OpenRouter/google/gemma-4-31b-it", + prompt_tokens=100, + completion_tokens=25, + requests=1, + ) + ] + + def test_build_reference_pricing_summary_uses_exact_match_and_reports_partial_cost( monkeypatch, ): diff --git a/tests/test_strip_thinking_carryover.py b/tests/test_strip_thinking_carryover.py new file mode 100644 index 0000000..f33abed --- /dev/null +++ b/tests/test_strip_thinking_carryover.py @@ -0,0 +1,185 @@ +"""Tests for the strip-thinking-before-char-cap fix applied to turn-1 +answers when constructing the turn-2 prompt in MT-Bench generation. + +Background: ``truncate_all_input_chars`` fires before the chat template +renders the turn-2 prompt. If the turn-1 answer contains a +``...`` block (Qwen3.5, SmolLM3 thinking mode) or a vLLM +forced-budget closer, a char cap landing inside the reasoning span would +destroy the ```` tag and force the full reasoning fragment into the +turn-2 context. Stripping the visible reasoning span first mirrors what +the Qwen3 chat template does natively for historical assistant turns and +keeps the cap on the visible answer. + +These tests pin down the composition that +``judgearena.generate.generate_multiturn`` performs: +``strip_thinking_tags_with_metadata`` -> ``truncate_with_metadata``. +""" + +from __future__ import annotations + +from dataclasses import replace + +from judgearena.cli_common import BaseCliArgs +from judgearena.generate_and_evaluate import CliArgs +from judgearena.mt_bench.mt_bench_utils import _mt_bench_generation_cache_name +from judgearena.utils import ( + VLLM_REASONING_END_STR, + strip_thinking_tags_with_metadata, + truncate_with_metadata, +) + + +def _strip_then_cap( + answer: str, cap: int, *, strip: bool = True +) -> tuple[str, bool, bool]: + """Reproduce the exact sequence inside ``generate_multiturn``'s turn-2 loop.""" + if strip: + stripped_text, thinking_stripped = strip_thinking_tags_with_metadata(answer) + else: + stripped_text, thinking_stripped = answer, False + truncated, was_truncated = truncate_with_metadata(stripped_text, max_len=cap) + return truncated, was_truncated, thinking_stripped + + +def test_well_formed_think_block_is_stripped_before_cap(): + """Nominal Qwen3.5 case: a complete ``...`` wrapper sits + in front of the visible answer. Stripping removes the whole span; the + char cap then applies to the visible answer only.""" + reasoning = "so let me think through this... " * 400 # ~12K chars + visible = "The capital of France is Paris." + answer = f"{reasoning}\n\n{visible}" + # Cap below the reasoning length but above the visible answer length + # so the old behaviour would have clipped inside . + cap = 1024 + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + assert "" not in truncated + assert "" not in truncated + + +def test_vllm_forced_thinking_budget_closer_is_stripped(): + """When the thinking budget is exhausted, vLLM inserts a forced closer + (``VLLM_REASONING_END_STR``) without a paired ```` opener. The + strip helper treats everything up to and including that marker as + reasoning and drops it before the cap fires.""" + forced_reasoning = "step 1... " * 500 # ~5K chars of runaway thought + visible = "Final answer: 42." + answer = f"{forced_reasoning}{VLLM_REASONING_END_STR}{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=256) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + + +def test_dangling_closing_tag_is_stripped(): + """Qwen3.5 sometimes emits ```` without a preceding ```` + opener (e.g. when the opener was chopped off during generation rollover). + The strip helper drops the preamble up to ```` and keeps the + postamble. Without this, the cap would land inside the dangling + preamble and the ```` closer would survive in the turn-2 + context, confusing the chat template.""" + preamble = "leftover reasoning fragment " * 100 + visible = "Answer: yes." + answer = f"{preamble}\n{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=512) + + assert thinking_stripped is True + assert was_truncated is False + assert truncated == visible + + +def test_no_thinking_tags_passthrough(): + """Non-thinking models (e.g. EuroLLM, Apertus) produce answers without + any ```` markers. Strip is a no-op; the cap behaves exactly as + before the fix.""" + visible = "Paris is the capital of France. " * 50 # ~1.6K chars + cap = 512 + + truncated, was_truncated, thinking_stripped = _strip_then_cap(visible, cap) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated == visible[:cap] + + +def test_unclosed_think_block_is_unfixable_by_stripping(): + """Pathological case: the model writes ```` and hits the + generation limit before emitting ````. No ```` tag or + vLLM closer appears anywhere in the output, so the strip helper + returns the text unchanged and the cap still clips inside the + reasoning span. Stripping cannot fix this; the escape hatch is a + larger ``battle_thinking_token_budget``.""" + reasoning = "still reasoning " * 1000 + answer = f"{reasoning}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap(answer, cap=256) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated.startswith("") + + +def test_strip_disabled_reverts_to_pre_fix_behaviour(): + """With ``strip_thinking_in_turn_1_carryover=False`` (the pre-fix + behaviour, kept as a reproduction knob), the cap clips inside the + ```` block and the ```` closer is lost.""" + reasoning = "deep thinking " * 400 + visible = "Short answer." + answer = f"{reasoning}\n{visible}" + + truncated, was_truncated, thinking_stripped = _strip_then_cap( + answer, cap=1024, strip=False + ) + + assert thinking_stripped is False + assert was_truncated is True + assert truncated.startswith("") + assert "" not in truncated + + +def test_default_flag_is_enabled_in_base_cli_args(): + """Guard the default value: the fix ships enabled so existing runs + (including Phase A of the Gemma-4 benchmark) pick it up without a + launcher change.""" + args = BaseCliArgs(judge_model="OpenRouter/google/gemma-4-31b-it") + assert args.strip_thinking_in_turn_1_carryover is True + + +def _make_mt_bench_cli_args(**overrides) -> CliArgs: + args = CliArgs( + judge_model="OpenRouter/google/gemma-4-31b-it", + dataset="mt-bench", + model_A="VLLM/Qwen/Qwen3.5-9B", + model_B="VLLM/Qwen/Qwen3.5-9B", + n_instructions=3, + truncate_all_input_chars=30000, + max_out_tokens_models=49152, + max_model_len=57344, + battle_thinking_token_budget=32768, + ) + return replace(args, **overrides) if overrides else args + + +def test_mt_bench_cache_key_changes_when_flag_flipped(): + """The flag participates in the MT-Bench generation cache key so that + flipping it off to reproduce pre-fix behaviour does not silently reuse + post-fix completions (and vice versa). Without this, a rerun with the + same numeric knobs would reuse stale cache for multi-turn datasets.""" + args_on = _make_mt_bench_cli_args(strip_thinking_in_turn_1_carryover=True) + args_off = _make_mt_bench_cli_args(strip_thinking_in_turn_1_carryover=False) + + key_on = _mt_bench_generation_cache_name(args_on, model_name="VLLM/Qwen/Qwen3.5-9B") + key_off = _mt_bench_generation_cache_name( + args_off, model_name="VLLM/Qwen/Qwen3.5-9B" + ) + + assert key_on != key_off + assert key_on.startswith("mt-bench_VLLM/Qwen/Qwen3.5-9B_3_") + assert key_off.startswith("mt-bench_VLLM/Qwen/Qwen3.5-9B_3_") diff --git a/tests/test_utils.py b/tests/test_utils.py index 508263c..aea4990 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,9 @@ +import asyncio from types import SimpleNamespace from unittest.mock import MagicMock +import pytest + import judgearena.utils as utils from judgearena.utils import make_model @@ -56,6 +59,79 @@ async def fake_ainvoke(_input, **_kwargs): ] +def _build_inflight_tracking_chat_model(*, hold_seconds: float = 0.05): + """Helper: mock chat model whose `ainvoke` records peak concurrent in-flight calls.""" + + state = {"in_flight": 0, "peak": 0} + + async def fake_ainvoke(input_item, **_kwargs): + state["in_flight"] += 1 + state["peak"] = max(state["peak"], state["in_flight"]) + try: + await asyncio.sleep(hold_seconds) + return SimpleNamespace( + content=f"out-{input_item}", + response_metadata={"finish_reason": "stop"}, + ) + finally: + state["in_flight"] -= 1 + + return SimpleNamespace(ainvoke=fake_ainvoke), state + + +def test_do_inference_async_path_respects_concurrency_cap(monkeypatch): + """With JUDGEARENA_JUDGE_MAX_CONCURRENCY=4 and 16 inputs, peak in-flight must stay <= 4.""" + monkeypatch.setenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", "4") + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] <= 4, ( + f"Concurrency cap violated: peak in-flight={state['peak']}, expected <= 4" + ) + assert state["peak"] >= 1 + + +def test_do_inference_async_path_unbounded_when_env_unset(monkeypatch): + """Without JUDGEARENA_JUDGE_MAX_CONCURRENCY set, all 16 calls fire concurrently.""" + monkeypatch.delenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", raising=False) + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] > 4, ( + f"Expected unbounded concurrency to overshoot the capped variant; got peak={state['peak']}" + ) + + +def test_do_inference_async_path_zero_cap_is_unbounded(monkeypatch): + """JUDGEARENA_JUDGE_MAX_CONCURRENCY=0 falls back to unbounded (defensive default).""" + monkeypatch.setenv("JUDGEARENA_JUDGE_MAX_CONCURRENCY", "0") + chat_model, state = _build_inflight_tracking_chat_model() + + inputs = [f"prompt-{i}" for i in range(16)] + results = utils.do_inference( + chat_model=chat_model, + inputs=inputs, + use_tqdm=True, + ) + + assert len(results) == 16 + assert state["peak"] > 4 + + def test_do_inference_batch_path_propagates_finish_reason_without_batch_with_metadata(): batch_results = [ SimpleNamespace( @@ -146,3 +222,75 @@ def test_make_model_openrouter_strips_vllm_only_kwargs(monkeypatch): assert "chat_template" not in model.model_kwargs assert model.max_tokens == 16 assert model.temperature == 0.5 + + +def test_init_llm_with_retry_recovers_from_transient_cuda_error(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 3) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + monkeypatch.setattr(utils.time, "sleep", lambda *_a, **_k: None) + + calls: list[dict] = [] + + def fake_llm(**kwargs): + calls.append(kwargs) + if len(calls) < 3: + raise RuntimeError( + "CUDA error: CUDA-capable device(s) is/are busy or unavailable\n" + "Search for 'cudaErrorDevicesUnavailable' ..." + ) + return "llm" + + result = utils._init_llm_with_retry(fake_llm, model="m", trust_remote_code=True) + assert result == "llm" + assert len(calls) == 3 + + +def test_init_llm_with_retry_gives_up_after_max_attempts(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 2) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + monkeypatch.setattr(utils.time, "sleep", lambda *_a, **_k: None) + + def always_fails(**_kwargs): + raise RuntimeError("cudaErrorDevicesUnavailable") + + with pytest.raises(RuntimeError, match="cudaErrorDevicesUnavailable"): + utils._init_llm_with_retry(always_fails, model="m") + + +def test_init_llm_with_retry_reraises_non_matching_errors_immediately(monkeypatch): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 4) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + + call_count = 0 + + def fails_once(**_kwargs): + nonlocal call_count + call_count += 1 + raise ValueError("bad config") + + with pytest.raises(ValueError, match="bad config"): + utils._init_llm_with_retry(fails_once, model="m") + assert call_count == 1 + + +@pytest.mark.parametrize( + "message", + [ + "CUDA error: unknown error", + "NCCL error", + ], +) +def test_init_llm_with_retry_does_not_retry_broad_runtime_errors(monkeypatch, message): + monkeypatch.setattr(utils, "_VLLM_INIT_MAX_ATTEMPTS", 4) + monkeypatch.setattr(utils, "_VLLM_INIT_BACKOFF_SECONDS", 0) + + call_count = 0 + + def fails_once(**_kwargs): + nonlocal call_count + call_count += 1 + raise RuntimeError(message) + + with pytest.raises(RuntimeError, match=message): + utils._init_llm_with_retry(fails_once, model="m") + assert call_count == 1 From 16dc5e1346392a52277677010a293d041e21404c Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 29 Apr 2026 14:39:09 +0200 Subject: [PATCH 27/28] Clean up judge argument handling Make MT-Bench judge budget floors warning-only, remove implicit judge argument fallbacks, collapse MT-Bench thinking stripping onto the judge stripping flag, and restore default prompt completion-label rendering. --- judgearena/cli.py | 2 - judgearena/cli_common.py | 60 +++---------------- judgearena/estimate_elo_ratings.py | 7 +-- judgearena/generate.py | 4 +- judgearena/generate_and_evaluate.py | 7 +-- judgearena/judge_prompt_presets.py | 9 ++- judgearena/mt_bench/mt_bench_utils.py | 43 ++++++------- .../prompts/prompt-with-explanation.txt | 8 +-- judgearena/prompts/prompt.txt | 8 +-- tests/test_local_completion_loading.py | 30 +++++++--- tests/test_mt_bench_downloads.py | 37 +++++++----- tests/test_strip_thinking_carryover.py | 24 ++------ 12 files changed, 101 insertions(+), 138 deletions(-) diff --git a/judgearena/cli.py b/judgearena/cli.py index d8f0408..c8c778c 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -199,7 +199,6 @@ def _build_elo_args( judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, - strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, @@ -237,7 +236,6 @@ def _build_generate_and_evaluate_args( judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, - strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 112cc58..258a378 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -27,7 +27,6 @@ class BaseCliArgs: judge_prompt_preset: str = "default" battle_thinking_token_budget: int | None = None strip_thinking_before_judging: bool = False - strip_thinking_in_turn_1_carryover: bool = True skip_judging: bool = False truncate_all_input_chars: int = 8192 truncate_judge_input_chars: int | None = None @@ -49,26 +48,6 @@ def __post_init__(self): f"Only {supported_modes} modes are supported but got {self.swap_mode}." ) - def effective_judge_truncation(self) -> int: - """Character cap applied to judge-side inputs (completions, reference, etc.). - - Falls back to the generation-side ``truncate_all_input_chars`` when a - dedicated judge cap is not configured. - """ - if self.truncate_judge_input_chars is not None: - return int(self.truncate_judge_input_chars) - return int(self.truncate_all_input_chars) - - def effective_judge_max_model_len(self) -> int | None: - """Total context window for the judge vLLM instance. - - Falls back to the generation-side ``max_model_len`` when a dedicated - judge context window is not configured. - """ - if self.max_judge_model_len is not None: - return int(self.max_judge_model_len) - return self.max_model_len - def parse_optional_bool(raw: str | None) -> bool: if raw is None: @@ -164,22 +143,6 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: "before sending them to the judge." ), ) - parser.add_argument( - "--strip_thinking_in_turn_1_carryover", - nargs="?", - const=True, - default=True, - type=parse_optional_bool, - help=( - "When building the turn-2 prompt for multi-turn datasets, strip " - "... blocks (or vLLM forced thinking-budget closers) " - "from the turn-1 answer before the character-level truncation fires. " - "Matches what the Qwen3 chat template does internally for historical " - "assistant turns and prevents the turn-1 char cap from landing inside " - "a block and silently destroying the closer. Enabled " - "by default." - ), - ) parser.add_argument( "--skip_judging", nargs="?", @@ -213,9 +176,7 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: default=8192, help=( "Character-level truncation applied to generation-side inputs: " - "truncates each instruction before model A/B generation. When " - "--truncate_judge_input_chars is not set, this value also caps the " - "judge-side inputs (completions, reference, etc.)." + "truncates each instruction before model A/B generation." ), ) parser.add_argument( @@ -225,10 +186,8 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: default=None, help=( "Character cap applied to judge-side inputs (completions, " - "reference, instruction) before judge evaluation. Falls back to " - "--truncate_all_input_chars when not specified. Set much higher " - "than the generation cap to avoid cutting model completions before " - "they reach the judge." + "reference, instruction) before judge evaluation. When omitted, " + "judge inputs are not character-truncated by this CLI setting." ), ) parser.add_argument( @@ -260,8 +219,7 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: "Optional total context window for the battle-generation VLLM " "instances (prompt + generation). Independent from " "--max_out_tokens_models/--max_out_tokens_judge, which only cap " - "generated tokens. When --max_judge_model_len is not set, this " - "value also sizes the judge instance." + "generated tokens." ), ) parser.add_argument( @@ -270,11 +228,11 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: required=False, default=None, help=( - "Optional total context window for the judge VLLM instance. Falls " - "back to --max_model_len when not specified. Set higher than the " - "battle model_len when the judge needs to see longer prompts " - "(e.g. long completions from both A and B) than the battle " - "generator can fit." + "Optional total context window for the judge VLLM instance. When " + "omitted, no judge max_model_len override is passed. Set higher " + "than the battle model_len when the judge needs to see longer " + "prompts (e.g. long completions from both A and B) than the " + "battle generator can fit." ), ) parser.add_argument( diff --git a/judgearena/estimate_elo_ratings.py b/judgearena/estimate_elo_ratings.py index c743ebf..15222fc 100644 --- a/judgearena/estimate_elo_ratings.py +++ b/judgearena/estimate_elo_ratings.py @@ -269,9 +269,8 @@ def replace_slash(s: str) -> str: ] judge_extra_kwargs = {} - effective_judge_max_model_len = args.effective_judge_max_model_len() - if effective_judge_max_model_len is not None: - judge_extra_kwargs["max_model_len"] = effective_judge_max_model_len + if args.max_judge_model_len is not None: + judge_extra_kwargs["max_model_len"] = args.max_judge_model_len if args.chat_template is not None: judge_extra_kwargs["chat_template"] = args.chat_template @@ -288,7 +287,7 @@ def run_judge() -> pd.DataFrame: completions_B=completions_B, swap_mode=args.swap_mode, provide_explanation=args.provide_explanation, - truncate_input_chars=args.effective_judge_truncation(), + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=use_tqdm, ) return pd.DataFrame( diff --git a/judgearena/generate.py b/judgearena/generate.py index a25e8ec..d33c865 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -193,7 +193,7 @@ def generate_multiturn( usage_tracker=None, usage_phase: str | None = None, limit_event_tracker: LimitEventTracker | None = None, - strip_thinking_in_turn_1_carryover: bool = True, + strip_thinking_before_turn_2_prompt: bool = False, **model_kwargs, ) -> pd.DataFrame: """Generate two-turn completions for MT-Bench style questions.""" @@ -324,7 +324,7 @@ def generate_multiturn( # destroy the closer and force the whole thinking # fragment into the turn-2 prompt. t1_answer_str = str(t1_answer) - if strip_thinking_in_turn_1_carryover: + if strip_thinking_before_turn_2_prompt: t1_answer_str, thinking_stripped = strip_thinking_tags_with_metadata( t1_answer_str ) diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index edec74a..7482865 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -174,7 +174,6 @@ def parse_args(cls): judge_prompt_preset=args.judge_prompt_preset, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, - strip_thinking_in_turn_1_carryover=args.strip_thinking_in_turn_1_carryover, skip_judging=args.skip_judging, truncate_all_input_chars=args.truncate_all_input_chars, truncate_judge_input_chars=args.truncate_judge_input_chars, @@ -539,7 +538,7 @@ def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Serie judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, - max_model_len=args.effective_judge_max_model_len(), + max_model_len=args.max_judge_model_len, chat_template=args.chat_template, **_build_judge_model_kwargs(args=args, limit_event_tracker=limit_event_tracker), ) @@ -578,14 +577,14 @@ def _load_or_generate_completions(model_spec: str, usage_phase: str) -> pd.Serie strip_thinking_before_judging=args.strip_thinking_before_judging, system_prompt=resolved_prompt.system_prompt, user_prompt_template=resolved_prompt.user_prompt_template, - truncate_input_chars=args.effective_judge_truncation(), + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=args.use_tqdm, usage_tracker=usage_tracker, usage_phase="judge", usage_model_spec=args.judge_model, limit_event_tracker=limit_event_tracker, judge_tokenizer=judge_tokenizer, - max_judge_model_len=args.effective_judge_max_model_len(), + max_judge_model_len=args.max_judge_model_len, max_out_tokens_judge=args.max_out_tokens_judge, ) diff --git a/judgearena/judge_prompt_presets.py b/judgearena/judge_prompt_presets.py index 7a14313..00c3ea3 100644 --- a/judgearena/judge_prompt_presets.py +++ b/judgearena/judge_prompt_presets.py @@ -14,6 +14,8 @@ ) _PROMPTS_DIR = Path(__file__).resolve().parent / "prompts" +_COMPLETION_LABEL_SINGLE = "Answer" +_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" _EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" _SCORE_FENCE = "\n```" @@ -54,9 +56,13 @@ class ResolvedJudgePrompt: def _render_user_prompt_template( - raw_template: str, *, provide_explanation: bool + raw_template: str, *, provide_explanation: bool, multi_turn: bool ) -> str: template = raw_template.replace( + "{completion_label}", + _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, + ) + template = template.replace( "{explanation_suffix}", _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, ) @@ -91,6 +97,7 @@ def resolve_pairwise_judge_prompt( default_user_prompt_template = _render_user_prompt_template( (_PROMPTS_DIR / prompt_filename).read_text(encoding="utf-8"), provide_explanation=provide_explanation, + multi_turn=multi_turn, ) return ResolvedJudgePrompt( preset_name=preset.name, diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 4b3a505..51d3b81 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -50,8 +50,7 @@ from judgearena.generate_and_evaluate import CliArgs -# Original MT-Bench prompts include a visible explanation before the final verdict, -# and Qwen can spend thousands of visible tokens after reasoning ends on turn 2. +# Original MT-Bench prompts include a visible explanation before the final verdict. _MIN_MT_BENCH_JUDGE_TOKENS = 24576 _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN = 28672 @@ -95,7 +94,7 @@ def _mt_bench_generation_cache_name(args: CliArgs, *, model_name: str) -> str: "max_model_len": args.max_model_len, "chat_template": args.chat_template, "battle_thinking_token_budget": args.battle_thinking_token_budget, - "strip_thinking_in_turn_1_carryover": args.strip_thinking_in_turn_1_carryover, + "strip_thinking_before_judging": args.strip_thinking_before_judging, "engine_kwargs": _build_mt_bench_generation_kwargs( args=args, model_spec=model_name ), @@ -143,9 +142,7 @@ def _run_generation(model_name: str, usage_phase: str) -> pd.DataFrame: usage_tracker=usage_tracker, usage_phase=usage_phase, limit_event_tracker=limit_event_tracker, - strip_thinking_in_turn_1_carryover=( - args.strip_thinking_in_turn_1_carryover - ), + strip_thinking_before_turn_2_prompt=args.strip_thinking_before_judging, **_build_mt_bench_generation_kwargs(args=args, model_spec=model_name), ) @@ -251,7 +248,7 @@ def _run_mt_bench_fastchat( model_b=args.model_B, turns_mode="both", swap_mode=args.swap_mode, - truncate_input_chars=args.effective_judge_truncation(), + truncate_input_chars=args.truncate_judge_input_chars, use_tqdm=args.use_tqdm, prompt_preset=prompt_preset, strip_thinking_before_judging=args.strip_thinking_before_judging, @@ -335,13 +332,13 @@ def run_mt_bench( res_folder = Path(args.result_folder) / result_name res_folder.mkdir(parents=True, exist_ok=True) if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: - logger.info( - "MT-Bench judge prompts require room for budgeted thinking, the " - "original explanation, and the final verdict; " - f"overriding max_out_tokens_judge from {args.max_out_tokens_judge} " - f"to {_MIN_MT_BENCH_JUDGE_TOKENS}." + logger.warning( + "MT-Bench judge prompts request an explanation before the final " + "verdict; max_out_tokens_judge=%s may be too small " + "(recommended >= %s).", + args.max_out_tokens_judge, + _MIN_MT_BENCH_JUDGE_TOKENS, ) - args.max_out_tokens_judge = _MIN_MT_BENCH_JUDGE_TOKENS questions_df = load_instructions("mt-bench", n_instructions=args.n_instructions) logger.info( "Generating multi-turn completions for MT-Bench with %s and %s.", @@ -379,24 +376,22 @@ def run_mt_bench( result_name, ) return None - effective_judge_max_model_len = args.effective_judge_max_model_len() if ( - effective_judge_max_model_len is not None - and effective_judge_max_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN + args.max_judge_model_len is not None + and args.max_judge_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN ): - logger.info( - "MT-Bench judge prompts require a larger total context window for " - "prompt plus completion; " - f"overriding judge max_model_len from {effective_judge_max_model_len} " - f"to {_MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN}." + logger.warning( + "MT-Bench judge prompts request an explanation before the final " + "verdict; max_judge_model_len=%s may be too small for prompt plus " + "completion (recommended >= %s).", + args.max_judge_model_len, + _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN, ) - args.max_judge_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN - effective_judge_max_model_len = _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN judge_chat_model = make_model( model=args.judge_model, max_tokens=args.max_out_tokens_judge, temperature=0.0, - max_model_len=effective_judge_max_model_len, + max_model_len=args.max_judge_model_len, chat_template=args.chat_template, **_build_mt_bench_judge_model_kwargs( args=args, limit_event_tracker=limit_event_tracker diff --git a/judgearena/prompts/prompt-with-explanation.txt b/judgearena/prompts/prompt-with-explanation.txt index 6600f51..3d9eb41 100644 --- a/judgearena/prompts/prompt-with-explanation.txt +++ b/judgearena/prompts/prompt-with-explanation.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's Answer|> +<|The Start of Assistant A's {completion_label}|> {completion_A} -<|The End of Assistant A's Answer|> +<|The End of Assistant A's {completion_label}|> -<|The Start of Assistant B's Answer|> +<|The Start of Assistant B's {completion_label}|> {completion_B} -<|The End of Assistant B's Answer|> +<|The End of Assistant B's {completion_label}|> # Your output diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index 1b93858..38021e6 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's Answer|> +<|The Start of Assistant A's {completion_label}|> {completion_A} -<|The End of Assistant A's Answer|> +<|The End of Assistant A's {completion_label}|> -<|The Start of Assistant B's Answer|> +<|The Start of Assistant B's {completion_label}|> {completion_B} -<|The End of Assistant B's Answer|> +<|The End of Assistant B's {completion_label}|> # Your output diff --git a/tests/test_local_completion_loading.py b/tests/test_local_completion_loading.py index c9b22e2..eef3916 100644 --- a/tests/test_local_completion_loading.py +++ b/tests/test_local_completion_loading.py @@ -31,6 +31,17 @@ def test_load_judge_prompt_with_explanation_uses_freeform_scores(): assert "Assistant B's Answer" in user_prompt +def test_load_judge_prompt_multi_turn_uses_conversation_label(): + _system_prompt, user_prompt = evaluate.load_judge_system_and_user_prompt( + provide_explanation=False, + multi_turn=True, + ) + + assert "Assistant A's Conversation with User" in user_prompt + assert "Assistant B's Conversation with User" in user_prompt + assert "Assistant A's Answer" not in user_prompt + + def test_parse_optional_bool_accepts_explicit_true_false_values(): assert parse_optional_bool(None) is True assert parse_optional_bool("true") is True @@ -99,7 +110,7 @@ def fake_judge_and_parse_prefs(**kwargs): prefs = main_generate_and_eval( CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", @@ -166,7 +177,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs prefs = main_generate_and_eval( CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", @@ -220,7 +231,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs prefs = main_generate_and_eval( CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="VLLM/meta-llama/Llama-3.3-70B-Instruct", @@ -238,7 +249,7 @@ def fake_make_model(*, model, max_tokens, max_model_len, chat_template, **kwargs assert captured["make_model"]["thinking_token_budget"] == 2048 -def test_annotate_battles_warns_when_judge_inputs_are_truncated(monkeypatch, capsys): +def test_annotate_battles_warns_when_judge_inputs_are_truncated(monkeypatch, caplog): captured = {} def fake_do_inference( @@ -254,6 +265,7 @@ def fake_do_inference( return ["score_A: 0\nscore_B: 10"] monkeypatch.setattr(evaluate, "do_inference", fake_do_inference) + caplog.set_level("WARNING", logger=evaluate.__name__) annotations = evaluate.annotate_battles( judge_chat_model=object(), @@ -263,9 +275,9 @@ def fake_do_inference( truncate_input_chars=3, ) - stdout = capsys.readouterr().out assert ( - "Warning: truncated 2 judge inputs to 3 characters before evaluation." in stdout + "Warning: truncated 2 judge inputs to 3 characters before evaluation." + in caplog.text ) assert "Ans" in captured["judge_prompt"] assert "Answer A" not in captured["judge_prompt"] @@ -407,7 +419,7 @@ def fake_judge_and_parse_prefs(**kwargs): prefs = main_generate_and_eval( CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="VLLM/Qwen/Qwen3.5-27B-FP8", model_B="VLLM/allenai/Olmo-3-7B-Instruct", judge_model="Dummy/judge", @@ -434,7 +446,7 @@ def fake_judge_and_parse_prefs(**kwargs): def test_generation_cache_name_changes_with_generation_settings(): args = CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="Dummy/judge", @@ -443,7 +455,7 @@ def test_generation_cache_name_changes_with_generation_settings(): battle_thinking_token_budget=256, ) changed_args = CliArgs( - dataset="alpaca-eval", + task="alpaca-eval", model_A="Dummy/model-a", model_B="Dummy/model-b", judge_model="Dummy/judge", diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 3c50429..7097204 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -20,8 +20,8 @@ def _mt_bench_args( ) -> BaseCliArgs: """Construct a ``BaseCliArgs`` with MT-Bench CLI-style extras attached. - Using the real dataclass here ensures tests exercise the production - ``effective_judge_*`` fallback helpers instead of a duplicate shim. + Using the real dataclass here keeps tests close to the production CLI + contract while attaching the task/model fields owned by ``CliArgs``. """ args = BaseCliArgs(**base_overrides) args.task = dataset @@ -137,6 +137,7 @@ def test_generate_mt_bench_completions_uses_pregenerated_baseline(monkeypatch): ) generated_models = [] generation_kwargs = [] + generation_strip_flags = [] monkeypatch.setattr( mt_bench_utils, "cache_function_dataframe", lambda fun, **_kwargs: fun() @@ -155,11 +156,12 @@ def fake_generate_multiturn( usage_tracker, usage_phase, limit_event_tracker, - strip_thinking_in_turn_1_carryover, + strip_thinking_before_turn_2_prompt, **engine_kwargs, ): generated_models.append(model) generation_kwargs.append(engine_kwargs) + generation_strip_flags.append(strip_thinking_before_turn_2_prompt) return pd.DataFrame( { "instruction_index": [1, 2], @@ -195,7 +197,7 @@ def fake_generate_multiturn( max_model_len=16384, chat_template=None, battle_thinking_token_budget=None, - strip_thinking_in_turn_1_carryover=True, + strip_thinking_before_judging=True, engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, ) @@ -211,6 +213,7 @@ def fake_generate_multiturn( assert generation_kwargs == [ {"gpu_memory_utilization": 0.7, "language_model_only": True} ] + assert generation_strip_flags == [True] assert completions_a.loc[1, "completion_turn_1"] == "Gen A1" assert completions_b.loc[1, "completion_turn_1"] == "Base A1" assert completions_b.loc[2, "completion_turn_2"] == "Base B2" @@ -258,7 +261,7 @@ def test_conservative_winner_marks_one_sided_parse_failures_as_error(): assert fastchat_compat._conservative_winner("model_A", "model_B") == ("tie", True) -def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch): +def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch, caplog): questions_df = pd.DataFrame( {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, index=pd.Index([1], name="instruction_index"), @@ -338,16 +341,15 @@ def fake_run_mt_bench_fastchat(**kwargs): engine_kwargs={"gpu_memory_utilization": 0.7, "language_model_only": True}, ) + caplog.set_level("WARNING", logger=mt_bench_utils.__name__) mt_bench_utils.run_mt_bench(args, ignore_cache=False) assert args.swap_mode == "fixed" - assert args.max_out_tokens_judge == 24576 + assert args.max_out_tokens_judge == 256 assert args.max_model_len == 16384 - assert args.max_judge_model_len == 28672 - assert args.effective_judge_max_model_len() == 28672 - assert args.effective_judge_truncation() == 8192 - assert captured["make_model"]["max_tokens"] == 24576 - assert captured["make_model"]["max_model_len"] == 28672 + assert args.max_judge_model_len is None + assert captured["make_model"]["max_tokens"] == 256 + assert captured["make_model"]["max_model_len"] is None assert captured["make_model"]["kwargs"] == { "gpu_memory_utilization": 0.7, "language_model_only": True, @@ -362,6 +364,11 @@ def fake_run_mt_bench_fastchat(**kwargs): assert ( captured["run_mt_bench_fastchat"]["args"].strip_thinking_before_judging is False ) + assert ( + "MT-Bench judge prompts request an explanation before the final verdict" + in caplog.text + ) + assert "max_out_tokens_judge=256" in caplog.text def test_select_prompt_supports_optional_skywork_mt_bench_preset(): @@ -444,7 +451,7 @@ def fake_run_mt_bench_fastchat(**kwargs): assert captured["kwargs"]["prompt_preset"] == SKYWORK_JUDGE_PROMPT_PRESET assert captured["kwargs"]["args"].strip_thinking_before_judging is True - assert args.effective_judge_max_model_len() == 65536 - assert args.effective_judge_truncation() == 80000 - assert captured["kwargs"]["args"].effective_judge_truncation() == 80000 - assert captured["kwargs"]["args"].effective_judge_max_model_len() == 65536 + assert args.max_judge_model_len == 65536 + assert args.truncate_judge_input_chars == 80000 + assert captured["kwargs"]["args"].truncate_judge_input_chars == 80000 + assert captured["kwargs"]["args"].max_judge_model_len == 65536 diff --git a/tests/test_strip_thinking_carryover.py b/tests/test_strip_thinking_carryover.py index f33abed..7506313 100644 --- a/tests/test_strip_thinking_carryover.py +++ b/tests/test_strip_thinking_carryover.py @@ -19,7 +19,6 @@ from dataclasses import replace -from judgearena.cli_common import BaseCliArgs from judgearena.generate_and_evaluate import CliArgs from judgearena.mt_bench.mt_bench_utils import _mt_bench_generation_cache_name from judgearena.utils import ( @@ -127,8 +126,7 @@ def test_unclosed_think_block_is_unfixable_by_stripping(): def test_strip_disabled_reverts_to_pre_fix_behaviour(): - """With ``strip_thinking_in_turn_1_carryover=False`` (the pre-fix - behaviour, kept as a reproduction knob), the cap clips inside the + """When turn-1 carryover stripping is disabled, the cap clips inside the ```` block and the ```` closer is lost.""" reasoning = "deep thinking " * 400 visible = "Short answer." @@ -144,18 +142,10 @@ def test_strip_disabled_reverts_to_pre_fix_behaviour(): assert "" not in truncated -def test_default_flag_is_enabled_in_base_cli_args(): - """Guard the default value: the fix ships enabled so existing runs - (including Phase A of the Gemma-4 benchmark) pick it up without a - launcher change.""" - args = BaseCliArgs(judge_model="OpenRouter/google/gemma-4-31b-it") - assert args.strip_thinking_in_turn_1_carryover is True - - def _make_mt_bench_cli_args(**overrides) -> CliArgs: args = CliArgs( judge_model="OpenRouter/google/gemma-4-31b-it", - dataset="mt-bench", + task="mt-bench", model_A="VLLM/Qwen/Qwen3.5-9B", model_B="VLLM/Qwen/Qwen3.5-9B", n_instructions=3, @@ -168,12 +158,10 @@ def _make_mt_bench_cli_args(**overrides) -> CliArgs: def test_mt_bench_cache_key_changes_when_flag_flipped(): - """The flag participates in the MT-Bench generation cache key so that - flipping it off to reproduce pre-fix behaviour does not silently reuse - post-fix completions (and vice versa). Without this, a rerun with the - same numeric knobs would reuse stale cache for multi-turn datasets.""" - args_on = _make_mt_bench_cli_args(strip_thinking_in_turn_1_carryover=True) - args_off = _make_mt_bench_cli_args(strip_thinking_in_turn_1_carryover=False) + """Judge-side thinking stripping now also controls MT-Bench turn-1 + carryover stripping, so it must participate in the generation cache key.""" + args_on = _make_mt_bench_cli_args(strip_thinking_before_judging=True) + args_off = _make_mt_bench_cli_args(strip_thinking_before_judging=False) key_on = _mt_bench_generation_cache_name(args_on, model_name="VLLM/Qwen/Qwen3.5-9B") key_off = _mt_bench_generation_cache_name( From 5411ff8c15c06e0aec6b1d21e3d38df72f1481f7 Mon Sep 17 00:00:00 2001 From: ErlisLushtaku Date: Wed, 29 Apr 2026 17:18:55 +0200 Subject: [PATCH 28/28] Add default score-based verdict mode for fastchat --- judgearena/cli.py | 2 + judgearena/cli_common.py | 20 + judgearena/generate_and_evaluate.py | 1 + judgearena/mt_bench/common.py | 7 + judgearena/mt_bench/fastchat_compat.py | 48 +- judgearena/mt_bench/mt_bench_utils.py | 178 ++++++- judgearena/mt_bench/preset_judging.py | 645 ++++++++++++++++++++++++ judgearena/mt_bench/prompt_templates.py | 31 ++ tests/test_cli.py | 36 ++ tests/test_mt_bench_downloads.py | 384 +++++++++++++- 10 files changed, 1281 insertions(+), 71 deletions(-) create mode 100644 judgearena/mt_bench/preset_judging.py create mode 100644 judgearena/mt_bench/prompt_templates.py diff --git a/judgearena/cli.py b/judgearena/cli.py index c8c778c..649fe9e 100644 --- a/judgearena/cli.py +++ b/judgearena/cli.py @@ -197,6 +197,7 @@ def _build_elo_args( swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, skip_judging=args.skip_judging, @@ -234,6 +235,7 @@ def _build_generate_and_evaluate_args( swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, skip_judging=args.skip_judging, diff --git a/judgearena/cli_common.py b/judgearena/cli_common.py index 258a378..6cfe627 100644 --- a/judgearena/cli_common.py +++ b/judgearena/cli_common.py @@ -13,6 +13,8 @@ from judgearena.judge_prompt_presets import JUDGE_PROMPT_PRESETS +MT_BENCH_JUDGE_MODES = ("default", "fastchat_original") + @dataclass class BaseCliArgs: @@ -25,6 +27,7 @@ class BaseCliArgs: swap_mode: str = "fixed" ignore_cache: bool = False judge_prompt_preset: str = "default" + mt_bench_judge_mode: str = "default" battle_thinking_token_budget: int | None = None strip_thinking_before_judging: bool = False skip_judging: bool = False @@ -47,6 +50,11 @@ def __post_init__(self): assert self.swap_mode in supported_modes, ( f"Only {supported_modes} modes are supported but got {self.swap_mode}." ) + assert self.mt_bench_judge_mode in MT_BENCH_JUDGE_MODES, ( + "Only " + f"{list(MT_BENCH_JUDGE_MODES)} MT-Bench judge modes are supported but " + f"got {self.mt_bench_judge_mode!r}." + ) def parse_optional_bool(raw: str | None) -> bool: @@ -122,6 +130,18 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: "verdict-first preset." ), ) + parser.add_argument( + "--mt_bench_judge_mode", + type=str, + choices=MT_BENCH_JUDGE_MODES, + default="default", + help=( + "MT-Bench-only judging mode. 'default' makes MT-Bench obey " + "--judge_prompt_preset like the other benchmarks, while " + "'fastchat_original' preserves the original FastChat-style " + "prompting and [[A]]/[[B]]/[[C]] verdict parsing." + ), + ) parser.add_argument( "--battle_thinking_token_budget", type=int, diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 7482865..5a0b49b 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -172,6 +172,7 @@ def parse_args(cls): swap_mode=args.swap_mode, ignore_cache=args.ignore_cache, judge_prompt_preset=args.judge_prompt_preset, + mt_bench_judge_mode=args.mt_bench_judge_mode, battle_thinking_token_budget=args.battle_thinking_token_budget, strip_thinking_before_judging=args.strip_thinking_before_judging, skip_judging=args.skip_judging, diff --git a/judgearena/mt_bench/common.py b/judgearena/mt_bench/common.py index 51b0963..9c5b095 100644 --- a/judgearena/mt_bench/common.py +++ b/judgearena/mt_bench/common.py @@ -7,6 +7,13 @@ from judgearena.utils import safe_text_with_metadata +MT_BENCH_REFERENCE_CATEGORIES: set[str] = { + "math", + "reasoning", + "coding", + "arena-hard-200", +} + @dataclass(frozen=True) class MTBenchPairwiseRow: diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py index dcaec19..d7c4c84 100644 --- a/judgearena/mt_bench/fastchat_compat.py +++ b/judgearena/mt_bench/fastchat_compat.py @@ -4,7 +4,6 @@ import math from dataclasses import dataclass -from pathlib import Path from typing import Any, Literal import pandas as pd @@ -14,7 +13,14 @@ DEFAULT_JUDGE_PROMPT_PRESET, SKYWORK_JUDGE_PROMPT_PRESET, ) -from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows +from judgearena.mt_bench.common import ( + MT_BENCH_REFERENCE_CATEGORIES, + iter_mt_bench_pairwise_rows, +) +from judgearena.mt_bench.prompt_templates import ( + build_mt_bench_user_prompt_template, + render_mt_bench_prompt_text, +) from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker from judgearena.utils import ( LimitEventTracker, @@ -35,13 +41,6 @@ "arena-hard-200": 0.0, } -FASTCHAT_NEED_REF_CATS: set[str] = { - "math", - "reasoning", - "coding", - "arena-hard-200", -} - FastChatVerdict = Literal["A", "B", "tie", "error"] PairwiseWinner = Literal["model_A", "model_B", "tie", "error"] @@ -55,21 +54,7 @@ class FastChatPairwisePrompt: ref_based: bool -_PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" _SYSTEM_BASE_FILE = "system-base.txt" -_USER_SINGLE_BASE_FILE = "user-single-base.txt" -_USER_MULTI_BASE_FILE = "user-multi-base.txt" -_USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" -_USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" - - -def _load_prompt_text(filename: str) -> str: - path = _PROMPTS_DIR / filename - return path.read_text(encoding="utf-8") - - -def _render_prompt_text(filename: str, **kwargs: str) -> str: - return _load_prompt_text(filename).format(**kwargs) def _build_system_prompt( @@ -80,7 +65,7 @@ def _build_system_prompt( focus_line: str = "", ) -> str: focus_segment = f"{focus_line} " if focus_line else "" - return _render_prompt_text( + return render_mt_bench_prompt_text( _SYSTEM_BASE_FILE, user_subject=user_subject, task_description=task_description, @@ -89,17 +74,6 @@ def _build_system_prompt( ) -def _build_user_prompt_template(*, multi_turn: bool, ref_based: bool) -> str: - base_filename = _USER_MULTI_BASE_FILE if multi_turn else _USER_SINGLE_BASE_FILE - reference_block = "" - if ref_based: - ref_block_filename = ( - _USER_MULTI_REF_BLOCK_FILE if multi_turn else _USER_SINGLE_REF_BLOCK_FILE - ) - reference_block = _load_prompt_text(ref_block_filename).rstrip("\n") + "\n\n" - return _render_prompt_text(base_filename, reference_block=reference_block) - - def _load_pairwise_prompt( *, name: str, @@ -120,7 +94,7 @@ def _load_pairwise_prompt( begin_instruction=system_begin_instruction, focus_line=system_focus_line, ), - user_prompt_template=_build_user_prompt_template( + user_prompt_template=build_mt_bench_user_prompt_template( multi_turn=multi_turn, ref_based=ref_based, ), @@ -321,7 +295,7 @@ def _select_prompt( raise ValueError( f"Unsupported MT-Bench prompt preset '{prompt_preset}'. Choose from: {supported}." ) - needs_ref = (category or "") in FASTCHAT_NEED_REF_CATS + needs_ref = (category or "") in MT_BENCH_REFERENCE_CATEGORIES if needs_ref and multi_turn: return prompt_variants["multi_ref"] if needs_ref: diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py index 51d3b81..67afefa 100644 --- a/judgearena/mt_bench/mt_bench_utils.py +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -29,6 +29,7 @@ FASTCHAT_TEMPERATURE_CONFIG, judge_mt_bench_pairwise_fastchat, ) +from judgearena.mt_bench.preset_judging import judge_mt_bench_with_preset from judgearena.openrouter_reference_pricing import ( OpenRouterReferencePricingTracker, build_openrouter_reference_pricing_summary, @@ -182,6 +183,23 @@ def _build_mt_bench_result_name(args: CliArgs, suffix: str | None = None) -> str return name.replace("/", "_") +def _build_mt_bench_input_payloads( + *, + questions_df: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, +) -> dict[str, object]: + return { + "instruction_index": questions_df.index.tolist(), + "turn_1": questions_df["turn_1"].tolist(), + "turn_2": questions_df["turn_2"].tolist(), + "completion_turn_1_A": completions_a["completion_turn_1"].tolist(), + "completion_turn_2_A": completions_a["completion_turn_2"].tolist(), + "completion_turn_1_B": completions_b["completion_turn_1"].tolist(), + "completion_turn_2_B": completions_b["completion_turn_2"].tolist(), + } + + def _save_mt_bench_results( *, args: CliArgs, @@ -192,14 +210,14 @@ def _save_mt_bench_results( questions_df: pd.DataFrame, pricing_reference: dict[str, object] | None, started_at_utc: datetime, - name_suffix: str | None = None, + input_payloads: dict[str, object] | None = None, + judge_system_prompt: str | None = None, + judge_user_prompt_template: str | None = None, ) -> None: """Persist MT-Bench arguments, annotations, and aggregate results.""" - name = _build_mt_bench_result_name(args, suffix=name_suffix) - res_folder = Path(args.result_folder) / name res_folder.mkdir(parents=True, exist_ok=True) - with open(res_folder / f"args-{name}.json", "w") as f: + with open(res_folder / f"args-{result_name}.json", "w") as f: json.dump(_to_jsonable(asdict(args)), f, indent=2, allow_nan=False) annotations_df.to_csv(res_folder / f"{result_name}-annotations.csv", index=False) @@ -212,11 +230,14 @@ def _save_mt_bench_results( entrypoint="judgearena.mt_bench.mt_bench_utils.run_mt_bench", run=asdict(args), results=results, - input_payloads={ + input_payloads=input_payloads + or { "instruction_index": questions_df.index.tolist(), "turn_1": questions_df["turn_1"].tolist(), "turn_2": questions_df["turn_2"].tolist(), }, + judge_system_prompt=judge_system_prompt, + judge_user_prompt_template=judge_user_prompt_template, started_at_utc=started_at_utc, pricing_reference=pricing_reference, ) @@ -265,6 +286,7 @@ def _run_mt_bench_fastchat( "model_B": args.model_B, "judge_model": args.judge_model, "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, "strip_thinking_before_judging": args.strip_thinking_before_judging, "battle_thinking_token_budget": args.battle_thinking_token_budget, "num_inconsistent": num_inconsistent, @@ -297,7 +319,114 @@ def _run_mt_bench_fastchat( questions_df=questions_df, pricing_reference=pricing_reference, started_at_utc=started_at_utc, - name_suffix="mtbench", + input_payloads=_build_mt_bench_input_payloads( + questions_df=questions_df, + completions_a=completions_a, + completions_b=completions_b, + ), + ) + return prefs + + +def _run_mt_bench_preset( + *, + args: CliArgs, + res_folder: Path, + result_name: str, + questions_df: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + judge_chat_model, + prompt_preset: str, + usage_tracker: OpenRouterReferencePricingTracker, + limit_event_tracker: LimitEventTracker | None, + started_at_utc: datetime, +) -> pd.Series: + prefs, annotations, combined_metadata, _num_inconsistent = ( + judge_mt_bench_with_preset( + judge_chat_model=judge_chat_model, + judge_model=args.judge_model, + questions=questions_df, + completions_a=completions_a, + completions_b=completions_b, + model_a=args.model_A, + model_b=args.model_B, + turns_mode="both", + swap_mode=args.swap_mode, + truncate_input_chars=args.truncate_judge_input_chars, + use_tqdm=args.use_tqdm, + prompt_preset=prompt_preset, + provide_explanation=args.provide_explanation, + strip_thinking_before_judging=args.strip_thinking_before_judging, + judge_tokenizer=getattr(judge_chat_model, "tokenizer", None), + max_judge_model_len=args.max_judge_model_len, + max_out_tokens_judge=args.max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase="judge", + limit_event_tracker=limit_event_tracker, + ) + ) + + stats = compute_pref_summary(prefs) + results = { + "task": args.task, + "model_A": args.model_A, + "model_B": args.model_B, + "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, + "strip_thinking_before_judging": args.strip_thinking_before_judging, + "battle_thinking_token_budget": args.battle_thinking_token_budget, + **stats, + "limit_events": limit_event_tracker.build_summary() + if limit_event_tracker is not None + else {}, + "per_category": _compute_grouped_stats(prefs, combined_metadata, "category"), + "per_turn": _compute_grouped_stats(prefs, combined_metadata, "turn"), + "preferences": prefs.tolist(), + "date": str(datetime.now().isoformat()), + "user": os.getenv("USER", ""), + } + print_results(results) + pricing_reference = build_openrouter_reference_pricing_summary( + tracker=usage_tracker, + phase_model_specs={ + "generation_model_A": args.model_A, + "generation_model_B": args.model_B, + "judge": args.judge_model, + }, + ) + print(format_openrouter_reference_pricing_summary(pricing_reference)) + unique_system_prompts = { + row.get("system_prompt") + for row in annotations + if row.get("system_prompt") is not None + } + unique_user_templates = { + row.get("user_prompt_template") + for row in annotations + if row.get("user_prompt_template") is not None + } + _save_mt_bench_results( + args=args, + res_folder=res_folder, + result_name=result_name, + results=results, + annotations_df=pd.DataFrame(annotations), + questions_df=questions_df, + pricing_reference=pricing_reference, + started_at_utc=started_at_utc, + input_payloads=_build_mt_bench_input_payloads( + questions_df=questions_df, + completions_a=completions_a, + completions_b=completions_b, + ), + judge_system_prompt=next(iter(unique_system_prompts), None) + if len(unique_system_prompts) == 1 + else None, + judge_user_prompt_template=next(iter(unique_user_templates), None) + if len(unique_user_templates) == 1 + else None, ) return prefs @@ -309,12 +438,17 @@ def run_mt_bench( res_folder: Path | None = None, result_name: str | None = None, ): - """MT-Bench pipeline with FastChat-compatible pairwise judging.""" + """MT-Bench pipeline with preset or FastChat-original pairwise judging.""" run_started_at = datetime.now(UTC) usage_tracker = OpenRouterReferencePricingTracker() limit_event_tracker = LimitEventTracker() prompt_preset = args.judge_prompt_preset or DEFAULT_JUDGE_PROMPT_PRESET - if prompt_preset == DEFAULT_JUDGE_PROMPT_PRESET and not args.provide_explanation: + fastchat_mode = args.mt_bench_judge_mode == "fastchat_original" + if ( + fastchat_mode + and prompt_preset == DEFAULT_JUDGE_PROMPT_PRESET + and not args.provide_explanation + ): logger.info( "MT-Bench ignores provide_explanation=False and keeps the original " "FastChat-style explanation-plus-verdict prompt." @@ -331,7 +465,7 @@ def run_mt_bench( if res_folder is None: res_folder = Path(args.result_folder) / result_name res_folder.mkdir(parents=True, exist_ok=True) - if args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: + if fastchat_mode and args.max_out_tokens_judge < _MIN_MT_BENCH_JUDGE_TOKENS: logger.warning( "MT-Bench judge prompts request an explanation before the final " "verdict; max_out_tokens_judge=%s may be too small " @@ -361,6 +495,8 @@ def run_mt_bench( "model_A": args.model_A, "model_B": args.model_B, "judge_model": args.judge_model, + "judge_prompt_preset": prompt_preset, + "mt_bench_judge_mode": args.mt_bench_judge_mode, "n_instructions": args.n_instructions if args.n_instructions is not None else len(questions_df), @@ -376,7 +512,7 @@ def run_mt_bench( result_name, ) return None - if ( + if fastchat_mode and ( args.max_judge_model_len is not None and args.max_judge_model_len < _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN ): @@ -387,17 +523,23 @@ def run_mt_bench( args.max_judge_model_len, _MIN_MT_BENCH_JUDGE_MAX_MODEL_LEN, ) - judge_chat_model = make_model( - model=args.judge_model, - max_tokens=args.max_out_tokens_judge, - temperature=0.0, - max_model_len=args.max_judge_model_len, - chat_template=args.chat_template, + judge_model_kwargs = { + "model": args.judge_model, + "max_tokens": args.max_out_tokens_judge, + "max_model_len": args.max_judge_model_len, + "chat_template": args.chat_template, **_build_mt_bench_judge_model_kwargs( - args=args, limit_event_tracker=limit_event_tracker + args=args, + limit_event_tracker=limit_event_tracker, ), + } + if fastchat_mode: + judge_model_kwargs["temperature"] = 0.0 + judge_chat_model = make_model( + **judge_model_kwargs, ) - return _run_mt_bench_fastchat( + runner = _run_mt_bench_fastchat if fastchat_mode else _run_mt_bench_preset + return runner( args=args, res_folder=res_folder, result_name=result_name, diff --git a/judgearena/mt_bench/preset_judging.py b/judgearena/mt_bench/preset_judging.py new file mode 100644 index 0000000..ded587b --- /dev/null +++ b/judgearena/mt_bench/preset_judging.py @@ -0,0 +1,645 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + +import pandas as pd +from langchain_core.prompts import ChatPromptTemplate + +from judgearena.evaluate import ( + _PREFLIGHT_MAX_ITERATIONS, + _PREFLIGHT_MIN_COMPLETION_CHARS, + _PREFLIGHT_RESERVED_TOKENS, + PairScore, + _chars_per_token, + _count_chat_tokens, + _find_token_overflows, +) +from judgearena.judge_prompt_presets import ( + DEFAULT_JUDGE_PROMPT_PRESET, + ResolvedJudgePrompt, + resolve_pairwise_judge_prompt, +) +from judgearena.log import get_logger +from judgearena.mt_bench.common import ( + MT_BENCH_REFERENCE_CATEGORIES, + iter_mt_bench_pairwise_rows, +) +from judgearena.mt_bench.prompt_templates import build_mt_bench_user_prompt_template +from judgearena.openrouter_reference_pricing import OpenRouterReferencePricingTracker +from judgearena.utils import ( + LimitEventTracker, + do_inference, + strip_thinking_tags_with_metadata, + truncate_with_metadata, +) + +logger = get_logger(__name__) + + +@dataclass(frozen=True) +class MTBenchPresetPrompt: + name: str + preset_name: str + parser_mode: str + system_prompt: str | None + user_prompt_template: str + multi_turn: bool + ref_based: bool + + +def _extract_output_section(user_prompt_template: str) -> str: + marker = "# Your output" + marker_index = user_prompt_template.find(marker) + if marker_index < 0: + raise ValueError("Could not find '# Your output' section in preset template.") + return user_prompt_template[marker_index:].lstrip() + + +def _extract_user_preamble(user_prompt_template: str) -> str: + marker = "[User Question]" + marker_index = user_prompt_template.find(marker) + if marker_index < 0: + raise ValueError("Could not find '[User Question]' section in preset template.") + return user_prompt_template[:marker_index].rstrip() + + +def _build_mt_bench_preset_user_prompt_template( + *, + resolved_prompt: ResolvedJudgePrompt, + multi_turn: bool, + ref_based: bool, +) -> str: + base_template = build_mt_bench_user_prompt_template( + multi_turn=multi_turn, + ref_based=ref_based, + ) + if resolved_prompt.system_prompt is None: + user_preamble = _extract_user_preamble(resolved_prompt.user_prompt_template) + return f"{user_preamble}\n\n{base_template}" + output_section = _extract_output_section(resolved_prompt.user_prompt_template) + return f"{base_template}\n\n{output_section}" + + +def _select_preset_prompt( + category: str | None, + multi_turn: bool, + *, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool, +) -> MTBenchPresetPrompt: + ref_based = (category or "") in MT_BENCH_REFERENCE_CATEGORIES + resolved_prompt = resolve_pairwise_judge_prompt( + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + multi_turn=multi_turn, + ) + suffix = "multi" if multi_turn else "single" + if ref_based: + suffix += "_ref" + return MTBenchPresetPrompt( + name=f"{resolved_prompt.preset_name}-{suffix}", + preset_name=resolved_prompt.preset_name, + parser_mode=resolved_prompt.parser_mode, + system_prompt=resolved_prompt.system_prompt, + user_prompt_template=_build_mt_bench_preset_user_prompt_template( + resolved_prompt=resolved_prompt, + multi_turn=multi_turn, + ref_based=ref_based, + ), + multi_turn=multi_turn, + ref_based=ref_based, + ) + + +def _group_indices_by_prompt(items: list[dict[str, Any]]) -> dict[str, list[int]]: + grouped: dict[str, list[int]] = {} + for idx, item in enumerate(items): + grouped.setdefault(item["prompt_name"], []).append(idx) + return grouped + + +def _swap_prompt_kwargs(kwargs: dict[str, str], *, multi_turn: bool) -> dict[str, str]: + swapped = dict(kwargs) + if multi_turn: + swapped["answer_a_1"], swapped["answer_b_1"] = ( + swapped["answer_b_1"], + swapped["answer_a_1"], + ) + swapped["answer_a_2"], swapped["answer_b_2"] = ( + swapped["answer_b_2"], + swapped["answer_a_2"], + ) + return swapped + swapped["answer_a"], swapped["answer_b"] = swapped["answer_b"], swapped["answer_a"] + return swapped + + +def _build_chat_prompt_template(prompt: MTBenchPresetPrompt) -> ChatPromptTemplate: + message_templates: list[tuple[str, str]] = [] + if prompt.system_prompt is not None: + message_templates.append(("system", prompt.system_prompt)) + message_templates.append(("user", prompt.user_prompt_template)) + return ChatPromptTemplate.from_messages(message_templates) + + +def _answer_field_names(prompt: MTBenchPresetPrompt) -> tuple[str, ...]: + if prompt.multi_turn: + return ("answer_a_1", "answer_a_2", "answer_b_1", "answer_b_2") + return ("answer_a", "answer_b") + + +def _truncation_flag_name(field: str) -> str: + if field == "answer_a": + return "answer_a_1_truncated" + if field == "answer_b": + return "answer_b_1_truncated" + return f"{field}_truncated" + + +def _preflight_prompt_group_to_judge_budget( + *, + prompt_template: ChatPromptTemplate, + prompt_kwargs_batch: list[dict[str, str]], + batch_items: list[dict[str, Any]], + judge_tokenizer: Any, + max_judge_model_len: int, + max_out_tokens_judge: int | None, + limit_event_tracker: LimitEventTracker | None, +) -> list[Any]: + prompt_inputs = prompt_template.batch(prompt_kwargs_batch) + safe_budget = ( + max_judge_model_len - (max_out_tokens_judge or 0) - _PREFLIGHT_RESERVED_TOKENS + ) + + for _ in range(_PREFLIGHT_MAX_ITERATIONS): + overflows = _find_token_overflows(prompt_inputs, judge_tokenizer, safe_budget) + if not overflows: + return prompt_inputs + + for idx, _token_count in overflows: + prompt_kwargs = prompt_kwargs_batch[idx] + item = batch_items[idx] + answer_fields = item["answer_fields"] + if not answer_fields: + continue + + empty_kwargs = dict(prompt_kwargs) + for field in answer_fields: + empty_kwargs[field] = "" + fixed_tokens = _count_chat_tokens( + prompt_template.invoke(empty_kwargs), + judge_tokenizer, + ) + per_answer_budget = max( + 256, (safe_budget - fixed_tokens) // len(answer_fields) + ) + + for field in answer_fields: + prompt_kwargs[field], shrunk = truncate_with_metadata( + prompt_kwargs[field], + max_len=max( + _PREFLIGHT_MIN_COMPLETION_CHARS, + int( + per_answer_budget + * _chars_per_token(prompt_kwargs[field], judge_tokenizer) + * 0.9 + ), + ), + tracker=limit_event_tracker, + kind="judge_input_token_truncation", + stage="judge_input", + field=field, + case_id=item["case_id"], + ) + if shrunk: + item["limit_flags"][_truncation_flag_name(field)] = True + + prompt_inputs = prompt_template.batch(prompt_kwargs_batch) + + final_overflows = _find_token_overflows(prompt_inputs, judge_tokenizer, safe_budget) + for idx, token_count in final_overflows: + if limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_token_truncation_failed", + stage="judge_input", + case_id=batch_items[idx]["case_id"], + original_length=token_count, + final_length=safe_budget, + note=( + f"{_PREFLIGHT_MAX_ITERATIONS} shrink iterations did not " + f"bring tokens under {safe_budget}; falling through to " + "vLLM validation." + ), + ) + return prompt_inputs + + +def _infer_by_prompt_groups( + *, + judge_chat_model, + items: list[dict[str, Any]], + use_tqdm: bool, + swap_answers: bool, + judge_tokenizer: Any | None = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + usage_model_spec: str | None = None, +) -> tuple[list[str], list[dict[str, str]]]: + judgments: list[str] = [""] * len(items) + used_prompt_kwargs: list[dict[str, str]] = [{} for _ in items] + for idxs in _group_indices_by_prompt(items).values(): + prompt: MTBenchPresetPrompt = items[idxs[0]]["prompt"] + prompt_template = _build_chat_prompt_template(prompt) + + batch_kwargs: list[dict[str, str]] = [] + batch_items = [items[item_index] for item_index in idxs] + for item_index in idxs: + prompt_kwargs = dict(items[item_index]["prompt_kwargs"]) + if swap_answers: + prompt_kwargs = _swap_prompt_kwargs( + prompt_kwargs, + multi_turn=prompt.multi_turn, + ) + batch_kwargs.append(prompt_kwargs) + + if judge_tokenizer is not None and max_judge_model_len is not None: + prompt_inputs = _preflight_prompt_group_to_judge_budget( + prompt_template=prompt_template, + prompt_kwargs_batch=batch_kwargs, + batch_items=batch_items, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + limit_event_tracker=items[idxs[0]].get("limit_event_tracker"), + ) + else: + prompt_inputs = prompt_template.batch(batch_kwargs) + outputs = do_inference( + chat_model=judge_chat_model, + inputs=prompt_inputs, + use_tqdm=use_tqdm, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=usage_model_spec, + ) + for item_index, output, prompt_kwargs in zip( + idxs, outputs, batch_kwargs, strict=True + ): + judgments[item_index] = str(output) + used_prompt_kwargs[item_index] = prompt_kwargs + return judgments, used_prompt_kwargs + + +def _build_mt_bench_preset_items( + *, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + eval_single: bool, + eval_multi: bool, + truncate_input_chars: int | None, + prompt_preset: str, + provide_explanation: bool, + strip_thinking_before_judging: bool, + limit_event_tracker: LimitEventTracker | None, +) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + truncated_field_count = 0 + + def _record_mt_bench_truncation( + *, + case_id: str, + field: str, + truncated: bool, + ) -> None: + nonlocal truncated_field_count + if truncated and limit_event_tracker is not None: + limit_event_tracker.record( + "judge_input_char_truncation", + stage="judge_input", + field=field, + case_id=case_id, + ) + truncated_field_count += int(truncated) + + def _prepare_answer(answer: str, *, case_id: str, field: str) -> tuple[str, bool]: + if not strip_thinking_before_judging: + return answer, False + stripped_answer, stripped = strip_thinking_tags_with_metadata(answer) + if stripped and limit_event_tracker is not None: + limit_event_tracker.record( + "thinking_trace_stripped_before_judging", + stage="judge_input", + field=field, + case_id=case_id, + original_length=len(answer), + final_length=len(stripped_answer), + ) + return stripped_answer, stripped + + for pair_row in iter_mt_bench_pairwise_rows( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + truncate_input_chars=truncate_input_chars, + ): + category = pair_row.category + if eval_single: + case_id = f"{pair_row.question_id}:turn1" + prompt = _select_preset_prompt( + category, + multi_turn=False, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + ) + answer_a, answer_a_stripped = _prepare_answer( + pair_row.answer_a_1, + case_id=case_id, + field="answer_a_1", + ) + answer_b, answer_b_stripped = _prepare_answer( + pair_row.answer_b_1, + case_id=case_id, + field="answer_b_1", + ) + _record_mt_bench_truncation( + case_id=case_id, + field="turn_1_question", + truncated=pair_row.turn_1_question_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_a_1", + truncated=pair_row.answer_a_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="answer_b_1", + truncated=pair_row.answer_b_1_truncated, + ) + prompt_kwargs: dict[str, str] = { + "question": pair_row.turn_1_question, + "answer_a": answer_a, + "answer_b": answer_b, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_a_1_reasoning_stripped": answer_a_stripped, + "answer_b_1_reasoning_stripped": answer_b_stripped, + } + if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + prompt_kwargs["ref_answer_1"] = pair_row.ref_1 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + items.append( + { + "case_id": case_id, + "question_id": pair_row.question_id, + "category": category, + "turn": 1, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": prompt_kwargs, + "answer_fields": _answer_field_names(prompt), + "limit_flags": limit_flags, + "limit_event_tracker": limit_event_tracker, + } + ) + + if eval_multi and pair_row.turn_2_question: + case_id = f"{pair_row.question_id}:turn2" + prompt = _select_preset_prompt( + category, + multi_turn=True, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + ) + answer_a_1, answer_a_1_stripped = _prepare_answer( + pair_row.answer_a_1, + case_id=case_id, + field="answer_a_1", + ) + answer_a_2, answer_a_2_stripped = _prepare_answer( + pair_row.answer_a_2, + case_id=case_id, + field="answer_a_2", + ) + answer_b_1, answer_b_1_stripped = _prepare_answer( + pair_row.answer_b_1, + case_id=case_id, + field="answer_b_1", + ) + answer_b_2, answer_b_2_stripped = _prepare_answer( + pair_row.answer_b_2, + case_id=case_id, + field="answer_b_2", + ) + for field, truncated in ( + ("turn_1_question", pair_row.turn_1_question_truncated), + ("turn_2_question", pair_row.turn_2_question_truncated), + ("answer_a_1", pair_row.answer_a_1_truncated), + ("answer_a_2", pair_row.answer_a_2_truncated), + ("answer_b_1", pair_row.answer_b_1_truncated), + ("answer_b_2", pair_row.answer_b_2_truncated), + ): + _record_mt_bench_truncation( + case_id=case_id, + field=field, + truncated=truncated, + ) + prompt_kwargs = { + "question_1": pair_row.turn_1_question, + "question_2": pair_row.turn_2_question, + "answer_a_1": answer_a_1, + "answer_a_2": answer_a_2, + "answer_b_1": answer_b_1, + "answer_b_2": answer_b_2, + } + limit_flags = { + "turn_1_question_truncated": pair_row.turn_1_question_truncated, + "turn_2_question_truncated": pair_row.turn_2_question_truncated, + "answer_a_1_truncated": pair_row.answer_a_1_truncated, + "answer_a_2_truncated": pair_row.answer_a_2_truncated, + "answer_b_1_truncated": pair_row.answer_b_1_truncated, + "answer_b_2_truncated": pair_row.answer_b_2_truncated, + "answer_a_1_reasoning_stripped": answer_a_1_stripped, + "answer_a_2_reasoning_stripped": answer_a_2_stripped, + "answer_b_1_reasoning_stripped": answer_b_1_stripped, + "answer_b_2_reasoning_stripped": answer_b_2_stripped, + } + if prompt.ref_based: + _record_mt_bench_truncation( + case_id=case_id, + field="ref_1", + truncated=pair_row.ref_1_truncated, + ) + _record_mt_bench_truncation( + case_id=case_id, + field="ref_2", + truncated=pair_row.ref_2_truncated, + ) + prompt_kwargs["ref_answer_1"] = pair_row.ref_1 + prompt_kwargs["ref_answer_2"] = pair_row.ref_2 + limit_flags["ref_1_truncated"] = pair_row.ref_1_truncated + limit_flags["ref_2_truncated"] = pair_row.ref_2_truncated + items.append( + { + "case_id": case_id, + "question_id": pair_row.question_id, + "category": category, + "turn": 2, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": prompt_kwargs, + "answer_fields": _answer_field_names(prompt), + "limit_flags": limit_flags, + "limit_event_tracker": limit_event_tracker, + } + ) + if truncated_field_count: + logger.warning( + "Warning: truncated %s judge inputs to %s characters before evaluation.", + truncated_field_count, + truncate_input_chars, + ) + return items + + +def _normalize_preference(preference: float | None, *, swapped: bool) -> float: + if preference is None: + return math.nan + return 1.0 - preference if swapped else float(preference) + + +def judge_mt_bench_with_preset( + *, + judge_chat_model, + judge_model: str, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + model_a: str, + model_b: str, + turns_mode: str, + swap_mode: str, + truncate_input_chars: int | None, + use_tqdm: bool, + prompt_preset: str = DEFAULT_JUDGE_PROMPT_PRESET, + provide_explanation: bool = False, + strip_thinking_before_judging: bool = False, + judge_tokenizer: Any | None = None, + max_judge_model_len: int | None = None, + max_out_tokens_judge: int | None = None, + usage_tracker: OpenRouterReferencePricingTracker | None = None, + usage_phase: str | None = None, + limit_event_tracker: LimitEventTracker | None = None, +) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: + assert turns_mode in ("both", "single", "multi") + assert swap_mode in ("fixed", "both") + + eval_single = turns_mode in ("both", "single") + eval_multi = turns_mode in ("both", "multi") + + items = _build_mt_bench_preset_items( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + eval_single=eval_single, + eval_multi=eval_multi, + truncate_input_chars=truncate_input_chars, + prompt_preset=prompt_preset, + provide_explanation=provide_explanation, + strip_thinking_before_judging=strip_thinking_before_judging, + limit_event_tracker=limit_event_tracker, + ) + + judgments, prompt_kwargs_used = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=False, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, + ) + + annotations: list[dict[str, Any]] = [] + metadata: list[dict[str, object]] = [] + preferences: list[float] = [] + + def _append_results( + raw_judgments: list[str], + used_prompt_kwargs: list[dict[str, str]], + *, + swapped: bool, + ) -> None: + for item, raw_judgment, prompt_kwargs in zip( + items, raw_judgments, used_prompt_kwargs, strict=True + ): + prompt: MTBenchPresetPrompt = item["prompt"] + parsed_preference = PairScore( + parser_mode=prompt.parser_mode + ).parse_model_raw(raw_judgment) + normalized_preference = _normalize_preference( + parsed_preference, + swapped=swapped, + ) + annotation_row = { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + "model_A": model_b if swapped else model_a, + "model_B": model_a if swapped else model_b, + "judge": judge_model, + "prompt_name": prompt.name, + "prompt_preset": prompt.preset_name, + "parser_mode": prompt.parser_mode, + "system_prompt": prompt.system_prompt, + "user_prompt_template": prompt.user_prompt_template, + "user_prompt": prompt.user_prompt_template.format(**prompt_kwargs), + "judge_completion": raw_judgment, + "preference": normalized_preference, + "swapped": swapped, + } + annotation_row.update(item.get("limit_flags", {})) + annotations.append(annotation_row) + metadata.append( + { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + } + ) + preferences.append(normalized_preference) + + _append_results(judgments, prompt_kwargs_used, swapped=False) + + if swap_mode == "both": + swapped_judgments, swapped_prompt_kwargs = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=True, + judge_tokenizer=judge_tokenizer, + max_judge_model_len=max_judge_model_len, + max_out_tokens_judge=max_out_tokens_judge, + usage_tracker=usage_tracker, + usage_phase=usage_phase, + usage_model_spec=judge_model, + ) + _append_results(swapped_judgments, swapped_prompt_kwargs, swapped=True) + + return pd.Series(preferences, dtype=float), annotations, metadata, 0 diff --git a/judgearena/mt_bench/prompt_templates.py b/judgearena/mt_bench/prompt_templates.py new file mode 100644 index 0000000..edef887 --- /dev/null +++ b/judgearena/mt_bench/prompt_templates.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from pathlib import Path + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" +_USER_SINGLE_BASE_FILE = "user-single-base.txt" +_USER_MULTI_BASE_FILE = "user-multi-base.txt" +_USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" +_USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" + + +def load_mt_bench_prompt_text(filename: str) -> str: + path = _PROMPTS_DIR / filename + return path.read_text(encoding="utf-8") + + +def render_mt_bench_prompt_text(filename: str, **kwargs: str) -> str: + return load_mt_bench_prompt_text(filename).format(**kwargs) + + +def build_mt_bench_user_prompt_template(*, multi_turn: bool, ref_based: bool) -> str: + base_filename = _USER_MULTI_BASE_FILE if multi_turn else _USER_SINGLE_BASE_FILE + reference_block = "" + if ref_based: + ref_block_filename = ( + _USER_MULTI_REF_BLOCK_FILE if multi_turn else _USER_SINGLE_REF_BLOCK_FILE + ) + reference_block = ( + load_mt_bench_prompt_text(ref_block_filename).rstrip("\n") + "\n\n" + ) + return render_mt_bench_prompt_text(base_filename, reference_block=reference_block) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5d905ea..e19f0d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -356,3 +356,39 @@ def test_engine_kwargs_parsed_as_json(capture_mains): ) ge_args: CliArgs = capture_mains["args"] assert ge_args.engine_kwargs == {"tensor_parallel_size": 4} + + +def test_mt_bench_defaults_to_default_judge_mode(capture_mains): + cli_module.cli( + [ + "--task", + "mt-bench", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.mt_bench_judge_mode == "default" + + +def test_mt_bench_forwards_fastchat_original_mode(capture_mains): + cli_module.cli( + [ + "--task", + "mt-bench", + "--model_A", + "Dummy/A", + "--model_B", + "Dummy/B", + "--judge", + "Dummy/J", + "--mt_bench_judge_mode", + "fastchat_original", + ] + ) + ge_args: CliArgs = capture_mains["args"] + assert ge_args.mt_bench_judge_mode == "fastchat_original" diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py index 7097204..368cdab 100644 --- a/tests/test_mt_bench_downloads.py +++ b/tests/test_mt_bench_downloads.py @@ -1,3 +1,5 @@ +import importlib +from datetime import UTC, datetime from types import SimpleNamespace import pandas as pd @@ -248,6 +250,243 @@ def test_pair_v2_system_prompt_matches_original_fastchat_contract(): ) +def test_mt_bench_prompt_templates_preserve_multi_turn_reference_blocks(): + prompt_templates = importlib.import_module("judgearena.mt_bench.prompt_templates") + + rendered = prompt_templates.build_mt_bench_user_prompt_template( + multi_turn=True, + ref_based=True, + ) + + assert "<|The Start of Reference Answer|>" in rendered + assert "### User:\n{question_1}" in rendered + assert "### Reference answer:\n{ref_answer_2}" in rendered + assert "### Assistant A:\n{answer_a_2}" in rendered + assert "### Assistant B:\n{answer_b_2}" in rendered + + +def test_select_preset_prompt_uses_default_score_mode_with_mt_bench_template(): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + + prompt = preset_judging._select_preset_prompt( + "math", + multi_turn=True, + prompt_preset="default", + provide_explanation=False, + ) + + assert prompt.parser_mode == "score" + assert prompt.ref_based is True + assert prompt.multi_turn is True + assert prompt.system_prompt + assert "<|The Start of Reference Answer|>" in prompt.user_prompt_template + assert "### Assistant A:\n{answer_a_2}" in prompt.user_prompt_template + + +def test_select_preset_prompt_uses_skywork_verdict_mode_with_mt_bench_template(): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + + prompt = preset_judging._select_preset_prompt( + "writing", + multi_turn=True, + prompt_preset=SKYWORK_JUDGE_PROMPT_PRESET, + provide_explanation=True, + ) + + assert prompt.parser_mode == "verdict" + assert prompt.ref_based is False + assert prompt.multi_turn is True + assert prompt.system_prompt is None + assert "Please briefly explain your reasoning first" in prompt.user_prompt_template + assert "### Assistant B:\n{answer_b_2}" in prompt.user_prompt_template + + +def test_preset_judging_uses_shared_swap_mode_both_semantics(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + questions_df = pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Q1"], + "turn_2": ["Q2"], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ) + completions_a = pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ) + completions_b = pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ) + call_count = {"count": 0} + + def fake_do_inference(**kwargs): + call_count["count"] += 1 + if call_count["count"] <= 2: + return ["score_A: 9\nscore_B: 3"] + return ["score_A: 2\nscore_B: 7"] + + monkeypatch.setattr(preset_judging, "do_inference", fake_do_inference) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=questions_df, + completions_a=completions_a, + completions_b=completions_b, + model_a="Model/A", + model_b="Model/B", + turns_mode="both", + swap_mode="both", + truncate_input_chars=None, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + ) + ) + + assert len(prefs) == 4 + assert len(annotations) == 4 + assert len(metadata) == 4 + assert num_inconsistent == 0 + assert [row["swapped"] for row in annotations] == [False, False, True, True] + + +def test_preset_judging_uses_shared_char_truncation_event_kind(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + tracker = utils.LimitEventTracker() + + monkeypatch.setattr( + preset_judging, + "do_inference", + lambda **kwargs: ["score_A: 8\nscore_B: 4"], + ) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Q" * 20], + "turn_2": [""], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ), + completions_a=pd.DataFrame( + {"completion_turn_1": ["A" * 30], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + completions_b=pd.DataFrame( + {"completion_turn_1": ["B" * 30], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + model_a="Model/A", + model_b="Model/B", + turns_mode="single", + swap_mode="fixed", + truncate_input_chars=10, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + limit_event_tracker=tracker, + ) + ) + + assert len(prefs) == 1 + assert len(annotations) == 1 + assert len(metadata) == 1 + assert num_inconsistent == 0 + assert tracker.build_summary()["counts_by_kind"]["judge_input_char_truncation"] == 3 + + +def test_preset_judging_preflights_token_budget_for_default_mode(monkeypatch): + preset_judging = importlib.import_module("judgearena.mt_bench.preset_judging") + tracker = utils.LimitEventTracker() + captured = {} + + class FakeTokenizer: + def apply_chat_template(self, messages, tokenize=True): + text = "".join( + str( + message["content"] if isinstance(message, dict) else message.content + ) + for message in messages + ) + return [0] * len(text) + + def encode(self, text): + return [0] * len(text) + + def fake_do_inference(*, inputs, **kwargs): + captured["inputs"] = inputs + return ["score_A: 8\nscore_B: 4"] + + monkeypatch.setattr(preset_judging, "do_inference", fake_do_inference) + + prefs, annotations, metadata, num_inconsistent = ( + preset_judging.judge_mt_bench_with_preset( + judge_chat_model=object(), + judge_model="Dummy/J", + questions=pd.DataFrame( + { + "category": ["writing"], + "turn_1": ["Question"], + "turn_2": [""], + "reference_turn_1": [""], + "reference_turn_2": [""], + }, + index=pd.Index([1], name="question_id"), + ), + completions_a=pd.DataFrame( + {"completion_turn_1": ["A" * 1200], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + completions_b=pd.DataFrame( + {"completion_turn_1": ["B" * 1200], "completion_turn_2": [""]}, + index=pd.Index([1], name="question_id"), + ), + model_a="Model/A", + model_b="Model/B", + turns_mode="single", + swap_mode="fixed", + truncate_input_chars=None, + use_tqdm=False, + prompt_preset="default", + provide_explanation=False, + limit_event_tracker=tracker, + judge_tokenizer=FakeTokenizer(), + max_judge_model_len=2300, + max_out_tokens_judge=32, + ) + ) + + prompt_value = captured["inputs"][0] + assert len(prefs) == 1 + assert len(annotations) == 1 + assert len(metadata) == 1 + assert num_inconsistent == 0 + assert len(FakeTokenizer().apply_chat_template(prompt_value.to_messages())) <= 2012 + assert annotations[0]["answer_a_1_truncated"] is True + assert annotations[0]["answer_b_1_truncated"] is True + assert ( + tracker.build_summary()["counts_by_kind"]["judge_input_token_truncation"] >= 1 + ) + + def test_conservative_winner_marks_one_sided_parse_failures_as_error(): assert fastchat_compat._conservative_winner("model_A", "error") == ( "error", @@ -295,7 +534,13 @@ def test_run_mt_bench_forwards_engine_kwargs_to_judge(monkeypatch, caplog): ) def fake_make_model( - *, model, max_tokens, temperature, max_model_len, chat_template, **kwargs + *, + model, + max_tokens, + temperature=None, + max_model_len, + chat_template, + **kwargs, ): captured["make_model"] = { "model": model, @@ -309,8 +554,8 @@ def fake_make_model( monkeypatch.setattr(mt_bench_utils, "make_model", fake_make_model) - def fake_run_mt_bench_fastchat(**kwargs): - captured["run_mt_bench_fastchat"] = kwargs + def fake_run_mt_bench_preset(**kwargs): + captured["run_mt_bench_preset"] = kwargs return pd.Series( kwargs["questions_df"].index.to_list(), dtype=float, @@ -318,8 +563,8 @@ def fake_run_mt_bench_fastchat(**kwargs): monkeypatch.setattr( mt_bench_utils, - "_run_mt_bench_fastchat", - fake_run_mt_bench_fastchat, + "_run_mt_bench_preset", + fake_run_mt_bench_preset, ) args = _mt_bench_args( @@ -359,16 +604,13 @@ def fake_run_mt_bench_fastchat(**kwargs): "limit_event_model_spec": "VLLM/Qwen/Qwen3.5-27B-FP8", "limit_event_tracker": captured["make_model"]["kwargs"]["limit_event_tracker"], } - assert captured["run_mt_bench_fastchat"]["args"].swap_mode == "fixed" - assert captured["run_mt_bench_fastchat"]["prompt_preset"] == "default" - assert ( - captured["run_mt_bench_fastchat"]["args"].strip_thinking_before_judging is False - ) + assert captured["make_model"]["temperature"] is None + assert captured["run_mt_bench_preset"]["args"].swap_mode == "fixed" + assert captured["run_mt_bench_preset"]["prompt_preset"] == "default" assert ( - "MT-Bench judge prompts request an explanation before the final verdict" - in caplog.text + captured["run_mt_bench_preset"]["args"].strip_thinking_before_judging is False ) - assert "max_out_tokens_judge=256" in caplog.text + assert "MT-Bench ignores provide_explanation=False" not in caplog.text def test_select_prompt_supports_optional_skywork_mt_bench_preset(): @@ -416,14 +658,14 @@ def test_run_mt_bench_keeps_skywork_prompt_preset(monkeypatch): ) monkeypatch.setattr(mt_bench_utils, "make_model", lambda **kwargs: object()) - def fake_run_mt_bench_fastchat(**kwargs): + def fake_run_mt_bench_preset(**kwargs): captured["kwargs"] = kwargs return pd.Series([0.0], dtype=float) monkeypatch.setattr( mt_bench_utils, - "_run_mt_bench_fastchat", - fake_run_mt_bench_fastchat, + "_run_mt_bench_preset", + fake_run_mt_bench_preset, ) args = _mt_bench_args( @@ -455,3 +697,113 @@ def fake_run_mt_bench_fastchat(**kwargs): assert args.truncate_judge_input_chars == 80000 assert captured["kwargs"]["args"].truncate_judge_input_chars == 80000 assert captured["kwargs"]["args"].max_judge_model_len == 65536 + + +def test_run_mt_bench_default_respects_judge_temperature_from_engine_kwargs( + monkeypatch, +): + questions_df = pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q1b"]}, + index=pd.Index([1], name="instruction_index"), + ) + captured = {} + + monkeypatch.setattr( + mt_bench_utils, + "load_instructions", + lambda dataset, n_instructions=None: questions_df, + ) + monkeypatch.setattr( + mt_bench_utils, + "_generate_mt_bench_completions", + lambda args, questions_df, ignore_cache, usage_tracker, limit_event_tracker: ( + pd.DataFrame( + { + "completion_turn_1": ["A1"], + "completion_turn_2": ["A2"], + }, + index=questions_df.index, + ), + pd.DataFrame( + { + "completion_turn_1": ["B1"], + "completion_turn_2": ["B2"], + }, + index=questions_df.index, + ), + ), + ) + + def fake_make_model( + *, + model, + max_tokens, + temperature=None, + max_model_len, + chat_template, + **kwargs, + ): + captured["temperature"] = temperature + return object() + + monkeypatch.setattr(mt_bench_utils, "make_model", fake_make_model) + monkeypatch.setattr( + mt_bench_utils, + "_run_mt_bench_preset", + lambda **kwargs: pd.Series([0.0], dtype=float), + ) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + n_instructions=1, + max_out_tokens_judge=256, + engine_kwargs={"temperature": 0.8}, + judge_engine_kwargs={"temperature": 0.4}, + ) + + mt_bench_utils.run_mt_bench(args, ignore_cache=False) + + assert captured["temperature"] == 0.4 + + +def test_save_mt_bench_results_uses_explicit_res_folder(tmp_path, monkeypatch): + captured = {} + + def fake_write_run_metadata(**kwargs): + captured["output_dir"] = kwargs["output_dir"] + return kwargs["output_dir"] / "run-metadata.v1.json" + + monkeypatch.setattr(mt_bench_utils, "write_run_metadata", fake_write_run_metadata) + + args = _mt_bench_args( + dataset="mt-bench", + model_A="VLLM/example/model-a", + model_B="gpt-4", + judge_model="VLLM/Qwen/Qwen3.5-27B-FP8", + result_folder=str(tmp_path / "results-root"), + ) + explicit_res_folder = tmp_path / "explicit-run" + result_name = "mt-bench-run" + + mt_bench_utils._save_mt_bench_results( + args=args, + res_folder=explicit_res_folder, + result_name=result_name, + results={"task": "mt-bench"}, + annotations_df=pd.DataFrame([{"question_id": 1, "turn": 1}]), + questions_df=pd.DataFrame( + {"turn_1": ["Q1"], "turn_2": ["Q2"]}, + index=pd.Index([1], name="question_id"), + ), + pricing_reference=None, + started_at_utc=datetime.now(UTC), + ) + + assert captured["output_dir"] == explicit_res_folder + assert (explicit_res_folder / f"args-{result_name}.json").exists() + assert (explicit_res_folder / f"results-{result_name}.json").exists() + assert (explicit_res_folder / f"{result_name}-annotations.csv").exists() + assert not (tmp_path / "results-root" / result_name).exists()