Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `docs/claude-ready.md` first-run flow: `/kempnerforge:install-and-verify` runs before `/kempnerforge:cluster-config`.
- `README.md` and `kempnerforge/README.md`: list `install-and-verify` in the skill catalog and drop the hardcoded skill count.

### Fixed
- **Resume silently reset AdamW optimizer momentum.** `CheckpointManager` round-tripped optimizer state through raw `optimizer.state_dict()` / `optimizer.load_state_dict()`. On resume the optimizer is freshly built, so its `state_dict()` is empty — `dcp.load` then had no `exp_avg` / `exp_avg_sq` tensors to fill, and the moments were silently dropped, resetting Adam momentum to zero at every resume point. Model weights, scheduler, dataloader position, and RNG all restored correctly; only the optimizer moments were lost, so resumed runs were not bit-exact.
- `kempnerforge/checkpoint/manager.py`: save and load now go through DCP's `get_model_state_dict` / `get_optimizer_state_dict` / `set_model_state_dict` / `set_optimizer_state_dict`. The getters build a load template with the optimizer moments allocated in the correct FSDP/DTensor layout, so `dcp.load` repopulates them; the setters write the loaded values back into the live optimizer.
- `docs/checkpointing/dcp-model.md`: updated the save/load snippets and the "shape to fill" explanation to the DCP-aware helpers.
- Tests (fail on the pre-fix code, pass after): `tests/integration/test_checkpoint_roundtrip.py::test_manager_restores_optimizer_moments_single_gpu` and `tests/distributed/test_checkpoint.py::test_resume_restores_optimizer_moments` assert `exp_avg` / `exp_avg_sq` are restored bit-exactly into a *fresh* optimizer (single-GPU + distributed); `tests/e2e/test_training_e2e.py::test_resume_determinism_single_gpu` / `test_resume_determinism_2gpu_fsdp` assert end-to-end bit-exact loss across an interrupt-and-resume on a learnable dataset.
- **On-disk format note:** optimizer state is now keyed by parameter fully-qualified name rather than positional index. Checkpoints written before this fix will not restore optimizer state on resume (training continues with a fresh optimizer); model state is unaffected.

## [0.1.0] — 2026-04-16

Initial public release.
Expand Down
47 changes: 36 additions & 11 deletions docs/checkpointing/dcp-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,33 @@ in `kempnerforge/checkpoint/manager.py`.

```python
dcp_state = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"model": get_model_state_dict(self.model),
"optimizer": get_optimizer_state_dict(self.model, self.optimizer),
}
self._async_ckpt.save(
dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group
)
```

`get_model_state_dict` / `get_optimizer_state_dict` are the DCP-aware
helpers from `torch.distributed.checkpoint.state_dict` — **not** raw
`model.state_dict()` / `optimizer.state_dict()`. They key the
optimizer state by parameter fully-qualified name (not positional
index) and keep the FSDP/DTensor sharding intact, which is what makes
load (and resharding) line up by name. See [Loading](#loading) for why
the raw calls break resume.

Two top-level keys — `"model"` and `"optimizer"`. DCP introspects the
state dicts, finds `DTensor` / `ShardedTensor` parameters, and
writes each shard to disk with enough metadata to reassemble.

What's in each:

- **`model.state_dict()`** — every parameter and buffer: weights,
- **model state** — every parameter and buffer: weights,
RMSNorm scales, learned RoPE frequencies (if present), and any
registered buffer. Under FSDP2 these are `DTensor`s; under TP
they're `DTensor`s on a 2D mesh. DCP handles both.
- **`optimizer.state_dict()`** — AdamW's `exp_avg`, `exp_avg_sq`,
- **optimizer state** — AdamW's `exp_avg`, `exp_avg_sq`,
`step` counters; Lion's `exp_avg`; Muon's internal state. All
per-parameter tensors live on the same device and parallelism
shape as the parameter, so DCP saves them symmetrically.
Expand Down Expand Up @@ -148,16 +156,33 @@ it through on every save/load.
Load is the mirror image of save:

```python
dcp_state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()}
dcp_state = {
"model": get_model_state_dict(self.model),
"optimizer": get_optimizer_state_dict(self.model, self.optimizer),
}
dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group)
self.model.load_state_dict(dcp_state["model"])
self.optimizer.load_state_dict(dcp_state["optimizer"])
set_model_state_dict(self.model, dcp_state["model"])
set_optimizer_state_dict(self.model, self.optimizer, optim_state_dict=dcp_state["optimizer"])
```

The first `state_dict()` call gives DCP the **shape** to fill — it
doesn't contain the saved data, just the tensor metadata DCP needs
to know what to load where. `dcp.load` mutates the tensors in place
with the loaded values. Then `load_state_dict` consumes them.
The getter call gives DCP the **shape** to fill — it doesn't contain
the saved data, just the tensor metadata (and, crucially, the
*allocated* optimizer moment tensors) DCP needs to know what to load
where. `dcp.load` mutates those tensors in place with the loaded
values; the setters then write them back into the live model and
optimizer.

> **Why the DCP-aware helpers, not `optimizer.state_dict()`?** On
> resume the optimizer is freshly built, so `optimizer.state_dict()`
> has *no* `exp_avg` / `exp_avg_sq` tensors yet — AdamW creates the
> per-parameter state lazily on the first `.step()`. Passing that
> empty dict as the load template gives `dcp.load` nothing to fill, so
> the saved moments are silently dropped and Adam momentum resets to
> zero at every resume (a non-bit-exact resume). `get_optimizer_state_dict`
> allocates the moment tensors up front in the right sharded layout, so
> `dcp.load` repopulates them. The model side would work with either
> call — its parameters are always allocated — but we use the matching
> getter/setter for symmetry.

Loading with a different GPU count triggers DCP's automatic
resharding — see [Resharding](resharding.md).
Expand Down
43 changes: 32 additions & 11 deletions kempnerforge/checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)

from kempnerforge.checkpoint.async_save import AsyncCheckpointer
from kempnerforge.checkpoint.state import build_train_state, restore_train_state
Expand Down Expand Up @@ -240,10 +246,15 @@ def save(
dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir
dcp_dir.mkdir(parents=True, exist_ok=True)

# Save distributed state (model + optimizer) via DCP
# Save distributed state (model + optimizer) via DCP. Use the DCP-aware
# state-dict helpers, NOT raw optimizer.state_dict(): on load they build a
# template with the optimizer moment tensors allocated so dcp.load can
# repopulate them. A freshly-constructed optimizer's raw state_dict() is
# empty, so the moments would be silently dropped on resume (Adam momentum
# resets to zero -> non-bit-exact resume; see manager load()).
dcp_state = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"model": get_model_state_dict(self.model),
"optimizer": get_optimizer_state_dict(self.model, self.optimizer),
}
# Dispatch the DCP save. For async modes this returns immediately but
# FIRST awaits the previous in-flight flush, so any deferred
Expand Down Expand Up @@ -457,19 +468,29 @@ def load(
# Load distributed state via DCP
dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir

load_model = exclude_keys is None or "model" not in exclude_keys
load_optim = exclude_keys is None or "optimizer" not in exclude_keys

# Build properly-structured (FSDP-sharded, moment-allocated) templates via
# the DCP-aware getters so dcp.load can repopulate them, then write them
# back with the setters. Raw optimizer.state_dict() would be empty on a
# fresh optimizer, so dcp.load would find no moment tensors to fill and the
# AdamW momentum would silently reset to zero on resume.
dcp_state: dict[str, Any] = {}
if exclude_keys is None or "model" not in exclude_keys:
dcp_state["model"] = self.model.state_dict()
if exclude_keys is None or "optimizer" not in exclude_keys:
dcp_state["optimizer"] = self.optimizer.state_dict()
if load_model:
dcp_state["model"] = get_model_state_dict(self.model)
if load_optim:
dcp_state["optimizer"] = get_optimizer_state_dict(self.model, self.optimizer)

if dcp_state:
dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group)

if "model" in dcp_state:
self.model.load_state_dict(dcp_state["model"])
if "optimizer" in dcp_state:
self.optimizer.load_state_dict(dcp_state["optimizer"])
if load_model:
set_model_state_dict(self.model, dcp_state["model"])
if load_optim:
set_optimizer_state_dict(
self.model, self.optimizer, optim_state_dict=dcp_state["optimizer"]
)

# Load non-distributed state. On NFS/Lustre, independent stat()
# calls can disagree briefly across ranks; if some ranks enter
Expand Down
68 changes: 68 additions & 0 deletions tests/distributed/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,74 @@ def test_save_load_fsdp(self, distributed_env, shared_tmp_dir):
f"Restored output differs: max diff={(ref_out - restored_out).abs().max().item()}"
)

def test_resume_restores_optimizer_moments(self, distributed_env, shared_tmp_dir):
"""Optimizer moments must be restored into a FRESH optimizer on resume.

Regression for the bug where the manager used raw ``optimizer.state_dict()``
with DCP: a freshly-constructed optimizer has empty state, so the load
template had no moment tensors and ``dcp.load`` silently dropped them ->
AdamW ``exp_avg``/``exp_avg_sq`` reset to zero on every resume. Loading
into the *same* already-stepped optimizer (as test_save_load_fsdp does)
hides this, so here we load into a fresh optimizer like a real resume.
"""
from kempnerforge.config.schema import OptimizerConfig

mesh = distributed_env
ckpt_dir = shared_tmp_dir

def snapshot_moments(optimizer):
"""Per-rank (sharded) (exp_avg, exp_avg_sq) clones, in param order."""
out = []
for group in optimizer.param_groups:
for p in group["params"]:
st = optimizer.state.get(p, {})
if "exp_avg" in st:
out.append(
(st["exp_avg"].detach().clone(), st["exp_avg_sq"].detach().clone())
)
return out

# Build, step a few times so the moments are non-trivial, then save.
torch.manual_seed(42)
model = Transformer(SMALL_CONFIG).cuda()
apply_fsdp2(model, mesh)
opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False))
for _ in range(3):
tokens = torch.randint(0, 512, (2, 32), device="cuda")
model(tokens).sum().backward()
opt.step()
opt.zero_grad()

ref_moments = snapshot_moments(opt)
assert ref_moments, "optimizer had no moment state to test"
assert any(ea.abs().sum().item() > 0 for ea, _ in ref_moments), (
"reference moments are all zero"
)

CheckpointManager(CheckpointConfig(dir=ckpt_dir, keep_last_n=2), model, opt).save(
step=3, tokens_seen=192
)

# FRESH model + FRESH optimizer (never stepped -> empty optimizer state),
# exactly the resume scenario. Then load.
torch.manual_seed(42)
model2 = Transformer(SMALL_CONFIG).cuda()
apply_fsdp2(model2, mesh)
opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False))
assert not snapshot_moments(opt2), "fresh optimizer should have empty state before load"

CheckpointManager(CheckpointConfig(dir=ckpt_dir, keep_last_n=2), model2, opt2).load()

loaded_moments = snapshot_moments(opt2)
assert loaded_moments, (
"optimizer moments were not restored (state still empty after load) "
"-- AdamW momentum reset on resume"
)
assert len(loaded_moments) == len(ref_moments)
for i, ((rea, rev), (lea, lev)) in enumerate(zip(ref_moments, loaded_moments, strict=True)):
assert torch.equal(rea, lea), f"param {i}: exp_avg not restored bit-exactly"
assert torch.equal(rev, lev), f"param {i}: exp_avg_sq not restored bit-exactly"

def test_latest_symlink(self, distributed_env, shared_tmp_dir):
"""The 'latest' symlink should point to the most recent checkpoint."""
mesh = distributed_env
Expand Down
103 changes: 103 additions & 0 deletions tests/e2e/test_training_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def _parse_last_loss(output: str) -> float | None:
return float(matches[-1]) if matches else None


def _parse_losses_by_step(output: str) -> dict[int, str]:
"""Map step -> logged loss string (exact text, for bit-identical comparison)."""
return {int(s): loss for s, loss in re.findall(r"\[step (\d+)\] loss=([\d.]+)", output)}


# ============================================================================
# Single GPU
# ============================================================================
Expand Down Expand Up @@ -397,6 +402,104 @@ def test_checkpoint_save_and_resume(tmp_path):
assert "step=10" in output, "Did not resume from step 10"


# ============================================================================
# Resume Determinism (regression: optimizer state must be restored bit-exactly)
# ============================================================================


def _make_learnable_shard(tmp_path) -> str:
"""Write a small on-disk dataset with a LEARNABLE repeating pattern.

The resume bug perturbs the weights slightly; random tokens give a
near-constant loss (a model can't fit noise) that masks it, so we need data
the model can actually learn for the loss to be sensitive to the exact
weights. A simple token cycle (next token is always predictable) does that.
"""
import numpy as np

data_dir = tmp_path / "data"
data_dir.mkdir(parents=True, exist_ok=True)
tokens = (np.arange(8_000_000) % 256).astype(np.uint16)
np.save(str(data_dir / "shard.npy"), tokens)
return str(data_dir)


def _check_resume_determinism(tmp_path, nproc: int) -> None:
"""An interrupted+resumed run must reproduce an uninterrupted reference
bit-for-bit on the post-resume steps.

If the optimizer moments (or RNG / dataloader position) are not restored
exactly, the loss trajectories diverge. Regression test for the bug where the
manager used raw ``optimizer.state_dict()`` with DCP: a freshly-built
optimizer's state is empty, so ``dcp.load`` had no moment tensors to fill and
AdamW momentum silently reset to zero on every resume.
"""
data_dir = _make_learnable_shard(tmp_path)
common = [
DEBUG_CONFIG,
"--metrics.log_interval=1",
"--train.seed=1234",
"--train.compile_model=false",
"--model.vocab_size=256",
f"--data.dataset_path={data_dir}",
"--data.file_pattern=*.npy",
]
reference = _run_training(
common
+ [
"--train.max_steps=20",
f"--checkpoint.dir={tmp_path}/ref",
"--checkpoint.interval=1000",
],
nproc=nproc,
timeout=300,
)
_assert_training_complete(reference, expected_steps=20)

phase1 = _run_training(
common
+ ["--train.max_steps=10", f"--checkpoint.dir={tmp_path}/test", "--checkpoint.interval=10"],
nproc=nproc,
timeout=300,
)
_assert_training_complete(phase1, expected_steps=10)

resumed = _run_training(
common
+ [
"--train.max_steps=20",
f"--checkpoint.dir={tmp_path}/test",
"--checkpoint.interval=1000",
],
nproc=nproc,
timeout=300,
)
_assert_training_complete(resumed, expected_steps=20)
assert "step=10" in (resumed.stdout + resumed.stderr), "did not resume from step 10"

ref_losses = _parse_losses_by_step(reference.stdout + reference.stderr)
res_losses = _parse_losses_by_step(resumed.stdout + resumed.stderr)
for step in range(11, 21):
assert step in res_losses, f"resumed run missing step {step}"
assert res_losses[step] == ref_losses.get(step), (
f"step {step}: resumed loss {res_losses[step]} != reference {ref_losses.get(step)} "
"-- post-resume state (optimizer moments / RNG / dataloader) not restored bit-exactly"
)


@pytest.mark.e2e
def test_resume_determinism_single_gpu(tmp_path):
"""Single GPU: resumed run must match the uninterrupted reference bit-for-bit."""
_check_resume_determinism(tmp_path, nproc=1)


@pytest.mark.e2e
@requires_gpus(2)
def test_resume_determinism_2gpu_fsdp(tmp_path):
"""2 GPU FSDP: resumed run must match the uninterrupted reference bit-for-bit."""
_check_resume_determinism(tmp_path, nproc=2)


# ============================================================================
# Pipeline Parallel + Checkpoint
# ============================================================================
Expand Down
Loading
Loading