From a1ddeb61c44cdd3c855c74d63adae78625a7743a Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 28 Apr 2026 22:03:41 +0200 Subject: [PATCH 1/2] fix: forward warmup_ratio correctly to TrainingArguments warmup_ratio (float 0.03) was passed to warmup_steps (int), truncating to 0 and silently disabling warmup entirely. Add a warmup_steps config field (default 0) and forward both fields to TrainingArguments so each reaches the correct parameter. --- src/post_training/config.py | 1 + src/post_training/methods/common.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/post_training/config.py b/src/post_training/config.py index a0bde16..f36e00a 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -62,6 +62,7 @@ class TrainingConfig: effective_batch_size: int = 512 per_device_train_batch_size: int = 4 warmup_ratio: float = 0.03 + warmup_steps: int = 0 lr_scheduler_type: str = "cosine_with_min_lr" lr_scheduler_kwargs: LRSchedulerKwargs = field(default_factory=LRSchedulerKwargs) adam_beta1: float = 0.9 diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..bef90ea 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -102,7 +102,8 @@ def build_common_training_kwargs( weight_decay=t.weight_decay, adam_epsilon=t.adam_epsilon, gradient_accumulation_steps=grad_accum, - warmup_steps=t.warmup_ratio, + warmup_steps=t.warmup_steps, + warmup_ratio=t.warmup_ratio, lr_scheduler_type=t.lr_scheduler_type, lr_scheduler_kwargs={ k: v for k, v in dataclasses.asdict(t.lr_scheduler_kwargs).items() if v is not None From 396671419f387dfe3a533b769bc3591b83e1a5f8 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 14:25:14 +0200 Subject: [PATCH 2/2] refactor: consolidate warmup_ratio and warmup_steps into a single float field Drop warmup_ratio; warmup_steps is now a float where values in [0, 1) are treated as a ratio and values >= 1 as absolute steps, matching HF's TrainingArguments behaviour. Adds a backward-compat shim that auto-migrates warmup_ratio in old YAMLs with a deprecation warning. --- README.md | 2 +- configs/trl/dpo.yaml | 2 +- configs/trl/sft.yaml | 2 +- src/post_training/config.py | 28 +++++++++++++++++++++++++-- src/post_training/methods/common.py | 1 - src/post_training/utils/guardrails.py | 2 +- 6 files changed, 30 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6ec9ede..fdb17fd 100644 --- a/README.md +++ b/README.md @@ -408,7 +408,7 @@ training: learning_rate: 2.0e-5 effective_batch_size: 32 # per_device * grad_accum * world_size per_device_train_batch_size: 8 - warmup_ratio: 0.03 + warmup_steps: 0.03 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/configs/trl/dpo.yaml b/configs/trl/dpo.yaml index b388e0d..dde3263 100644 --- a/configs/trl/dpo.yaml +++ b/configs/trl/dpo.yaml @@ -35,7 +35,7 @@ training: learning_rate: 5.0e-7 effective_batch_size: 4 per_device_train_batch_size: 1 - warmup_ratio: 0.1 + warmup_steps: 0.1 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/configs/trl/sft.yaml b/configs/trl/sft.yaml index 5829762..32a7891 100644 --- a/configs/trl/sft.yaml +++ b/configs/trl/sft.yaml @@ -45,7 +45,7 @@ training: learning_rate: 2.0e-5 effective_batch_size: 32 # per_device * grad_accum * world_size per_device_train_batch_size: 8 - warmup_ratio: 0.03 + warmup_steps: 0.03 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/src/post_training/config.py b/src/post_training/config.py index f36e00a..83e8dfb 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -6,6 +6,7 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -13,6 +14,8 @@ import yaml from omegaconf import MISSING, DictConfig, OmegaConf +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- # Sub-configs # --------------------------------------------------------------------------- @@ -61,8 +64,16 @@ class TrainingConfig: learning_rate: float = 2.0e-5 effective_batch_size: int = 512 per_device_train_batch_size: int = 4 - warmup_ratio: float = 0.03 - warmup_steps: int = 0 + warmup_steps: float = field( + default=0.0, + metadata={ + "help": ( + "Linear warmup duration. Values in [0, 1) are interpreted as a " + "ratio of total training steps; values >= 1 are interpreted as an " + "absolute number of steps; 0 disables warmup." + ) + }, + ) lr_scheduler_type: str = "cosine_with_min_lr" lr_scheduler_kwargs: LRSchedulerKwargs = field(default_factory=LRSchedulerKwargs) adam_beta1: float = 0.9 @@ -269,6 +280,19 @@ def load( """ schema = OmegaConf.structured(cls) file_cfg = OmegaConf.load(yaml_path) + + # Migrate deprecated training.warmup_ratio -> training.warmup_steps + file_dict = OmegaConf.to_container(file_cfg, resolve=False) + if isinstance(file_dict, dict): + training_dict = file_dict.get("training", {}) + if isinstance(training_dict, dict) and "warmup_ratio" in training_dict: + logger.warning( + "training.warmup_ratio is deprecated; use training.warmup_steps " + "(values < 1 are interpreted as a ratio). Auto-migrating." + ) + training_dict.setdefault("warmup_steps", training_dict.pop("warmup_ratio")) + file_cfg = OmegaConf.create(file_dict) + merged: DictConfig = OmegaConf.merge(schema, file_cfg) if cli_overrides: diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index bef90ea..af995cc 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -103,7 +103,6 @@ def build_common_training_kwargs( adam_epsilon=t.adam_epsilon, gradient_accumulation_steps=grad_accum, warmup_steps=t.warmup_steps, - warmup_ratio=t.warmup_ratio, lr_scheduler_type=t.lr_scheduler_type, lr_scheduler_kwargs={ k: v for k, v in dataclasses.asdict(t.lr_scheduler_kwargs).items() if v is not None diff --git a/src/post_training/utils/guardrails.py b/src/post_training/utils/guardrails.py index 08db38f..d54be25 100644 --- a/src/post_training/utils/guardrails.py +++ b/src/post_training/utils/guardrails.py @@ -207,7 +207,7 @@ def run_guardrails(config: PostTrainingConfig, run_dir: Path, tokenize_only: boo min_lr = config.training.lr_scheduler_kwargs.min_lr_rate lr_sched_str = lr_sched if min_lr is None else f"{lr_sched} (min_lr_rate={min_lr})" _row("LR scheduler", lr_sched_str) - _row("Warmup ratio", str(config.training.warmup_ratio)) + _row("Warmup steps", str(config.training.warmup_steps)) batch_line, _ = _batch_summary(config, total_gpus) _row("Batch sizes", batch_line) _row("Grad checkpoint", str(config.training.gradient_checkpointing))