diff --git a/CHANGELOG.md b/CHANGELOG.md index 1746bda..3ea5fdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Configs: `configs/train/vlm_debug_mot.toml` (1-GPU smoke) and `configs/train/vlm_7b_mot.toml` (4-GPU 7B). - `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning. - `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched. +- **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves). + - `kempnerforge/config/checkpoint.py`: new `DynamicCheckpointWindow` dataclass (`start: int = 0`, `stop: int = 512`, `strategy: str = "power2"`), new `CheckpointConfig.dyn_ckpt_window: DynamicCheckpointWindow | None = None`, `should_save(step)`, and `is_dynamic_milestone(step)`. Ships with `"power2"` registered by default; new strategies plug in via the registry without touching `CheckpointConfig`. + - `kempnerforge/config/registry.py`: `register_dyn_ckpt_strategy(name)` / `get_dyn_ckpt_strategy(name)` / `list_dyn_ckpt_strategies()` — a strategy is any `Callable[[DynamicCheckpointWindow, int], bool]` and registers via `@registry.register_dyn_ckpt_strategy("name")`. + - Milestone-aware retention: `CheckpointManager._cleanup` never prunes a step where the configured dynamic strategy fired, so `keep_last_n` rotates only the later interval checkpoints. `keep_last_n <= 0` keeps everything (the previous `keep_last_n >= 1` requirement is relaxed). + - `scripts/train.py`: the save gate now calls `config.checkpoint.should_save(step)`. + - Tests: `tests/unit/test_config.py` (defaults, power2 firing, offset-based `start > 0`, validation, unknown-strategy rejection, `is_dynamic_milestone`), `tests/unit/test_checkpoint.py::TestCheckpointRetention::test_cleanup_protects_dynamic_milestones`. ### Changed - `docs/getting-started/install.md` Prerequisites: documents `.python-version` and uv's auto-fetch behavior. diff --git a/docs/checkpointing/index.md b/docs/checkpointing/index.md index ed52024..adf4ac7 100644 --- a/docs/checkpointing/index.md +++ b/docs/checkpointing/index.md @@ -51,4 +51,4 @@ contains two kinds of state: [HF conversion](hf-conversion.md). - **Config knobs**: [Configuration § CheckpointConfig](../configuration/config-sections.md) - (search for `interval`, `async_mode`, `keep_last_n`, `load_path`). + (search for `interval`, `dyn_ckpt_window`, `async_mode`, `keep_last_n`, `load_path`). diff --git a/docs/configuration/config-sections.md b/docs/configuration/config-sections.md index f0ba1d5..13aa339 100644 --- a/docs/configuration/config-sections.md +++ b/docs/configuration/config-sections.md @@ -236,13 +236,62 @@ DCP-based checkpointing. | Field | Type | Default | Purpose | |-------|------|---------|---------| | `dir` | `str` | `"checkpoints"` | root directory for `step_N/` + `latest` symlink | -| `interval` | `int` | `1000` | save every N steps | +| `interval` | `int` | `1000` | save every N steps (outside any `dyn_ckpt_window`) | +| `dyn_ckpt_window` | `DynamicCheckpointWindow \| None` | `None` | optional dense save phase — see [`[checkpoint.dyn_ckpt_window]`](#checkpointdyn_ckpt_window--dynamiccheckpointwindow) | | `async_mode` | `"disabled" \| "async" \| "async_with_pinned_mem"` | `"disabled"` | DCP async-save mode | -| `keep_last_n` | `int` | `3` | retain the most recent N checkpoints | +| `keep_last_n` | `int` | `3` | retain the most recent N checkpoints (`<= 0` keeps all); steps saved by `dyn_ckpt_window` are always kept | | `load_path` | `str \| None` | `None` | explicit resume path (overrides `latest` symlink) | | `export_dtype` | `"float32" \| "bfloat16"` | `"bfloat16"` | dtype for HF exports via `scripts/convert_checkpoint.py` | | `exclude_from_loading` | `list[str]` | `[]` | FQN prefixes to skip on load (e.g. to reinit a head) | +### `[checkpoint.dyn_ckpt_window]` — `DynamicCheckpointWindow` + +Opt-in dense save phase. Inside `[start, stop]` a registered strategy decides +which steps to save; outside the window the regular `interval` cadence applies. +Steps saved by the strategy are exempt from `keep_last_n` retention, so a +finite `keep_last_n` rotates only the later interval checkpoints — the dense +window checkpoints are never pruned. Useful for inspecting early-training +dynamics at fine granularity. + +| Field | Type | Default | Purpose | +|-------|------|---------|---------| +| `start` | `int` | `0` | first step of the dense window; `0` captures the initial weights before any training step | +| `stop` | `int` | `512` | last step of the dense window | +| `strategy` | `str` | `"power2"` | name of a registered strategy (see below) | + +Example TOML: + +```toml +# Off by default — interval cadence only: +[checkpoint] +interval = 1000 +keep_last_n = 3 + +# Opt in with defaults (start=0, stop=512, strategy="power2"): +[checkpoint.dyn_ckpt_window] + +# Customize: +[checkpoint.dyn_ckpt_window] +start = 0 +stop = 1024 +strategy = "power2" +``` + +#### Dyn_ckpt strategies + +A strategy is a `Callable[[DynamicCheckpointWindow, int], bool]` registered via +`@registry.register_dyn_ckpt_strategy("name")` (see +[`kempnerforge/config/registry.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/registry.py)). +The strategy receives the window and the current training step and returns +`True` iff the step should be saved. The registry must be populated before +config load — either by `kempnerforge` itself (the default `"power2"` is +registered in `kempnerforge/config/checkpoint.py`) or by an explicit `import` +of the user's module. + +| Name | Behavior | +|------|----------| +| `"power2"` (default) | save at `start` and at every `start + 2^k` while `<= stop` (offset-based powers of two; tight near `start`, doubling thereafter) | + ## `[metrics]` — `MetricsConfig` Logging cadence and backend toggles. diff --git a/docs/configuration/validation-rules.md b/docs/configuration/validation-rules.md index 24d385f..ea6d6cb 100644 --- a/docs/configuration/validation-rules.md +++ b/docs/configuration/validation-rules.md @@ -120,7 +120,12 @@ File: [`kempnerforge/config/checkpoint.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/checkpoint.py). - `interval > 0`. -- `keep_last_n ≥ 1`. +- If `dyn_ckpt_window` is set: + - `dyn_ckpt_window.start >= 0`. + - `dyn_ckpt_window.stop >= dyn_ckpt_window.start`. + - `dyn_ckpt_window.strategy` must name a registered dyn_ckpt strategy + (`"power2"` ships by default; see + [`kempnerforge/config/registry.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/registry.py)). ### `MetricsConfig` diff --git a/kempnerforge/checkpoint/manager.py b/kempnerforge/checkpoint/manager.py index 10bc07a..35ecc93 100644 --- a/kempnerforge/checkpoint/manager.py +++ b/kempnerforge/checkpoint/manager.py @@ -608,11 +608,12 @@ def _resolve_dcp_load_dir(self, resolved: Path, path: str | None) -> Path: def _cleanup(self) -> None: """Remove old checkpoints beyond the retention limit. - Two directories are never removed regardless of retention: the - current ``latest`` target and the in-flight async checkpoint - (``_pending_finalize``). Pruning either would let a crash strand - resume with no loadable checkpoint — the exact failure this fix - exists to prevent. + Never removed regardless of retention: the current ``latest`` target + and the in-flight async checkpoint (``_pending_finalize``) — pruning + either would let a crash strand resume with no loadable checkpoint — + and any dynamic-window milestone (``CheckpointConfig.is_dynamic_milestone``), + so the dense early-dynamics checkpoints survive even with a finite + ``keep_last_n``. """ keep = self.config.keep_last_n if keep <= 0: @@ -631,6 +632,15 @@ def _cleanup(self) -> None: if self._pending_finalize is not None: protected.add(self._pending_finalize[1].resolve()) + # Dynamic-window milestones (whatever steps the configured + # dyn_ckpt_window strategy fires on) are the dense early-dynamics + # checkpoints the window exists to capture, so never prune them; + # retention then applies only to the later interval checkpoints. + # No-op when no dyn_ckpt_window is configured. + for d in ckpt_dirs: + if self.config.is_dynamic_milestone(int(d.name.split("_")[1])): + protected.add(d.resolve()) + # Remove oldest beyond retention, but never a protected dir. to_remove = ckpt_dirs[:-keep] if len(ckpt_dirs) > keep else [] for d in to_remove: diff --git a/kempnerforge/config/checkpoint.py b/kempnerforge/config/checkpoint.py index be2bf28..727b5b8 100644 --- a/kempnerforge/config/checkpoint.py +++ b/kempnerforge/config/checkpoint.py @@ -6,6 +6,8 @@ from enum import StrEnum from typing import Literal +from kempnerforge.config.registry import registry + class AsyncCheckpointMode(StrEnum): disabled = "disabled" @@ -13,14 +15,58 @@ class AsyncCheckpointMode(StrEnum): async_pinned = "async_with_pinned_mem" +@dataclass +class DynamicCheckpointWindow: + """A bounded step range with a registered checkpoint strategy. + + Inside ``[start, stop]`` the strategy decides which steps to save, and every + such step is exempt from ``CheckpointConfig.keep_last_n`` retention. Outside + the window the regular ``CheckpointConfig.interval`` cadence applies. + + ``"power2"`` (default) saves at ``start`` and at every ``start + 2^k`` while + ``<= stop`` -- tight at the start of the window, doubling thereafter. New + strategies register via ``@registry.register_dyn_ckpt_strategy(name)`` and + become selectable by setting ``strategy``. + """ + + start: int = 0 # 0 = capture initial weights before any training step + stop: int = 512 + strategy: str = "power2" + + def __post_init__(self) -> None: + if self.start < 0: + raise ValueError("dyn_ckpt_window.start must be >= 0") + if self.stop < self.start: + raise ValueError("dyn_ckpt_window.stop must be >= start") + known = registry.list_dyn_ckpt_strategies() + if self.strategy not in known: + raise ValueError( + f"unknown dyn_ckpt_window.strategy {self.strategy!r}; registered: {known}" + ) + + def is_milestone(self, step: int) -> bool: + """True iff the configured strategy fires at ``step``.""" + return registry.get_dyn_ckpt_strategy(self.strategy)(self, step) + + +@registry.register_dyn_ckpt_strategy("power2") +def _power2_strategy(window: DynamicCheckpointWindow, step: int) -> bool: + """Save at ``start`` and at every ``start + 2^k`` while ``<= stop``.""" + if step < window.start or step > window.stop: + return False + offset = step - window.start + return offset == 0 or (offset & (offset - 1)) == 0 + + @dataclass class CheckpointConfig: """Checkpointing settings.""" dir: str = "checkpoints" - interval: int = 1000 # Save every N steps + interval: int = 1000 # save every N steps; outside any dyn_ckpt_window + dyn_ckpt_window: DynamicCheckpointWindow | None = None # opt-in dense window async_mode: AsyncCheckpointMode = AsyncCheckpointMode.disabled - keep_last_n: int = 3 # Number of checkpoints to retain + keep_last_n: int = 3 # recent ckpts kept (<=0 keeps all); dynamic milestones always kept load_path: str | None = None # Path to load from (for resumption) export_dtype: Literal["float32", "bfloat16"] = "bfloat16" exclude_from_loading: list[str] = field(default_factory=list) @@ -33,5 +79,25 @@ class CheckpointConfig: def __post_init__(self) -> None: if self.interval <= 0: raise ValueError("interval must be positive") - if self.keep_last_n < 1: - raise ValueError("keep_last_n must be >= 1") + + def should_save(self, step: int) -> bool: + """Whether to write a checkpoint at ``step``. + + Inside ``dyn_ckpt_window``: the registered strategy decides (default + ``"power2"`` saves at ``start`` and each ``start + 2^k`` while + ``<= stop``). Outside the window: every ``interval`` steps. Dynamic + milestones are exempt from ``keep_last_n`` (see + ``CheckpointManager._cleanup``). + """ + w = self.dyn_ckpt_window + if w is not None and w.start <= step <= w.stop: + return w.is_milestone(step) + return step % self.interval == 0 + + def is_dynamic_milestone(self, step: int) -> bool: + """True if ``step`` is a milestone of the configured ``dyn_ckpt_window``. + + ``CheckpointManager._cleanup`` excludes these from ``keep_last_n`` so + the dense early-window checkpoints survive a finite retention. + """ + return self.dyn_ckpt_window is not None and self.dyn_ckpt_window.is_milestone(step) diff --git a/kempnerforge/config/registry.py b/kempnerforge/config/registry.py index eab389e..5089f33 100644 --- a/kempnerforge/config/registry.py +++ b/kempnerforge/config/registry.py @@ -186,6 +186,31 @@ def get_adapter(self, name: str) -> Callable: def list_adapters(self) -> list[str]: return self.list("adapter") + def register_dyn_ckpt_strategy(self, name: str) -> Callable: + """Decorator to register a dynamic-checkpointing-window strategy. + + Strategies take ``(window: DynamicCheckpointWindow, step: int)`` and + return ``True`` iff ``step`` should be saved by the dynamic window. + ``DynamicCheckpointWindow.is_milestone`` looks the strategy up here by + name; ``CheckpointManager._cleanup`` exempts every step the strategy + fires on from ``keep_last_n`` retention. + + Ships with ``"power2"`` registered by default in + ``kempnerforge/config/checkpoint.py``. + """ + + def decorator(fn: Callable) -> Callable: + self.register("dyn_ckpt_strategy", name, fn) + return fn + + return decorator + + def get_dyn_ckpt_strategy(self, name: str) -> Callable: + return self.get("dyn_ckpt_strategy", name) + + def list_dyn_ckpt_strategies(self) -> list[str]: + return self.list("dyn_ckpt_strategy") + # Global registry instance registry = Registry() diff --git a/kempnerforge/config/schema.py b/kempnerforge/config/schema.py index 3202ce4..09be75f 100644 --- a/kempnerforge/config/schema.py +++ b/kempnerforge/config/schema.py @@ -1,7 +1,11 @@ """Backward-compatible re-exports. Import from here or from submodules directly.""" from kempnerforge.config.adapter import AdapterConfig # noqa: F401 -from kempnerforge.config.checkpoint import AsyncCheckpointMode, CheckpointConfig # noqa: F401 +from kempnerforge.config.checkpoint import ( # noqa: F401 + AsyncCheckpointMode, + CheckpointConfig, + DynamicCheckpointWindow, +) from kempnerforge.config.data import DataConfig, DatasetSource, TrainingPhase # noqa: F401 from kempnerforge.config.distributed import DistributedConfig, PipelineSchedule # noqa: F401 from kempnerforge.config.eval import EvalConfig # noqa: F401 diff --git a/scripts/train.py b/scripts/train.py index 1090331..be8c6a4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -626,6 +626,29 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if prof is not None: prof.start() + # Capture the initial weights (step 0) on fresh start when the + # dyn_ckpt_window covers step 0 -- the per-step save gate only runs + # after a training step completes, so without this the random init + # is never persisted. Skipped on resume (step > 0). + if step == 0 and config.checkpoint.is_dynamic_milestone(0): + init_extra: dict = {"phase_idx": current_phase_idx} if active_phases else {} + if config.metrics.wandb_run_id: + init_extra["wandb_run_id"] = config.metrics.wandb_run_id + if is_vlm: + assert vlm_cfg is not None + valid_modules = set(vlm_cfg.module_patterns.keys()) + init_extra["vlm_freeze"] = canonical_freeze_meta( + effective_freeze(0, vlm_cfg.freeze, vlm_cfg.freeze_schedule, valid_modules) + ) + ckpt_mgr.save( + step=0, + tokens_seen=0, + scheduler=scheduler, + dataloader=dataloader, + extra=init_extra, + ) + hook_runner.on_checkpoint_save(0, config.checkpoint.dir) + while step < tc.max_steps: # Refresh data iterator at start / epoch boundary if dataloader is not None and data_iter is None: @@ -963,7 +986,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: ckpt_extra["vlm_freeze"] = canonical_freeze_meta( effective_freeze(step, vlm_cfg.freeze, vlm_cfg.freeze_schedule, valid_modules) ) - if step % config.checkpoint.interval == 0: + if config.checkpoint.should_save(step): ckpt_mgr.save( step=step, tokens_seen=tokens_seen, diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py index fce4be2..15a3bd3 100644 --- a/tests/unit/test_checkpoint.py +++ b/tests/unit/test_checkpoint.py @@ -16,7 +16,11 @@ restore_train_state, set_rng_state, ) -from kempnerforge.config.schema import AsyncCheckpointMode, CheckpointConfig +from kempnerforge.config.schema import ( + AsyncCheckpointMode, + CheckpointConfig, + DynamicCheckpointWindow, +) # --------------------------------------------------------------------------- # RNG state capture/restore @@ -190,6 +194,33 @@ def test_cleanup_preserves_when_under_limit(self, tmp_path): remaining = sorted(d.name for d in tmp_path.iterdir() if d.is_dir()) assert remaining == ["step_10", "step_20"] + def test_cleanup_protects_dynamic_milestones(self, tmp_path): + """With a dyn_ckpt_window configured, the strategy's milestone steps + survive keep_last_n; retention applies only to the later interval + checkpoints.""" + from kempnerforge.checkpoint.manager import CheckpointManager + + config = CheckpointConfig( + dir=str(tmp_path), + interval=1000, + keep_last_n=2, + dyn_ckpt_window=DynamicCheckpointWindow(start=0, stop=512), + ) + model = torch.nn.Linear(4, 4) + opt = torch.optim.SGD(model.parameters(), lr=0.1) + mgr = CheckpointManager(config, model, opt) + + milestones = [0, 1, 2, 4, 256, 512] + intervals = [1000, 2000, 3000] + for i in milestones + intervals: + (tmp_path / f"step_{i}").mkdir() + + mgr._cleanup() + + remaining = sorted(int(d.name.split("_")[1]) for d in tmp_path.iterdir() if d.is_dir()) + # All dynamic milestones kept; only the last keep_last_n=2 interval ckpts kept. + assert remaining == [0, 1, 2, 4, 256, 512, 2000, 3000] + # --------------------------------------------------------------------------- # AsyncCheckpointer diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 0b5a3f0..d550487 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -14,6 +14,7 @@ CheckpointConfig, DataConfig, DistributedConfig, + DynamicCheckpointWindow, EvalConfig, MetricsConfig, OptimizerConfig, @@ -251,6 +252,73 @@ def test_rejects_zero_interval(self): with pytest.raises(ValueError, match="interval must be positive"): CheckpointConfig(interval=0) + def test_dyn_ckpt_window_defaults_to_none(self): + # Opt-in: no dyn_ckpt_window means pure interval cadence. + assert CheckpointConfig().dyn_ckpt_window is None + + def test_interval_only_saves_on_multiples(self): + c = CheckpointConfig(interval=100) + assert c.should_save(0) + assert c.should_save(100) + assert not c.should_save(150) + + def test_dyn_ckpt_window_power2_saves_powers_of_two(self): + c = CheckpointConfig( + interval=1000, + dyn_ckpt_window=DynamicCheckpointWindow(start=0, stop=512), + ) + for step in (0, 1, 2, 4, 8, 256, 512): + assert c.should_save(step), f"expected save at {step}" + for step in (3, 5, 100, 300): + assert not c.should_save(step), f"unexpected save at {step}" + + def test_dyn_ckpt_window_uses_interval_outside_window(self): + c = CheckpointConfig( + interval=1000, + dyn_ckpt_window=DynamicCheckpointWindow(start=0, stop=512), + ) + # Beyond stop, interval applies again. + assert c.should_save(1000) + assert c.should_save(2000) + assert not c.should_save(1500) + + def test_dyn_ckpt_window_with_start_offset(self): + # Offset-based power2: save at start and at start + 2^k while <= stop. + c = CheckpointConfig( + interval=10_000, # large so the interval rule doesn't shadow the test + dyn_ckpt_window=DynamicCheckpointWindow(start=1000, stop=1300), + ) + for step in (1000, 1001, 1002, 1004, 1008, 1016, 1128, 1256): + assert c.should_save(step), f"expected save at {step}" + for step in (999, 1003, 1005, 1100, 1300): # 1300-1000=300, not a power of two + assert not c.should_save(step), f"unexpected save at {step}" + + def test_is_dynamic_milestone(self): + c = CheckpointConfig(dyn_ckpt_window=DynamicCheckpointWindow(start=0, stop=512)) + assert c.is_dynamic_milestone(0) + assert c.is_dynamic_milestone(256) + assert not c.is_dynamic_milestone(300) + assert not c.is_dynamic_milestone(1000) + # No window configured -> never a milestone, even at power-of-two steps. + assert not CheckpointConfig().is_dynamic_milestone(4) + + def test_keep_last_n_allows_keep_all(self): + # <= 0 means retain all checkpoints (CheckpointManager._cleanup early-returns) + assert CheckpointConfig(keep_last_n=0).keep_last_n == 0 + assert CheckpointConfig(keep_last_n=-1).keep_last_n == -1 + + def test_dyn_ckpt_window_rejects_negative_start(self): + with pytest.raises(ValueError, match="dyn_ckpt_window.start must be >= 0"): + DynamicCheckpointWindow(start=-1) + + def test_dyn_ckpt_window_rejects_stop_below_start(self): + with pytest.raises(ValueError, match="dyn_ckpt_window.stop must be >= start"): + DynamicCheckpointWindow(start=100, stop=50) + + def test_dyn_ckpt_window_rejects_unknown_strategy(self): + with pytest.raises(ValueError, match="unknown dyn_ckpt_window.strategy"): + DynamicCheckpointWindow(strategy="not_registered") + # --------------------------------------------------------------------------- # ProfilingConfig