Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Configs: `configs/train/vlm_debug_mot.toml` (1-GPU smoke) and `configs/train/vlm_7b_mot.toml` (4-GPU 7B).
- `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning.
- `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched.
- **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves).
- `kempnerforge/config/checkpoint.py`: new `DynamicCheckpointWindow` dataclass (`start: int = 0`, `stop: int = 512`, `strategy: str = "power2"`), new `CheckpointConfig.dyn_ckpt_window: DynamicCheckpointWindow | None = None`, `should_save(step)`, and `is_dynamic_milestone(step)`. Ships with `"power2"` registered by default; new strategies plug in via the registry without touching `CheckpointConfig`.
- `kempnerforge/config/registry.py`: `register_dyn_ckpt_strategy(name)` / `get_dyn_ckpt_strategy(name)` / `list_dyn_ckpt_strategies()` — a strategy is any `Callable[[DynamicCheckpointWindow, int], bool]` and registers via `@registry.register_dyn_ckpt_strategy("name")`.
- Milestone-aware retention: `CheckpointManager._cleanup` never prunes a step where the configured dynamic strategy fired, so `keep_last_n` rotates only the later interval checkpoints. `keep_last_n <= 0` keeps everything (the previous `keep_last_n >= 1` requirement is relaxed).
- `scripts/train.py`: the save gate now calls `config.checkpoint.should_save(step)`.
- Tests: `tests/unit/test_config.py` (defaults, power2 firing, offset-based `start > 0`, validation, unknown-strategy rejection, `is_dynamic_milestone`), `tests/unit/test_checkpoint.py::TestCheckpointRetention::test_cleanup_protects_dynamic_milestones`.

### Changed
- `docs/getting-started/install.md` Prerequisites: documents `.python-version` and uv's auto-fetch behavior.
Expand Down
2 changes: 1 addition & 1 deletion docs/checkpointing/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ contains two kinds of state:
[HF conversion](hf-conversion.md).
- **Config knobs**:
[Configuration § CheckpointConfig](../configuration/config-sections.md)
(search for `interval`, `async_mode`, `keep_last_n`, `load_path`).
(search for `interval`, `dyn_ckpt_window`, `async_mode`, `keep_last_n`, `load_path`).
53 changes: 51 additions & 2 deletions docs/configuration/config-sections.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,62 @@ DCP-based checkpointing.
| Field | Type | Default | Purpose |
|-------|------|---------|---------|
| `dir` | `str` | `"checkpoints"` | root directory for `step_N/` + `latest` symlink |
| `interval` | `int` | `1000` | save every N steps |
| `interval` | `int` | `1000` | save every N steps (outside any `dyn_ckpt_window`) |
| `dyn_ckpt_window` | `DynamicCheckpointWindow \| None` | `None` | optional dense save phase — see [`[checkpoint.dyn_ckpt_window]`](#checkpointdyn_ckpt_window--dynamiccheckpointwindow) |
| `async_mode` | `"disabled" \| "async" \| "async_with_pinned_mem"` | `"disabled"` | DCP async-save mode |
| `keep_last_n` | `int` | `3` | retain the most recent N checkpoints |
| `keep_last_n` | `int` | `3` | retain the most recent N checkpoints (`<= 0` keeps all); steps saved by `dyn_ckpt_window` are always kept |
| `load_path` | `str \| None` | `None` | explicit resume path (overrides `latest` symlink) |
| `export_dtype` | `"float32" \| "bfloat16"` | `"bfloat16"` | dtype for HF exports via `scripts/convert_checkpoint.py` |
| `exclude_from_loading` | `list[str]` | `[]` | FQN prefixes to skip on load (e.g. to reinit a head) |

### `[checkpoint.dyn_ckpt_window]` — `DynamicCheckpointWindow`

Opt-in dense save phase. Inside `[start, stop]` a registered strategy decides
which steps to save; outside the window the regular `interval` cadence applies.
Steps saved by the strategy are exempt from `keep_last_n` retention, so a
finite `keep_last_n` rotates only the later interval checkpoints — the dense
window checkpoints are never pruned. Useful for inspecting early-training
dynamics at fine granularity.

| Field | Type | Default | Purpose |
|-------|------|---------|---------|
| `start` | `int` | `0` | first step of the dense window; `0` captures the initial weights before any training step |
| `stop` | `int` | `512` | last step of the dense window |
| `strategy` | `str` | `"power2"` | name of a registered strategy (see below) |

Example TOML:

```toml
# Off by default — interval cadence only:
[checkpoint]
interval = 1000
keep_last_n = 3

# Opt in with defaults (start=0, stop=512, strategy="power2"):
[checkpoint.dyn_ckpt_window]

# Customize:
[checkpoint.dyn_ckpt_window]
start = 0
stop = 1024
strategy = "power2"
```

#### Dyn_ckpt strategies

A strategy is a `Callable[[DynamicCheckpointWindow, int], bool]` registered via
`@registry.register_dyn_ckpt_strategy("name")` (see
[`kempnerforge/config/registry.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/registry.py)).
The strategy receives the window and the current training step and returns
`True` iff the step should be saved. The registry must be populated before
config load — either by `kempnerforge` itself (the default `"power2"` is
registered in `kempnerforge/config/checkpoint.py`) or by an explicit `import`
of the user's module.

| Name | Behavior |
|------|----------|
| `"power2"` (default) | save at `start` and at every `start + 2^k` while `<= stop` (offset-based powers of two; tight near `start`, doubling thereafter) |

## `[metrics]` — `MetricsConfig`

Logging cadence and backend toggles.
Expand Down
7 changes: 6 additions & 1 deletion docs/configuration/validation-rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ File:
[`kempnerforge/config/checkpoint.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/checkpoint.py).

- `interval > 0`.
- `keep_last_n ≥ 1`.
- If `dyn_ckpt_window` is set:
- `dyn_ckpt_window.start >= 0`.
- `dyn_ckpt_window.stop >= dyn_ckpt_window.start`.
- `dyn_ckpt_window.strategy` must name a registered dyn_ckpt strategy
(`"power2"` ships by default; see
[`kempnerforge/config/registry.py`](https://github.com/KempnerInstitute/KempnerForge/blob/main/kempnerforge/config/registry.py)).

### `MetricsConfig`

Expand Down
20 changes: 15 additions & 5 deletions kempnerforge/checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,12 @@ def _resolve_dcp_load_dir(self, resolved: Path, path: str | None) -> Path:
def _cleanup(self) -> None:
"""Remove old checkpoints beyond the retention limit.

Two directories are never removed regardless of retention: the
current ``latest`` target and the in-flight async checkpoint
(``_pending_finalize``). Pruning either would let a crash strand
resume with no loadable checkpoint — the exact failure this fix
exists to prevent.
Never removed regardless of retention: the current ``latest`` target
and the in-flight async checkpoint (``_pending_finalize``) — pruning
either would let a crash strand resume with no loadable checkpoint —
and any dynamic-window milestone (``CheckpointConfig.is_dynamic_milestone``),
so the dense early-dynamics checkpoints survive even with a finite
``keep_last_n``.
"""
keep = self.config.keep_last_n
if keep <= 0:
Expand All @@ -631,6 +632,15 @@ def _cleanup(self) -> None:
if self._pending_finalize is not None:
protected.add(self._pending_finalize[1].resolve())

# Dynamic-window milestones (whatever steps the configured
# dyn_ckpt_window strategy fires on) are the dense early-dynamics
# checkpoints the window exists to capture, so never prune them;
# retention then applies only to the later interval checkpoints.
# No-op when no dyn_ckpt_window is configured.
for d in ckpt_dirs:
if self.config.is_dynamic_milestone(int(d.name.split("_")[1])):
protected.add(d.resolve())

# Remove oldest beyond retention, but never a protected dir.
to_remove = ckpt_dirs[:-keep] if len(ckpt_dirs) > keep else []
for d in to_remove:
Expand Down
74 changes: 70 additions & 4 deletions kempnerforge/config/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,67 @@
from enum import StrEnum
from typing import Literal

from kempnerforge.config.registry import registry


class AsyncCheckpointMode(StrEnum):
disabled = "disabled"
async_ = "async"
async_pinned = "async_with_pinned_mem"


@dataclass
class DynamicCheckpointWindow:
"""A bounded step range with a registered checkpoint strategy.

Inside ``[start, stop]`` the strategy decides which steps to save, and every
such step is exempt from ``CheckpointConfig.keep_last_n`` retention. Outside
the window the regular ``CheckpointConfig.interval`` cadence applies.

``"power2"`` (default) saves at ``start`` and at every ``start + 2^k`` while
``<= stop`` -- tight at the start of the window, doubling thereafter. New
strategies register via ``@registry.register_dyn_ckpt_strategy(name)`` and
become selectable by setting ``strategy``.
"""

start: int = 0 # 0 = capture initial weights before any training step
stop: int = 512
strategy: str = "power2"

def __post_init__(self) -> None:
if self.start < 0:
raise ValueError("dyn_ckpt_window.start must be >= 0")
if self.stop < self.start:
raise ValueError("dyn_ckpt_window.stop must be >= start")
known = registry.list_dyn_ckpt_strategies()
if self.strategy not in known:
raise ValueError(
f"unknown dyn_ckpt_window.strategy {self.strategy!r}; registered: {known}"
)

def is_milestone(self, step: int) -> bool:
"""True iff the configured strategy fires at ``step``."""
return registry.get_dyn_ckpt_strategy(self.strategy)(self, step)


@registry.register_dyn_ckpt_strategy("power2")
def _power2_strategy(window: DynamicCheckpointWindow, step: int) -> bool:
"""Save at ``start`` and at every ``start + 2^k`` while ``<= stop``."""
if step < window.start or step > window.stop:
return False
offset = step - window.start
return offset == 0 or (offset & (offset - 1)) == 0


@dataclass
class CheckpointConfig:
"""Checkpointing settings."""

dir: str = "checkpoints"
interval: int = 1000 # Save every N steps
interval: int = 1000 # save every N steps; outside any dyn_ckpt_window
dyn_ckpt_window: DynamicCheckpointWindow | None = None # opt-in dense window
async_mode: AsyncCheckpointMode = AsyncCheckpointMode.disabled
keep_last_n: int = 3 # Number of checkpoints to retain
keep_last_n: int = 3 # recent ckpts kept (<=0 keeps all); dynamic milestones always kept
load_path: str | None = None # Path to load from (for resumption)
export_dtype: Literal["float32", "bfloat16"] = "bfloat16"
exclude_from_loading: list[str] = field(default_factory=list)
Expand All @@ -33,5 +79,25 @@ class CheckpointConfig:
def __post_init__(self) -> None:
if self.interval <= 0:
raise ValueError("interval must be positive")
if self.keep_last_n < 1:
raise ValueError("keep_last_n must be >= 1")

def should_save(self, step: int) -> bool:
"""Whether to write a checkpoint at ``step``.

Inside ``dyn_ckpt_window``: the registered strategy decides (default
``"power2"`` saves at ``start`` and each ``start + 2^k`` while
``<= stop``). Outside the window: every ``interval`` steps. Dynamic
milestones are exempt from ``keep_last_n`` (see
``CheckpointManager._cleanup``).
"""
w = self.dyn_ckpt_window
if w is not None and w.start <= step <= w.stop:
return w.is_milestone(step)
return step % self.interval == 0

def is_dynamic_milestone(self, step: int) -> bool:
"""True if ``step`` is a milestone of the configured ``dyn_ckpt_window``.

``CheckpointManager._cleanup`` excludes these from ``keep_last_n`` so
the dense early-window checkpoints survive a finite retention.
"""
return self.dyn_ckpt_window is not None and self.dyn_ckpt_window.is_milestone(step)
25 changes: 25 additions & 0 deletions kempnerforge/config/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,31 @@ def get_adapter(self, name: str) -> Callable:
def list_adapters(self) -> list[str]:
return self.list("adapter")

def register_dyn_ckpt_strategy(self, name: str) -> Callable:
"""Decorator to register a dynamic-checkpointing-window strategy.

Strategies take ``(window: DynamicCheckpointWindow, step: int)`` and
return ``True`` iff ``step`` should be saved by the dynamic window.
``DynamicCheckpointWindow.is_milestone`` looks the strategy up here by
name; ``CheckpointManager._cleanup`` exempts every step the strategy
fires on from ``keep_last_n`` retention.

Ships with ``"power2"`` registered by default in
``kempnerforge/config/checkpoint.py``.
"""

def decorator(fn: Callable) -> Callable:
self.register("dyn_ckpt_strategy", name, fn)
return fn

return decorator

def get_dyn_ckpt_strategy(self, name: str) -> Callable:
return self.get("dyn_ckpt_strategy", name)

def list_dyn_ckpt_strategies(self) -> list[str]:
return self.list("dyn_ckpt_strategy")


# Global registry instance
registry = Registry()
6 changes: 5 additions & 1 deletion kempnerforge/config/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Backward-compatible re-exports. Import from here or from submodules directly."""

from kempnerforge.config.adapter import AdapterConfig # noqa: F401
from kempnerforge.config.checkpoint import AsyncCheckpointMode, CheckpointConfig # noqa: F401
from kempnerforge.config.checkpoint import ( # noqa: F401
AsyncCheckpointMode,
CheckpointConfig,
DynamicCheckpointWindow,
)
from kempnerforge.config.data import DataConfig, DatasetSource, TrainingPhase # noqa: F401
from kempnerforge.config.distributed import DistributedConfig, PipelineSchedule # noqa: F401
from kempnerforge.config.eval import EvalConfig # noqa: F401
Expand Down
25 changes: 24 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,29 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if prof is not None:
prof.start()

# Capture the initial weights (step 0) on fresh start when the
# dyn_ckpt_window covers step 0 -- the per-step save gate only runs
# after a training step completes, so without this the random init
# is never persisted. Skipped on resume (step > 0).
if step == 0 and config.checkpoint.is_dynamic_milestone(0):
init_extra: dict = {"phase_idx": current_phase_idx} if active_phases else {}
if config.metrics.wandb_run_id:
init_extra["wandb_run_id"] = config.metrics.wandb_run_id
if is_vlm:
assert vlm_cfg is not None
valid_modules = set(vlm_cfg.module_patterns.keys())
init_extra["vlm_freeze"] = canonical_freeze_meta(
effective_freeze(0, vlm_cfg.freeze, vlm_cfg.freeze_schedule, valid_modules)
)
ckpt_mgr.save(
step=0,
tokens_seen=0,
scheduler=scheduler,
dataloader=dataloader,
extra=init_extra,
)
hook_runner.on_checkpoint_save(0, config.checkpoint.dir)

while step < tc.max_steps:
# Refresh data iterator at start / epoch boundary
if dataloader is not None and data_iter is None:
Expand Down Expand Up @@ -963,7 +986,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
ckpt_extra["vlm_freeze"] = canonical_freeze_meta(
effective_freeze(step, vlm_cfg.freeze, vlm_cfg.freeze_schedule, valid_modules)
)
if step % config.checkpoint.interval == 0:
if config.checkpoint.should_save(step):
ckpt_mgr.save(
step=step,
tokens_seen=tokens_seen,
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
restore_train_state,
set_rng_state,
)
from kempnerforge.config.schema import AsyncCheckpointMode, CheckpointConfig
from kempnerforge.config.schema import (
AsyncCheckpointMode,
CheckpointConfig,
DynamicCheckpointWindow,
)

# ---------------------------------------------------------------------------
# RNG state capture/restore
Expand Down Expand Up @@ -190,6 +194,33 @@ def test_cleanup_preserves_when_under_limit(self, tmp_path):
remaining = sorted(d.name for d in tmp_path.iterdir() if d.is_dir())
assert remaining == ["step_10", "step_20"]

def test_cleanup_protects_dynamic_milestones(self, tmp_path):
"""With a dyn_ckpt_window configured, the strategy's milestone steps
survive keep_last_n; retention applies only to the later interval
checkpoints."""
from kempnerforge.checkpoint.manager import CheckpointManager

config = CheckpointConfig(
dir=str(tmp_path),
interval=1000,
keep_last_n=2,
dyn_ckpt_window=DynamicCheckpointWindow(start=0, stop=512),
)
model = torch.nn.Linear(4, 4)
opt = torch.optim.SGD(model.parameters(), lr=0.1)
mgr = CheckpointManager(config, model, opt)

milestones = [0, 1, 2, 4, 256, 512]
intervals = [1000, 2000, 3000]
for i in milestones + intervals:
(tmp_path / f"step_{i}").mkdir()

mgr._cleanup()

remaining = sorted(int(d.name.split("_")[1]) for d in tmp_path.iterdir() if d.is_dir())
# All dynamic milestones kept; only the last keep_last_n=2 interval ckpts kept.
assert remaining == [0, 1, 2, 4, 256, 512, 2000, 3000]


# ---------------------------------------------------------------------------
# AsyncCheckpointer
Expand Down
Loading
Loading