From ff1460c4ec9ff79ac851e7e920a372859563dbbd Mon Sep 17 00:00:00 2001 From: mali-git Date: Sat, 28 Jun 2025 19:21:37 +0200 Subject: [PATCH 1/3] feat: evaluate translations --- src/ml_filter/__main__.py | 26 ++++- src/ml_filter/{ => translation}/translate.py | 0 .../translation/translation_evaluation.py | 102 ++++++++++++++++++ tests/conftest.py | 5 +- tests/test_translate.py | 2 +- 5 files changed, 129 insertions(+), 6 deletions(-) rename src/ml_filter/{ => translation}/translate.py (100%) create mode 100644 src/ml_filter/translation/translation_evaluation.py diff --git a/src/ml_filter/__main__.py b/src/ml_filter/__main__.py index c0bf556c..8207407d 100644 --- a/src/ml_filter/__main__.py +++ b/src/ml_filter/__main__.py @@ -18,7 +18,8 @@ from ml_filter.llm_client import LLMClient from ml_filter.sample_from_hf_dataset import sample_from_hf_dataset, upload_file_to_hf from ml_filter.training.annotator_model_pipeline import run_annotator_training_pipeline -from ml_filter.translate import TranslationServiceType, TranslatorFactory +from ml_filter.translation.translate import TranslationServiceType, TranslatorFactory +from ml_filter.translation.translation_evaluation import evaluate_translations from ml_filter.utils.chunk_data import chunk_jsonl from ml_filter.utils.manipulate_datasets import apply_score_transforms, convert_hf_dataset_to_jsonl, split_dataset from ml_filter.utils.manipulate_documents import merge_and_sort_jsonl_files @@ -757,5 +758,28 @@ def _get_target_language_codes_list_helper(target_language_codes: str) -> list[s return [lang_code.strip().lower() for lang_code in target_language_codes.split(",")] +@main.command(name="evaluate_translations") +@click.option("--data-dir", required=True, help="Directory containing translation JSONL files") +@click.option("--gold-path", required=True, help="Path to gold reference JSONL file") +@click.option("--model-name", default="Unbabel/wmt22-cometkiwi-da", help="COMET model to use") +@click.option("--languages", type=str, required=True, help="Comma-separated list of supported language codes") +@click.option("--batch-size", help="Batch size for processing translations") +def evaluate_translations_cli( + data_dir: str, + gold_path: str, + model_name: str, + languages: str, + batch_size: int, +): + """CLI entry point for evaluating translation quality.""" + evaluate_translations( + data_dir=data_dir, + gold_path=gold_path, + languages=languages.split(","), + model_name=model_name, + batch_size=batch_size, + ) + + if __name__ == "__main__": main() diff --git a/src/ml_filter/translate.py b/src/ml_filter/translation/translate.py similarity index 100% rename from src/ml_filter/translate.py rename to src/ml_filter/translation/translate.py diff --git a/src/ml_filter/translation/translation_evaluation.py b/src/ml_filter/translation/translation_evaluation.py new file mode 100644 index 00000000..80b06445 --- /dev/null +++ b/src/ml_filter/translation/translation_evaluation.py @@ -0,0 +1,102 @@ +import json +import logging +import os + +import numpy as np +from comet import download_model, load_from_checkpoint + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def _load_gold_dict(gold_path: str) -> dict[str, str]: + """Load reference translations from a JSONL file. + + Args: + gold_path: Path to the gold reference JSONL file. + + Returns: + A dictionary mapping document IDs to reference texts. + """ + gold_dict = {} + with open(gold_path, "r") as f: + for line in f: + item = json.loads(line) + gold_dict[item["document_id"]] = item["text"] + return gold_dict + + +def _prepare_translation_input(file_path: str, gold_dict: dict[str, str]) -> list[dict[str, str]]: + """Extract source and machine-translated texts from a JSONL file. + + Args: + file_path: Path to the target JSONL file. + lang: Language code. + gold_dict: Dictionary of gold references. + + Returns: + A list of dictionaries containing 'src' and 'mt' keys. + """ + target_texts = [] + with open(file_path, "r") as f: + for line_num, line in enumerate(f, 1): + if not line: + continue + try: + document = json.loads(line) + doc_id = document["document_id"] + text = document["text"] + + if doc_id not in gold_dict: + logging.warning(f"doc_id {doc_id} not found in gold references.") + continue + + target_texts.append({"src": gold_dict[doc_id], "mt": text}) + except json.JSONDecodeError as e: + logging.warning(f"Skipping invalid line {line_num} in {file_path}: {e}") + continue + return target_texts + + +def evaluate_translations( + data_dir: str, + gold_path: str, + languages: list[str], + batch_size: int, + model_name: str = "Unbabel/wmt22-cometkiwi-da", +) -> None: + """Evaluate translation quality for a set of files using a COMET model. + + Args: + data_dir: Directory containing translation JSONL files. + gold_path: Path to gold reference JSONL file. + languages: List of supported language codes. + model_name: COMET model to use. + """ + model_path = download_model(model_name) + model = load_from_checkpoint(model_path) + + gold_dict = _load_gold_dict(gold_path) + quality_dict = {} + + for filename in os.listdir(data_dir): + if filename.endswith(".jsonl"): + file_path = os.path.join(data_dir, filename) + lang = filename.split("_")[5] + + if lang not in languages: + logging.info(f"Skipping file with unsupported language: {file_path}") + continue + + target_texts = _prepare_translation_input(file_path, gold_dict) + + if target_texts: + # TODO: ;ultiple GPUs handling + model_output = model.predict(target_texts, batch_size=batch_size, gpus=1, accelerator="gpu") + quality_dict[lang] = model_output.scores + logging.info(f"Processed {len(target_texts)} documents for language '{lang}' in file {file_path}") + else: + logging.info(f"No valid documents for language '{lang}' in file {file_path}") + + logging.info("Translation quality scores:") + for lang, scores in quality_dict.items(): + logging.info(f"Mean score for {lang}: {np.mean(scores):.4f}") diff --git a/tests/conftest.py b/tests/conftest.py index 8b25f11a..e6bfb324 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,10 @@ import pandas as pd import pytest -import torch import yaml from omegaconf import OmegaConf -from transformers import AutoConfig, BertForSequenceClassification -from ml_filter.models.annotator_model_head import MultiTargetClassificationHead, MultiTargetRegressionHead -from ml_filter.translate import DeepLClient, OpenAIClient, Translator +from ml_filter.translation.translate import DeepLClient, OpenAIClient, Translator @pytest.fixture diff --git a/tests/test_translate.py b/tests/test_translate.py index 12131019..0883df6a 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -9,7 +9,7 @@ import pytest import yaml -from ml_filter.translate import Translator +from ml_filter.translation.translate import Translator @dataclass From 56070ad46cd89b6c536f680995391e3747134cc5 Mon Sep 17 00:00:00 2001 From: mali-git Date: Sun, 29 Jun 2025 19:55:26 +0200 Subject: [PATCH 2/3] feat: plot quality score distributions --- .../translation/translation_evaluation.py | 149 +++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/src/ml_filter/translation/translation_evaluation.py b/src/ml_filter/translation/translation_evaluation.py index 80b06445..4a827331 100644 --- a/src/ml_filter/translation/translation_evaluation.py +++ b/src/ml_filter/translation/translation_evaluation.py @@ -1,6 +1,7 @@ import json import logging import os +from pathlib import Path import numpy as np from comet import download_model, load_from_checkpoint @@ -81,7 +82,11 @@ def evaluate_translations( for filename in os.listdir(data_dir): if filename.endswith(".jsonl"): file_path = os.path.join(data_dir, filename) - lang = filename.split("_")[5] + parts = filename.split("_") + if len(parts) != 8: + logging.warning(f"Skipping file with unexpected format: {file_path}") + continue + lang = parts[5] if lang not in languages: logging.info(f"Skipping file with unsupported language: {file_path}") @@ -90,7 +95,7 @@ def evaluate_translations( target_texts = _prepare_translation_input(file_path, gold_dict) if target_texts: - # TODO: ;ultiple GPUs handling + # TODO: Multiple GPUs handling model_output = model.predict(target_texts, batch_size=batch_size, gpus=1, accelerator="gpu") quality_dict[lang] = model_output.scores logging.info(f"Processed {len(target_texts)} documents for language '{lang}' in file {file_path}") @@ -100,3 +105,143 @@ def evaluate_translations( logging.info("Translation quality scores:") for lang, scores in quality_dict.items(): logging.info(f"Mean score for {lang}: {np.mean(scores):.4f}") + + +def _plot_translation_scores_histogram_relative_to_gt( + id_to_translation_score: dict[str, str], id_to_gt_quality_score: dict[str, float], lang: str, output_path: str +) -> None: + import matplotlib.pyplot as plt + import numpy as np + + score_classes = ["fine", "minor", "major", "critical"] + + # Step 1: Get all GT scores from the GT dict (not only those that appear in translation dict) + all_gt_scores = sorted(set(float(v) for v in id_to_gt_quality_score.values())) + + # Step 2: Build empty counts for all GT scores + counts = {gt: [0] * len(score_classes) for gt in all_gt_scores} + + # Step 3: Count translation scores + for sample_id, trans_score in id_to_translation_score.items(): + if sample_id in id_to_gt_quality_score: + gt_score = float(id_to_gt_quality_score[sample_id]) + trans_score = trans_score.lower() + if trans_score in score_classes: + idx = score_classes.index(trans_score) + counts[gt_score][idx] += 1 + + # Step 4: Plot + x = np.array(all_gt_scores) + bar_width = 0.12 + fig, ax = plt.subplots() + + for i, score_class in enumerate(score_classes): + offsets = x + i * bar_width + heights = [counts[gt][i] for gt in all_gt_scores] + ax.bar(offsets, heights, width=bar_width, label=score_class.capitalize()) + + # X-ticks in center of bar groups + tick_positions = x + bar_width * (len(score_classes) - 1) / 2 + tick_labels = [str(int(v)) if v.is_integer() else "" for v in x] + + ax.set_xticks(tick_positions) + ax.set_xticklabels(tick_labels) + + ax.set_xlabel("Ground Truth Quality Score") + ax.set_ylabel("Number of Translations") + ax.set_title(f"Translation Quality vs Ground Truth Quality for {lang}") + ax.legend(title="Translation Score Class") + ax.grid(axis="y", alpha=0.5) + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def _plot_translation_scores_histogram(scores: list[str], lang: str, output_path: str) -> None: + """Plot a histogram of translation quality scores and save it to a file. + + Args: + scores: List of quality scores. + lang: Language code for the histogram title. + output_path: Path to save the histogram figure. + """ + from collections import Counter + + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + + # Enforce lowercase and fixed order + score_classes = ["fine", "minor", "major", "critical"] + scores = [s.lower() for s in scores] + counts = Counter(scores) + values = [counts[cls] for cls in score_classes] + + plt.bar(score_classes, values, alpha=0.7) + plt.title(f"Translation Quality Scores for {lang}") + plt.xlabel("Translation Score") + plt.ylabel("Frequency") + plt.grid(axis="y", alpha=0.75) + + # Ensure y-axis ticks are integers + plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True)) + + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def plot_translation_quality_results( + data_dir: Path, + gt_path: Path, + languages: list[str], + output_dir: Path, +) -> None: + """Plot histograms for translation quality results. + + Args: + data_dir: Directory containing translation JSONL files. + languages: List of supported language codes. + output_dir: Directory to save the histogram plots. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + id_to_gt_quality_score = {} + + with open(gt_path, "r") as f: + for line in f: + item = json.loads(line) + id_to_gt_quality_score[item["document_id"]] = float(item["score"]) + + for filename in os.listdir(data_dir): + if filename.endswith(".json"): + parts = filename.split("_") + if len(parts) != 8: + continue + lang = parts[5] + + if lang not in languages: + logging.warning(f"Skipping file with unsupported language: {filename}") + continue + + file_path = os.path.join(data_dir, filename) + id_to_translation_score = {} + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) # Load the full list + for item in data: + id_to_translation_score[item["document_id"]] = item["translation_score"].lower() + + output_path = os.path.join(output_dir, f"{lang}_translation_quality_histogram.png") + _plot_translation_scores_histogram( + scores=list(id_to_translation_score.values()), + lang=lang, + output_path=output_path, + ) + + output_path = os.path.join(output_dir, f"{lang}_translation_quality_vs_gt_histogram.png") + _plot_translation_scores_histogram_relative_to_gt( + id_to_translation_score=id_to_translation_score, + id_to_gt_quality_score=id_to_gt_quality_score, + lang=lang, + output_path=output_path, + ) From f2286f16b89f64cde354abe9356516180c528631 Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 30 Jun 2025 14:38:39 +0200 Subject: [PATCH 3/3] refactor: only translate context length//2 tokens --- src/constants.py | 3 + src/ml_filter/__main__.py | 36 +++- .../translation/translation_evaluation.py | 194 +++++++++++++++--- 3 files changed, 205 insertions(+), 28 deletions(-) diff --git a/src/constants.py b/src/constants.py index d21cf48f..fc50c605 100644 --- a/src/constants.py +++ b/src/constants.py @@ -38,6 +38,7 @@ "ro": "Romanian", "sh": "Serbo-Croation", "sr": "Serbian", + "sr-cyrl": "Serbian (Cyrillic)", "sk": "Slovak", "sl": "Slovenian", "es": "Spanish", @@ -59,3 +60,5 @@ "snowflake/snowflake-arctic-embed-m-v2.0": BertForSequenceClassification, "snowflake/snowflake-arctic-embed-l-v2.0": BertForSequenceClassification, } + +TRANSLATION_SCORE_CLASSES = ["fine", "minor", "major", "critical"] diff --git a/src/ml_filter/__main__.py b/src/ml_filter/__main__.py index 8207407d..d5ab1813 100644 --- a/src/ml_filter/__main__.py +++ b/src/ml_filter/__main__.py @@ -19,7 +19,10 @@ from ml_filter.sample_from_hf_dataset import sample_from_hf_dataset, upload_file_to_hf from ml_filter.training.annotator_model_pipeline import run_annotator_training_pipeline from ml_filter.translation.translate import TranslationServiceType, TranslatorFactory -from ml_filter.translation.translation_evaluation import evaluate_translations +from ml_filter.translation.translation_evaluation import ( + evaluate_translations, + save_human_eval_translation_quality_results, +) from ml_filter.utils.chunk_data import chunk_jsonl from ml_filter.utils.manipulate_datasets import apply_score_transforms, convert_hf_dataset_to_jsonl, split_dataset from ml_filter.utils.manipulate_documents import merge_and_sort_jsonl_files @@ -763,13 +766,17 @@ def _get_target_language_codes_list_helper(target_language_codes: str) -> list[s @click.option("--gold-path", required=True, help="Path to gold reference JSONL file") @click.option("--model-name", default="Unbabel/wmt22-cometkiwi-da", help="COMET model to use") @click.option("--languages", type=str, required=True, help="Comma-separated list of supported language codes") -@click.option("--batch-size", help="Batch size for processing translations") +@click.option("--batch-size", type=int, help="Batch size for processing translations") +@click.option( + "--output-dir", required=True, type=click.Path(file_okay=False), help="Directory to save histogram plots." +) def evaluate_translations_cli( data_dir: str, gold_path: str, model_name: str, languages: str, batch_size: int, + output_dir: str, ): """CLI entry point for evaluating translation quality.""" evaluate_translations( @@ -778,6 +785,31 @@ def evaluate_translations_cli( languages=languages.split(","), model_name=model_name, batch_size=batch_size, + output_dir=Path(output_dir), + ) + + +@main.command(name="plot_translation_quality") +@click.option( + "--data-dir", + required=True, + type=click.Path(exists=True, file_okay=False), + help="Directory containing translation JSON files.", +) +@click.option( + "--gt-path", required=True, type=click.Path(exists=True, dir_okay=False), help="Path to ground truth JSONL file." +) +@click.option("--languages", required=True, type=str, help="Comma-separated list of supported language codes.") +@click.option( + "--output-dir", required=True, type=click.Path(file_okay=False), help="Directory to save histogram plots." +) +def plot_translation_quality_cli(data_dir: str, gt_path: str, languages: str, output_dir: str): + """CLI entry point to plot translation quality histograms.""" + save_human_eval_translation_quality_results( + data_dir=Path(data_dir), + gt_path=Path(gt_path), + languages=[lang.strip() for lang in languages.split(",")], + output_dir=Path(output_dir), ) diff --git a/src/ml_filter/translation/translation_evaluation.py b/src/ml_filter/translation/translation_evaluation.py index 4a827331..0bd9143b 100644 --- a/src/ml_filter/translation/translation_evaluation.py +++ b/src/ml_filter/translation/translation_evaluation.py @@ -1,3 +1,4 @@ +import csv import json import logging import os @@ -5,6 +6,9 @@ import numpy as np from comet import download_model, load_from_checkpoint +from transformers import AutoTokenizer + +from constants import EUROPEAN_LANGUAGES, TRANSLATION_SCORE_CLASSES logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") @@ -26,7 +30,12 @@ def _load_gold_dict(gold_path: str) -> dict[str, str]: return gold_dict -def _prepare_translation_input(file_path: str, gold_dict: dict[str, str]) -> list[dict[str, str]]: +def _prepare_translation_input( + file_path: str, + gold_dict: dict[str, str], + tokenizer_name_or_path: str, + max_tokens_per_input: int, +) -> list[dict[str, str]]: """Extract source and machine-translated texts from a JSONL file. Args: @@ -37,6 +46,10 @@ def _prepare_translation_input(file_path: str, gold_dict: dict[str, str]) -> lis Returns: A list of dictionaries containing 'src' and 'mt' keys. """ + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + max_src_tokens = max_mt_tokens = max_tokens_per_input // 2 + target_texts = [] with open(file_path, "r") as f: for line_num, line in enumerate(f, 1): @@ -51,7 +64,17 @@ def _prepare_translation_input(file_path: str, gold_dict: dict[str, str]) -> lis logging.warning(f"doc_id {doc_id} not found in gold references.") continue - target_texts.append({"src": gold_dict[doc_id], "mt": text}) + # Tokenize and truncate source and MT + src_ids = tokenizer.encode( + gold_dict[doc_id], truncation=True, max_length=max_src_tokens, add_special_tokens=False + ) + mt_ids = tokenizer.encode(text, truncation=True, max_length=max_mt_tokens, add_special_tokens=False) + + # Decode back to text + src_trunc = tokenizer.decode(src_ids, skip_special_tokens=True) + mt_trunc = tokenizer.decode(mt_ids, skip_special_tokens=True) + + target_texts.append({"src": src_trunc, "mt": mt_trunc}) except json.JSONDecodeError as e: logging.warning(f"Skipping invalid line {line_num} in {file_path}: {e}") continue @@ -63,6 +86,7 @@ def evaluate_translations( gold_path: str, languages: list[str], batch_size: int, + output_dir: Path, model_name: str = "Unbabel/wmt22-cometkiwi-da", ) -> None: """Evaluate translation quality for a set of files using a COMET model. @@ -75,6 +99,8 @@ def evaluate_translations( """ model_path = download_model(model_name) model = load_from_checkpoint(model_path) + tokenizer_name_or_path = model.encoder.tokenizer.name_or_path + max_seq_length = model.encoder.max_positions gold_dict = _load_gold_dict(gold_path) quality_dict = {} @@ -83,7 +109,7 @@ def evaluate_translations( if filename.endswith(".jsonl"): file_path = os.path.join(data_dir, filename) parts = filename.split("_") - if len(parts) != 8: + if len(parts) != 7: logging.warning(f"Skipping file with unexpected format: {file_path}") continue lang = parts[5] @@ -92,7 +118,12 @@ def evaluate_translations( logging.info(f"Skipping file with unsupported language: {file_path}") continue - target_texts = _prepare_translation_input(file_path, gold_dict) + target_texts = _prepare_translation_input( + file_path, + gold_dict, + tokenizer_name_or_path=tokenizer_name_or_path, + max_tokens_per_input=max_seq_length, + ) if target_texts: # TODO: Multiple GPUs handling @@ -102,13 +133,44 @@ def evaluate_translations( else: logging.info(f"No valid documents for language '{lang}' in file {file_path}") - logging.info("Translation quality scores:") - for lang, scores in quality_dict.items(): - logging.info(f"Mean score for {lang}: {np.mean(scores):.4f}") + output_path = os.path.join(output_dir, "translation_quality_results.csv") + _save_to_csv(quality_dict, Path(output_path)) + + +def _save_to_csv(quality_dict: dict[str, list[float]], output_path: Path) -> None: + """Save translation quality statistics to a CSV file. + + Args: + quality_dict: Dictionary mapping language code to list of quality scores. + output_path: Path to save the CSV file. + """ + with open(output_path, mode="w", newline="") as csv_file: + writer = csv.writer(csv_file) + writer.writerow( + ["language", "num_documents", "mean_score", "median_score", "q25_score", "q75_score", "q100_score"] + ) + + for lang, scores in quality_dict.items(): + scores_np = np.array(scores) + writer.writerow( + [ + EUROPEAN_LANGUAGES[lang], + len(scores), + f"{np.mean(scores_np):.4f}", + f"{np.median(scores_np):.4f}", + f"{np.quantile(scores_np, 0.25):.4f}", + f"{np.quantile(scores_np, 0.75):.4f}", + f"{np.max(scores_np):.4f}", + ] + ) def _plot_translation_scores_histogram_relative_to_gt( - id_to_translation_score: dict[str, str], id_to_gt_quality_score: dict[str, float], lang: str, output_path: str + id_to_translation_score: dict[str, str], + id_to_gt_quality_score: dict[str, float], + lang: str, + output_path: str, + counts: dict[float, list[int]], ) -> None: import matplotlib.pyplot as plt import numpy as np @@ -118,18 +180,6 @@ def _plot_translation_scores_histogram_relative_to_gt( # Step 1: Get all GT scores from the GT dict (not only those that appear in translation dict) all_gt_scores = sorted(set(float(v) for v in id_to_gt_quality_score.values())) - # Step 2: Build empty counts for all GT scores - counts = {gt: [0] * len(score_classes) for gt in all_gt_scores} - - # Step 3: Count translation scores - for sample_id, trans_score in id_to_translation_score.items(): - if sample_id in id_to_gt_quality_score: - gt_score = float(id_to_gt_quality_score[sample_id]) - trans_score = trans_score.lower() - if trans_score in score_classes: - idx = score_classes.index(trans_score) - counts[gt_score][idx] += 1 - # Step 4: Plot x = np.array(all_gt_scores) bar_width = 0.12 @@ -147,16 +197,50 @@ def _plot_translation_scores_histogram_relative_to_gt( ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels) - ax.set_xlabel("Ground Truth Quality Score") + ax.set_xlabel("Ground Truth Document Quality Score") ax.set_ylabel("Number of Translations") - ax.set_title(f"Translation Quality vs Ground Truth Quality for {lang}") - ax.legend(title="Translation Score Class") + ax.set_title(f"Translation Quality vs Ground Truth Document Quality for {EUROPEAN_LANGUAGES[lang]}") + ax.legend(title="Translation Quality") ax.grid(axis="y", alpha=0.5) plt.tight_layout() plt.savefig(output_path) plt.close() +def _compute_score_distribution( + id_to_translation_score: dict[str, str], + id_to_gt_quality_score: dict[str, float], +) -> dict[float, list[int]]: + """ + Compute a distribution matrix of translation scores per GT quality score. + + Args: + id_to_translation_score: Mapping of document ID to translation score label. + id_to_gt_quality_score: Mapping of document ID to GT quality score (float). + score_classes: List of allowed translation score class labels in desired order. + + Returns: + A dict mapping GT quality score (float) to a list of counts aligned with score_classes. + """ + + # Step 1: Get all GT scores from the GT dict (not only those that appear in translation dict) + all_gt_scores = sorted(set(float(v) for v in id_to_gt_quality_score.values())) + + # Step 2: Build empty counts for all GT scores + counts = {gt: [0] * len(TRANSLATION_SCORE_CLASSES) for gt in all_gt_scores} + + # Step 3: Count translation scores + for sample_id, trans_score in id_to_translation_score.items(): + if sample_id in id_to_gt_quality_score: + gt_score = float(id_to_gt_quality_score[sample_id]) + trans_score = trans_score.lower() + if trans_score in TRANSLATION_SCORE_CLASSES: + idx = TRANSLATION_SCORE_CLASSES.index(trans_score) + counts[gt_score][idx] += 1 + + return counts + + def _plot_translation_scores_histogram(scores: list[str], lang: str, output_path: str) -> None: """Plot a histogram of translation quality scores and save it to a file. @@ -177,8 +261,8 @@ def _plot_translation_scores_histogram(scores: list[str], lang: str, output_path values = [counts[cls] for cls in score_classes] plt.bar(score_classes, values, alpha=0.7) - plt.title(f"Translation Quality Scores for {lang}") - plt.xlabel("Translation Score") + plt.title(f"Translation Quality Scores for {EUROPEAN_LANGUAGES[lang]}") + plt.xlabel("Translation Quality") plt.ylabel("Frequency") plt.grid(axis="y", alpha=0.75) @@ -190,7 +274,7 @@ def _plot_translation_scores_histogram(scores: list[str], lang: str, output_path plt.close() -def plot_translation_quality_results( +def save_human_eval_translation_quality_results( data_dir: Path, gt_path: Path, languages: list[str], @@ -207,6 +291,8 @@ def plot_translation_quality_results( os.makedirs(output_dir) id_to_gt_quality_score = {} + lang_to_eval_stats = {} + lang_to_counts = {} with open(gt_path, "r") as f: for line in f: @@ -239,9 +325,65 @@ def plot_translation_quality_results( ) output_path = os.path.join(output_dir, f"{lang}_translation_quality_vs_gt_histogram.png") + + counts = _compute_score_distribution( + id_to_translation_score=id_to_translation_score, + id_to_gt_quality_score=id_to_gt_quality_score, + ) + lang_to_counts[lang] = counts + _plot_translation_scores_histogram_relative_to_gt( id_to_translation_score=id_to_translation_score, id_to_gt_quality_score=id_to_gt_quality_score, lang=lang, output_path=output_path, + counts=counts, + ) + lang_to_eval_stats[lang] = { + "num_documents": len(id_to_translation_score), + "Fine": list(id_to_translation_score.values()).count("fine"), + "Minor": list(id_to_translation_score.values()).count("minor"), + "Major": list(id_to_translation_score.values()).count("major"), + "Critical": list(id_to_translation_score.values()).count("critical"), + } + _save_language_eval_stats( + lang_to_eval_stats=lang_to_eval_stats, + output_path=os.path.join(output_dir, "language_eval_stats.csv"), + ) + + _save_detailed_score_distribution( + lang_to_counts=lang_to_counts, + output_path=os.path.join(output_dir, "detailed_score_distribution.csv"), + ) + + +def _save_language_eval_stats(lang_to_eval_stats: dict[str, dict], output_path: str) -> None: + import csv + + with open(output_path, mode="w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["language", "num_documents", "Fine", "Minor", "Major", "Critical"]) + + for lang, stats in lang_to_eval_stats.items(): + writer.writerow( + [ + lang, + stats.get("num_documents", 0), + stats.get("Fine", 0), + stats.get("Minor", 0), + stats.get("Major", 0), + stats.get("Critical", 0), + ] ) + + +def _save_detailed_score_distribution(lang_to_counts: dict[str, dict[float, list[int]]], output_path: str) -> None: + import csv + + with open(output_path, mode="w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["language", "gt_score"] + TRANSLATION_SCORE_CLASSES) + + for lang, counts in lang_to_counts.items(): + for gt_score, count_list in sorted(counts.items()): + writer.writerow([lang, f"{gt_score:.2f}"] + count_list)