From 2600a0ee6eb920f162d29bafd1be253a8f8ff000 Mon Sep 17 00:00:00 2001 From: amazloumi Date: Tue, 26 May 2026 19:47:50 -0400 Subject: [PATCH] Add configurable data_seed independent of training seed --- kempnerforge/config/training.py | 15 ++++++++++++++- scripts/train.py | 10 +++++----- tests/unit/test_config.py | 12 ++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/kempnerforge/config/training.py b/kempnerforge/config/training.py index c19ccff..bb621ff 100644 --- a/kempnerforge/config/training.py +++ b/kempnerforge/config/training.py @@ -25,7 +25,8 @@ class TrainConfig: max_steps: int = 100000 grad_accum_steps: int = 1 grad_clip_norm: float = 1.0 - seed: int = 42 + seed: int = 42 # Seeds parameter init, dropout, and (by default) data order + data_seed: int | None = None # Overrides data/batch order only; None falls back to seed compile_model: bool = True mixed_precision: Literal["bf16", "fp16", "fp32", "fp8"] = "bf16" activation_checkpointing: ActivationCheckpointing = ActivationCheckpointing.none @@ -67,3 +68,15 @@ def param_dtype(self) -> torch.dtype: def is_fp8(self) -> bool: """Whether FP8 mixed precision is enabled.""" return self.mixed_precision == "fp8" + + @property + def effective_data_seed(self) -> int: + """Seed for data shuffling / batch composition. + + Falls back to ``seed`` when ``data_seed`` is unset, so existing configs + reproduce their current trajectory. Kept independent from ``seed`` + (parameter init) so stability studies can vary batch order while holding + initialization fixed. Must stay identical across data-parallel ranks so + the global shuffle is consistent before rank partitioning. + """ + return self.data_seed if self.data_seed is not None else self.seed diff --git a/scripts/train.py b/scripts/train.py index 1090331..e0a8dde 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -319,7 +319,7 @@ def main() -> None: _pad_id = _tok.eos_token_id if _tok.eos_token_id is not None else 0 collator = VLMCollator(pad_id=int(_pad_id), max_text_len=vlm_cfg.max_text_len) sampler = DistributedSampler( - dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True, seed=tc.seed + dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True, seed=tc.effective_data_seed ) dataloader = StatefulDataLoader( dataset, @@ -373,7 +373,7 @@ def main() -> None: num_replicas=dp_size, rank=dp_rank, shuffle=True, - seed=tc.seed, + seed=tc.effective_data_seed, temperature=config.data.mix_temperature, ) dataloader = StatefulDataLoader( @@ -401,7 +401,7 @@ def main() -> None: num_replicas=dp_size, rank=dp_rank, shuffle=True, - seed=tc.seed, + seed=tc.effective_data_seed, ) dataloader = StatefulDataLoader( dataset, @@ -429,7 +429,7 @@ def main() -> None: dataset_config=config.data.hf_dataset_config, rank=dp_rank, world_size=dp_size, - seed=tc.seed, + seed=tc.effective_data_seed, pack_sequences=config.data.pack_sequences, ) dataloader = TorchDataLoader( @@ -463,7 +463,7 @@ def main() -> None: num_replicas=dp_size, rank=dp_rank, shuffle=True, - seed=tc.seed, + seed=tc.effective_data_seed, ) dataloader = StatefulDataLoader( dataset, diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 0b5a3f0..749b455 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -169,6 +169,18 @@ def test_rejects_zero_grad_accum(self): with pytest.raises(ValueError, match="grad_accum_steps must be positive"): TrainConfig(grad_accum_steps=0) + def test_data_seed_defaults_to_none(self): + assert TrainConfig().data_seed is None + + def test_effective_data_seed_falls_back_to_seed(self): + t = TrainConfig(seed=123) + assert t.effective_data_seed == 123 + + def test_effective_data_seed_overrides_when_set(self): + t = TrainConfig(seed=123, data_seed=456) + assert t.effective_data_seed == 456 + assert t.seed == 123 # init seed stays independent + # --------------------------------------------------------------------------- # OptimizerConfig