Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion kempnerforge/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class TrainConfig:
max_steps: int = 100000
grad_accum_steps: int = 1
grad_clip_norm: float = 1.0
seed: int = 42
seed: int = 42 # Seeds parameter init, dropout, and (by default) data order
data_seed: int | None = None # Overrides data/batch order only; None falls back to seed
compile_model: bool = True
mixed_precision: Literal["bf16", "fp16", "fp32", "fp8"] = "bf16"
activation_checkpointing: ActivationCheckpointing = ActivationCheckpointing.none
Expand Down Expand Up @@ -67,3 +68,15 @@ def param_dtype(self) -> torch.dtype:
def is_fp8(self) -> bool:
"""Whether FP8 mixed precision is enabled."""
return self.mixed_precision == "fp8"

@property
def effective_data_seed(self) -> int:
"""Seed for data shuffling / batch composition.

Falls back to ``seed`` when ``data_seed`` is unset, so existing configs
reproduce their current trajectory. Kept independent from ``seed``
(parameter init) so stability studies can vary batch order while holding
initialization fixed. Must stay identical across data-parallel ranks so
the global shuffle is consistent before rank partitioning.
"""
return self.data_seed if self.data_seed is not None else self.seed
10 changes: 5 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def main() -> None:
_pad_id = _tok.eos_token_id if _tok.eos_token_id is not None else 0
collator = VLMCollator(pad_id=int(_pad_id), max_text_len=vlm_cfg.max_text_len)
sampler = DistributedSampler(
dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True, seed=tc.seed
dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True, seed=tc.effective_data_seed
)
dataloader = StatefulDataLoader(
dataset,
Expand Down Expand Up @@ -373,7 +373,7 @@ def main() -> None:
num_replicas=dp_size,
rank=dp_rank,
shuffle=True,
seed=tc.seed,
seed=tc.effective_data_seed,
temperature=config.data.mix_temperature,
)
dataloader = StatefulDataLoader(
Expand Down Expand Up @@ -401,7 +401,7 @@ def main() -> None:
num_replicas=dp_size,
rank=dp_rank,
shuffle=True,
seed=tc.seed,
seed=tc.effective_data_seed,
)
dataloader = StatefulDataLoader(
dataset,
Expand Down Expand Up @@ -429,7 +429,7 @@ def main() -> None:
dataset_config=config.data.hf_dataset_config,
rank=dp_rank,
world_size=dp_size,
seed=tc.seed,
seed=tc.effective_data_seed,
pack_sequences=config.data.pack_sequences,
)
dataloader = TorchDataLoader(
Expand Down Expand Up @@ -463,7 +463,7 @@ def main() -> None:
num_replicas=dp_size,
rank=dp_rank,
shuffle=True,
seed=tc.seed,
seed=tc.effective_data_seed,
)
dataloader = StatefulDataLoader(
dataset,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ def test_rejects_zero_grad_accum(self):
with pytest.raises(ValueError, match="grad_accum_steps must be positive"):
TrainConfig(grad_accum_steps=0)

def test_data_seed_defaults_to_none(self):
assert TrainConfig().data_seed is None

def test_effective_data_seed_falls_back_to_seed(self):
t = TrainConfig(seed=123)
assert t.effective_data_seed == 123

def test_effective_data_seed_overrides_when_set(self):
t = TrainConfig(seed=123, data_seed=456)
assert t.effective_data_seed == 456
assert t.seed == 123 # init seed stays independent


# ---------------------------------------------------------------------------
# OptimizerConfig
Expand Down
Loading