From 1f950357ac14298d3445552f3cbc66d8840df8ea Mon Sep 17 00:00:00 2001 From: amazloumi Date: Wed, 27 May 2026 12:12:56 -0400 Subject: [PATCH 1/3] fix(checkpoint): restore AdamW optimizer moments on resume via DCP get/set_state_dict --- CHANGELOG.md | 7 ++ docs/checkpointing/dcp-model.md | 47 ++++++-- kempnerforge/checkpoint/manager.py | 43 ++++++-- tests/distributed/test_checkpoint.py | 68 ++++++++++++ tests/e2e/test_training_e2e.py | 103 ++++++++++++++++++ .../integration/test_checkpoint_roundtrip.py | 51 +++++++++ 6 files changed, 297 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1746bda..f04d6c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `docs/claude-ready.md` first-run flow: `/kempnerforge:install-and-verify` runs before `/kempnerforge:cluster-config`. - `README.md` and `kempnerforge/README.md`: list `install-and-verify` in the skill catalog and drop the hardcoded skill count. +### Fixed +- **Resume silently reset AdamW optimizer momentum.** `CheckpointManager` round-tripped optimizer state through raw `optimizer.state_dict()` / `optimizer.load_state_dict()`. On resume the optimizer is freshly built, so its `state_dict()` is empty — `dcp.load` then had no `exp_avg` / `exp_avg_sq` tensors to fill, and the moments were silently dropped, resetting Adam momentum to zero at every resume point. Model weights, scheduler, dataloader position, and RNG all restored correctly; only the optimizer moments were lost, so resumed runs were not bit-exact. + - `kempnerforge/checkpoint/manager.py`: save and load now go through DCP's `get_model_state_dict` / `get_optimizer_state_dict` / `set_model_state_dict` / `set_optimizer_state_dict`. The getters build a load template with the optimizer moments allocated in the correct FSDP/DTensor layout, so `dcp.load` repopulates them; the setters write the loaded values back into the live optimizer. + - `docs/checkpointing/dcp-model.md`: updated the save/load snippets and the "shape to fill" explanation to the DCP-aware helpers. + - Tests (fail on the pre-fix code, pass after): `tests/integration/test_checkpoint_roundtrip.py::test_manager_restores_optimizer_moments_single_gpu` and `tests/distributed/test_checkpoint.py::test_resume_restores_optimizer_moments` assert `exp_avg` / `exp_avg_sq` are restored bit-exactly into a *fresh* optimizer (single-GPU + distributed); `tests/e2e/test_training_e2e.py::test_resume_determinism_single_gpu` / `test_resume_determinism_2gpu_fsdp` assert end-to-end bit-exact loss across an interrupt-and-resume on a learnable dataset. + - **On-disk format note:** optimizer state is now keyed by parameter fully-qualified name rather than positional index. Checkpoints written before this fix will not restore optimizer state on resume (training continues with a fresh optimizer); model state is unaffected. + ## [0.1.0] — 2026-04-16 Initial public release. diff --git a/docs/checkpointing/dcp-model.md b/docs/checkpointing/dcp-model.md index a6a03cf..ed49324 100644 --- a/docs/checkpointing/dcp-model.md +++ b/docs/checkpointing/dcp-model.md @@ -15,25 +15,33 @@ in `kempnerforge/checkpoint/manager.py`. ```python dcp_state = { - "model": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), + "model": get_model_state_dict(self.model), + "optimizer": get_optimizer_state_dict(self.model, self.optimizer), } self._async_ckpt.save( dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group ) ``` +`get_model_state_dict` / `get_optimizer_state_dict` are the DCP-aware +helpers from `torch.distributed.checkpoint.state_dict` — **not** raw +`model.state_dict()` / `optimizer.state_dict()`. They key the +optimizer state by parameter fully-qualified name (not positional +index) and keep the FSDP/DTensor sharding intact, which is what makes +load (and resharding) line up by name. See [Loading](#loading) for why +the raw calls break resume. + Two top-level keys — `"model"` and `"optimizer"`. DCP introspects the state dicts, finds `DTensor` / `ShardedTensor` parameters, and writes each shard to disk with enough metadata to reassemble. What's in each: -- **`model.state_dict()`** — every parameter and buffer: weights, +- **model state** — every parameter and buffer: weights, RMSNorm scales, learned RoPE frequencies (if present), and any registered buffer. Under FSDP2 these are `DTensor`s; under TP they're `DTensor`s on a 2D mesh. DCP handles both. -- **`optimizer.state_dict()`** — AdamW's `exp_avg`, `exp_avg_sq`, +- **optimizer state** — AdamW's `exp_avg`, `exp_avg_sq`, `step` counters; Lion's `exp_avg`; Muon's internal state. All per-parameter tensors live on the same device and parallelism shape as the parameter, so DCP saves them symmetrically. @@ -148,16 +156,33 @@ it through on every save/load. Load is the mirror image of save: ```python -dcp_state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()} +dcp_state = { + "model": get_model_state_dict(self.model), + "optimizer": get_optimizer_state_dict(self.model, self.optimizer), +} dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group) -self.model.load_state_dict(dcp_state["model"]) -self.optimizer.load_state_dict(dcp_state["optimizer"]) +set_model_state_dict(self.model, dcp_state["model"]) +set_optimizer_state_dict(self.model, self.optimizer, optim_state_dict=dcp_state["optimizer"]) ``` -The first `state_dict()` call gives DCP the **shape** to fill — it -doesn't contain the saved data, just the tensor metadata DCP needs -to know what to load where. `dcp.load` mutates the tensors in place -with the loaded values. Then `load_state_dict` consumes them. +The getter call gives DCP the **shape** to fill — it doesn't contain +the saved data, just the tensor metadata (and, crucially, the +*allocated* optimizer moment tensors) DCP needs to know what to load +where. `dcp.load` mutates those tensors in place with the loaded +values; the setters then write them back into the live model and +optimizer. + +> **Why the DCP-aware helpers, not `optimizer.state_dict()`?** On +> resume the optimizer is freshly built, so `optimizer.state_dict()` +> has *no* `exp_avg` / `exp_avg_sq` tensors yet — AdamW creates the +> per-parameter state lazily on the first `.step()`. Passing that +> empty dict as the load template gives `dcp.load` nothing to fill, so +> the saved moments are silently dropped and Adam momentum resets to +> zero at every resume (a non-bit-exact resume). `get_optimizer_state_dict` +> allocates the moment tensors up front in the right sharded layout, so +> `dcp.load` repopulates them. The model side would work with either +> call — its parameters are always allocated — but we use the matching +> getter/setter for symmetry. Loading with a different GPU count triggers DCP's automatic resharding — see [Resharding](resharding.md). diff --git a/kempnerforge/checkpoint/manager.py b/kempnerforge/checkpoint/manager.py index 10bc07a..e5c18e0 100644 --- a/kempnerforge/checkpoint/manager.py +++ b/kempnerforge/checkpoint/manager.py @@ -20,6 +20,12 @@ import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) from kempnerforge.checkpoint.async_save import AsyncCheckpointer from kempnerforge.checkpoint.state import build_train_state, restore_train_state @@ -240,10 +246,15 @@ def save( dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir dcp_dir.mkdir(parents=True, exist_ok=True) - # Save distributed state (model + optimizer) via DCP + # Save distributed state (model + optimizer) via DCP. Use the DCP-aware + # state-dict helpers, NOT raw optimizer.state_dict(): on load they build a + # template with the optimizer moment tensors allocated so dcp.load can + # repopulate them. A freshly-constructed optimizer's raw state_dict() is + # empty, so the moments would be silently dropped on resume (Adam momentum + # resets to zero -> non-bit-exact resume; see manager load()). dcp_state = { - "model": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), + "model": get_model_state_dict(self.model), + "optimizer": get_optimizer_state_dict(self.model, self.optimizer), } # Dispatch the DCP save. For async modes this returns immediately but # FIRST awaits the previous in-flight flush, so any deferred @@ -457,19 +468,29 @@ def load( # Load distributed state via DCP dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir + load_model = exclude_keys is None or "model" not in exclude_keys + load_optim = exclude_keys is None or "optimizer" not in exclude_keys + + # Build properly-structured (FSDP-sharded, moment-allocated) templates via + # the DCP-aware getters so dcp.load can repopulate them, then write them + # back with the setters. Raw optimizer.state_dict() would be empty on a + # fresh optimizer, so dcp.load would find no moment tensors to fill and the + # AdamW momentum would silently reset to zero on resume. dcp_state: dict[str, Any] = {} - if exclude_keys is None or "model" not in exclude_keys: - dcp_state["model"] = self.model.state_dict() - if exclude_keys is None or "optimizer" not in exclude_keys: - dcp_state["optimizer"] = self.optimizer.state_dict() + if load_model: + dcp_state["model"] = get_model_state_dict(self.model) + if load_optim: + dcp_state["optimizer"] = get_optimizer_state_dict(self.model, self.optimizer) if dcp_state: dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group) - if "model" in dcp_state: - self.model.load_state_dict(dcp_state["model"]) - if "optimizer" in dcp_state: - self.optimizer.load_state_dict(dcp_state["optimizer"]) + if load_model: + set_model_state_dict(self.model, dcp_state["model"]) + if load_optim: + set_optimizer_state_dict( + self.model, self.optimizer, optim_state_dict=dcp_state["optimizer"] + ) # Load non-distributed state. On NFS/Lustre, independent stat() # calls can disagree briefly across ranks; if some ranks enter diff --git a/tests/distributed/test_checkpoint.py b/tests/distributed/test_checkpoint.py index bad1ff0..f91324e 100644 --- a/tests/distributed/test_checkpoint.py +++ b/tests/distributed/test_checkpoint.py @@ -105,6 +105,74 @@ def test_save_load_fsdp(self, distributed_env, shared_tmp_dir): f"Restored output differs: max diff={(ref_out - restored_out).abs().max().item()}" ) + def test_resume_restores_optimizer_moments(self, distributed_env, shared_tmp_dir): + """Optimizer moments must be restored into a FRESH optimizer on resume. + + Regression for the bug where the manager used raw ``optimizer.state_dict()`` + with DCP: a freshly-constructed optimizer has empty state, so the load + template had no moment tensors and ``dcp.load`` silently dropped them -> + AdamW ``exp_avg``/``exp_avg_sq`` reset to zero on every resume. Loading + into the *same* already-stepped optimizer (as test_save_load_fsdp does) + hides this, so here we load into a fresh optimizer like a real resume. + """ + from kempnerforge.config.schema import OptimizerConfig + + mesh = distributed_env + ckpt_dir = shared_tmp_dir + + def snapshot_moments(optimizer): + """Per-rank (sharded) (exp_avg, exp_avg_sq) clones, in param order.""" + out = [] + for group in optimizer.param_groups: + for p in group["params"]: + st = optimizer.state.get(p, {}) + if "exp_avg" in st: + out.append( + (st["exp_avg"].detach().clone(), st["exp_avg_sq"].detach().clone()) + ) + return out + + # Build, step a few times so the moments are non-trivial, then save. + torch.manual_seed(42) + model = Transformer(SMALL_CONFIG).cuda() + apply_fsdp2(model, mesh) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + for _ in range(3): + tokens = torch.randint(0, 512, (2, 32), device="cuda") + model(tokens).sum().backward() + opt.step() + opt.zero_grad() + + ref_moments = snapshot_moments(opt) + assert ref_moments, "optimizer had no moment state to test" + assert any(ea.abs().sum().item() > 0 for ea, _ in ref_moments), ( + "reference moments are all zero" + ) + + CheckpointManager(CheckpointConfig(dir=ckpt_dir, keep_last_n=2), model, opt).save( + step=3, tokens_seen=192 + ) + + # FRESH model + FRESH optimizer (never stepped -> empty optimizer state), + # exactly the resume scenario. Then load. + torch.manual_seed(42) + model2 = Transformer(SMALL_CONFIG).cuda() + apply_fsdp2(model2, mesh) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + assert not snapshot_moments(opt2), "fresh optimizer should have empty state before load" + + CheckpointManager(CheckpointConfig(dir=ckpt_dir, keep_last_n=2), model2, opt2).load() + + loaded_moments = snapshot_moments(opt2) + assert loaded_moments, ( + "optimizer moments were not restored (state still empty after load) " + "-- AdamW momentum reset on resume" + ) + assert len(loaded_moments) == len(ref_moments) + for i, ((rea, rev), (lea, lev)) in enumerate(zip(ref_moments, loaded_moments, strict=True)): + assert torch.equal(rea, lea), f"param {i}: exp_avg not restored bit-exactly" + assert torch.equal(rev, lev), f"param {i}: exp_avg_sq not restored bit-exactly" + def test_latest_symlink(self, distributed_env, shared_tmp_dir): """The 'latest' symlink should point to the most recent checkpoint.""" mesh = distributed_env diff --git a/tests/e2e/test_training_e2e.py b/tests/e2e/test_training_e2e.py index 6349651..86983bc 100644 --- a/tests/e2e/test_training_e2e.py +++ b/tests/e2e/test_training_e2e.py @@ -110,6 +110,11 @@ def _parse_last_loss(output: str) -> float | None: return float(matches[-1]) if matches else None +def _parse_losses_by_step(output: str) -> dict[int, str]: + """Map step -> logged loss string (exact text, for bit-identical comparison).""" + return {int(s): loss for s, loss in re.findall(r"\[step (\d+)\] loss=([\d.]+)", output)} + + # ============================================================================ # Single GPU # ============================================================================ @@ -397,6 +402,104 @@ def test_checkpoint_save_and_resume(tmp_path): assert "step=10" in output, "Did not resume from step 10" +# ============================================================================ +# Resume Determinism (regression: optimizer state must be restored bit-exactly) +# ============================================================================ + + +def _make_learnable_shard(tmp_path) -> str: + """Write a small on-disk dataset with a LEARNABLE repeating pattern. + + The resume bug perturbs the weights slightly; random tokens give a + near-constant loss (a model can't fit noise) that masks it, so we need data + the model can actually learn for the loss to be sensitive to the exact + weights. A simple token cycle (next token is always predictable) does that. + """ + import numpy as np + + data_dir = tmp_path / "data" + data_dir.mkdir(parents=True, exist_ok=True) + tokens = (np.arange(8_000_000) % 256).astype(np.uint16) + np.save(str(data_dir / "shard.npy"), tokens) + return str(data_dir) + + +def _check_resume_determinism(tmp_path, nproc: int) -> None: + """An interrupted+resumed run must reproduce an uninterrupted reference + bit-for-bit on the post-resume steps. + + If the optimizer moments (or RNG / dataloader position) are not restored + exactly, the loss trajectories diverge. Regression test for the bug where the + manager used raw ``optimizer.state_dict()`` with DCP: a freshly-built + optimizer's state is empty, so ``dcp.load`` had no moment tensors to fill and + AdamW momentum silently reset to zero on every resume. + """ + data_dir = _make_learnable_shard(tmp_path) + common = [ + DEBUG_CONFIG, + "--metrics.log_interval=1", + "--train.seed=1234", + "--train.compile_model=false", + "--model.vocab_size=256", + f"--data.dataset_path={data_dir}", + "--data.file_pattern=*.npy", + ] + reference = _run_training( + common + + [ + "--train.max_steps=20", + f"--checkpoint.dir={tmp_path}/ref", + "--checkpoint.interval=1000", + ], + nproc=nproc, + timeout=300, + ) + _assert_training_complete(reference, expected_steps=20) + + phase1 = _run_training( + common + + ["--train.max_steps=10", f"--checkpoint.dir={tmp_path}/test", "--checkpoint.interval=10"], + nproc=nproc, + timeout=300, + ) + _assert_training_complete(phase1, expected_steps=10) + + resumed = _run_training( + common + + [ + "--train.max_steps=20", + f"--checkpoint.dir={tmp_path}/test", + "--checkpoint.interval=1000", + ], + nproc=nproc, + timeout=300, + ) + _assert_training_complete(resumed, expected_steps=20) + assert "step=10" in (resumed.stdout + resumed.stderr), "did not resume from step 10" + + ref_losses = _parse_losses_by_step(reference.stdout + reference.stderr) + res_losses = _parse_losses_by_step(resumed.stdout + resumed.stderr) + for step in range(11, 21): + assert step in res_losses, f"resumed run missing step {step}" + assert res_losses[step] == ref_losses.get(step), ( + f"step {step}: resumed loss {res_losses[step]} != reference {ref_losses.get(step)} " + "-- post-resume state (optimizer moments / RNG / dataloader) not restored bit-exactly" + ) + + +@pytest.mark.e2e +def test_resume_determinism_single_gpu(tmp_path): + """Single GPU: resumed run must match the uninterrupted reference bit-for-bit.""" + _check_resume_determinism(tmp_path, nproc=1) + + +@pytest.mark.e2e +@requires_gpus(2) +def test_resume_determinism_2gpu_fsdp(tmp_path): + """2 GPU FSDP: resumed run must match the uninterrupted reference bit-for-bit.""" + _check_resume_determinism(tmp_path, nproc=2) + + # ============================================================================ # Pipeline Parallel + Checkpoint # ============================================================================ diff --git a/tests/integration/test_checkpoint_roundtrip.py b/tests/integration/test_checkpoint_roundtrip.py index c397da6..b7c5a13 100644 --- a/tests/integration/test_checkpoint_roundtrip.py +++ b/tests/integration/test_checkpoint_roundtrip.py @@ -164,3 +164,54 @@ def test_full_training_resume(self, tmp_path): assert abs(ref_loss - resumed_loss) < 1e-4, ( f"Resumed loss differs: ref={ref_loss:.6f}, resumed={resumed_loss:.6f}" ) + + def test_manager_restores_optimizer_moments_single_gpu(self, tmp_path): + """Single-GPU: CheckpointManager must restore optimizer moments into a + FRESH optimizer (the resume scenario). + + Plain torch.save/load round-trips an optimizer fine, so the tests above + don't exercise the bug. The manager's DCP path filled the load template + from a freshly-built optimizer's *empty* state_dict, silently dropping the + moments. This guards that single-GPU manager path (the distributed + equivalent is in tests/distributed/test_checkpoint.py). + """ + from kempnerforge.checkpoint.manager import CheckpointManager + from kempnerforge.config.schema import CheckpointConfig + + def moments(optimizer): + out = [] + for group in optimizer.param_groups: + for p in group["params"]: + st = optimizer.state.get(p, {}) + if "exp_avg" in st: + out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) + return out + + torch.manual_seed(0) + model = Transformer(CONFIG).to(DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + for _ in range(3): + tokens = torch.randint(0, 256, (2, 32), device=DEVICE) + model(tokens).sum().backward() + opt.step() + opt.zero_grad() + + ref = moments(opt) + assert ref and any(ea.abs().sum().item() > 0 for ea, _ in ref), ( + "no non-zero moments to test" + ) + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + # Fresh model + fresh optimizer (empty state) -> the resume scenario. + model2 = Transformer(CONFIG).to(DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + assert not moments(opt2), "fresh optimizer should have empty state before load" + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load() + + loaded = moments(opt2) + assert loaded, "optimizer moments not restored via manager (momentum reset on resume)" + for (rea, rev), (lea, lev) in zip(ref, loaded, strict=True): + assert torch.equal(rea, lea), "exp_avg not restored bit-exactly" + assert torch.equal(rev, lev), "exp_avg_sq not restored bit-exactly" From 8ccfdabdad54114bed78d9321f453db620c9db90 Mon Sep 17 00:00:00 2001 From: amazloumi Date: Thu, 28 May 2026 12:46:53 -0400 Subject: [PATCH 2/3] Add exclude_keys coverage tests for CheckpointManager.load --- .../integration/test_checkpoint_roundtrip.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests/integration/test_checkpoint_roundtrip.py b/tests/integration/test_checkpoint_roundtrip.py index b7c5a13..438b280 100644 --- a/tests/integration/test_checkpoint_roundtrip.py +++ b/tests/integration/test_checkpoint_roundtrip.py @@ -215,3 +215,94 @@ def moments(optimizer): for (rea, rev), (lea, lev) in zip(ref, loaded, strict=True): assert torch.equal(rea, lea), "exp_avg not restored bit-exactly" assert torch.equal(rev, lev), "exp_avg_sq not restored bit-exactly" + + def test_manager_load_excludes_optimizer(self, tmp_path): + """exclude_keys=['optimizer'] (the scripts/eval.py / fine-tune flow): + load model state but leave the fresh optimizer untouched. Covers the + load_optim=False branch in CheckpointManager.load.""" + from kempnerforge.checkpoint.manager import CheckpointManager + from kempnerforge.config.schema import CheckpointConfig + + def moments(optimizer): + out = [] + for group in optimizer.param_groups: + for p in group["params"]: + st = optimizer.state.get(p, {}) + if "exp_avg" in st: + out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) + return out + + torch.manual_seed(0) + model = Transformer(CONFIG).to(DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + for _ in range(3): + tokens = torch.randint(0, 256, (2, 32), device=DEVICE) + model(tokens).sum().backward() + opt.step() + opt.zero_grad() + ref_weights = {n: p.clone() for n, p in model.named_parameters()} + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + # Fresh model (different init) + fresh optimizer. + torch.manual_seed(99) + model2 = Transformer(CONFIG).to(DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + pre_load = {n: p.clone() for n, p in model2.named_parameters()} + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load( + exclude_keys=["optimizer"] + ) + + # Model state restored (weights now match the saved model). + for n, p in model2.named_parameters(): + assert not torch.equal(pre_load[n], p), f"{n} was not loaded" + assert torch.equal(ref_weights[n], p), f"{n} mismatch vs saved" + # Optimizer was excluded -> still empty. + assert not moments(opt2), "optimizer should remain empty on exclude_keys=['optimizer']" + + def test_manager_load_excludes_model(self, tmp_path): + """exclude_keys=['model'] (symmetric case): load optimizer moments but + leave the existing model weights untouched. Covers the load_model=False + branch in CheckpointManager.load.""" + from kempnerforge.checkpoint.manager import CheckpointManager + from kempnerforge.config.schema import CheckpointConfig + + def moments(optimizer): + out = [] + for group in optimizer.param_groups: + for p in group["params"]: + st = optimizer.state.get(p, {}) + if "exp_avg" in st: + out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) + return out + + torch.manual_seed(0) + model = Transformer(CONFIG).to(DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + for _ in range(3): + tokens = torch.randint(0, 256, (2, 32), device=DEVICE) + model(tokens).sum().backward() + opt.step() + opt.zero_grad() + ref_moments = moments(opt) + assert ref_moments, "fixture: expected non-empty optimizer state" + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + torch.manual_seed(99) + model2 = Transformer(CONFIG).to(DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + pre_load = {n: p.clone() for n, p in model2.named_parameters()} + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load(exclude_keys=["model"]) + + # Model state was excluded -> weights unchanged from fresh init. + for n, p in model2.named_parameters(): + assert torch.equal(pre_load[n], p), f"{n} should not have changed" + # Optimizer moments restored from the saved checkpoint. + loaded = moments(opt2) + assert loaded, "optimizer moments should load with exclude_keys=['model']" + for (rea, rev), (lea, lev) in zip(ref_moments, loaded, strict=True): + assert torch.equal(rea, lea), "exp_avg not restored" + assert torch.equal(rev, lev), "exp_avg_sq not restored" From e62b1b849604d66ec5b4ea6eeb2078ab5794e977 Mon Sep 17 00:00:00 2001 From: amazloumi Date: Thu, 28 May 2026 14:43:41 -0400 Subject: [PATCH 3/3] moving test from integration to unit test for codcov --- .../integration/test_checkpoint_roundtrip.py | 142 ------------------ tests/unit/test_checkpoint.py | 134 ++++++++++++++++- 2 files changed, 133 insertions(+), 143 deletions(-) diff --git a/tests/integration/test_checkpoint_roundtrip.py b/tests/integration/test_checkpoint_roundtrip.py index 438b280..c397da6 100644 --- a/tests/integration/test_checkpoint_roundtrip.py +++ b/tests/integration/test_checkpoint_roundtrip.py @@ -164,145 +164,3 @@ def test_full_training_resume(self, tmp_path): assert abs(ref_loss - resumed_loss) < 1e-4, ( f"Resumed loss differs: ref={ref_loss:.6f}, resumed={resumed_loss:.6f}" ) - - def test_manager_restores_optimizer_moments_single_gpu(self, tmp_path): - """Single-GPU: CheckpointManager must restore optimizer moments into a - FRESH optimizer (the resume scenario). - - Plain torch.save/load round-trips an optimizer fine, so the tests above - don't exercise the bug. The manager's DCP path filled the load template - from a freshly-built optimizer's *empty* state_dict, silently dropping the - moments. This guards that single-GPU manager path (the distributed - equivalent is in tests/distributed/test_checkpoint.py). - """ - from kempnerforge.checkpoint.manager import CheckpointManager - from kempnerforge.config.schema import CheckpointConfig - - def moments(optimizer): - out = [] - for group in optimizer.param_groups: - for p in group["params"]: - st = optimizer.state.get(p, {}) - if "exp_avg" in st: - out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) - return out - - torch.manual_seed(0) - model = Transformer(CONFIG).to(DEVICE) - opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) - for _ in range(3): - tokens = torch.randint(0, 256, (2, 32), device=DEVICE) - model(tokens).sum().backward() - opt.step() - opt.zero_grad() - - ref = moments(opt) - assert ref and any(ea.abs().sum().item() > 0 for ea, _ in ref), ( - "no non-zero moments to test" - ) - - ckpt_dir = str(tmp_path / "ckpt") - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) - - # Fresh model + fresh optimizer (empty state) -> the resume scenario. - model2 = Transformer(CONFIG).to(DEVICE) - opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) - assert not moments(opt2), "fresh optimizer should have empty state before load" - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load() - - loaded = moments(opt2) - assert loaded, "optimizer moments not restored via manager (momentum reset on resume)" - for (rea, rev), (lea, lev) in zip(ref, loaded, strict=True): - assert torch.equal(rea, lea), "exp_avg not restored bit-exactly" - assert torch.equal(rev, lev), "exp_avg_sq not restored bit-exactly" - - def test_manager_load_excludes_optimizer(self, tmp_path): - """exclude_keys=['optimizer'] (the scripts/eval.py / fine-tune flow): - load model state but leave the fresh optimizer untouched. Covers the - load_optim=False branch in CheckpointManager.load.""" - from kempnerforge.checkpoint.manager import CheckpointManager - from kempnerforge.config.schema import CheckpointConfig - - def moments(optimizer): - out = [] - for group in optimizer.param_groups: - for p in group["params"]: - st = optimizer.state.get(p, {}) - if "exp_avg" in st: - out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) - return out - - torch.manual_seed(0) - model = Transformer(CONFIG).to(DEVICE) - opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) - for _ in range(3): - tokens = torch.randint(0, 256, (2, 32), device=DEVICE) - model(tokens).sum().backward() - opt.step() - opt.zero_grad() - ref_weights = {n: p.clone() for n, p in model.named_parameters()} - - ckpt_dir = str(tmp_path / "ckpt") - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) - - # Fresh model (different init) + fresh optimizer. - torch.manual_seed(99) - model2 = Transformer(CONFIG).to(DEVICE) - opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) - pre_load = {n: p.clone() for n, p in model2.named_parameters()} - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load( - exclude_keys=["optimizer"] - ) - - # Model state restored (weights now match the saved model). - for n, p in model2.named_parameters(): - assert not torch.equal(pre_load[n], p), f"{n} was not loaded" - assert torch.equal(ref_weights[n], p), f"{n} mismatch vs saved" - # Optimizer was excluded -> still empty. - assert not moments(opt2), "optimizer should remain empty on exclude_keys=['optimizer']" - - def test_manager_load_excludes_model(self, tmp_path): - """exclude_keys=['model'] (symmetric case): load optimizer moments but - leave the existing model weights untouched. Covers the load_model=False - branch in CheckpointManager.load.""" - from kempnerforge.checkpoint.manager import CheckpointManager - from kempnerforge.config.schema import CheckpointConfig - - def moments(optimizer): - out = [] - for group in optimizer.param_groups: - for p in group["params"]: - st = optimizer.state.get(p, {}) - if "exp_avg" in st: - out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) - return out - - torch.manual_seed(0) - model = Transformer(CONFIG).to(DEVICE) - opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) - for _ in range(3): - tokens = torch.randint(0, 256, (2, 32), device=DEVICE) - model(tokens).sum().backward() - opt.step() - opt.zero_grad() - ref_moments = moments(opt) - assert ref_moments, "fixture: expected non-empty optimizer state" - - ckpt_dir = str(tmp_path / "ckpt") - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) - - torch.manual_seed(99) - model2 = Transformer(CONFIG).to(DEVICE) - opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) - pre_load = {n: p.clone() for n, p in model2.named_parameters()} - CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load(exclude_keys=["model"]) - - # Model state was excluded -> weights unchanged from fresh init. - for n, p in model2.named_parameters(): - assert torch.equal(pre_load[n], p), f"{n} should not have changed" - # Optimizer moments restored from the saved checkpoint. - loaded = moments(opt2) - assert loaded, "optimizer moments should load with exclude_keys=['model']" - for (rea, rev), (lea, lev) in zip(ref_moments, loaded, strict=True): - assert torch.equal(rea, lea), "exp_avg not restored" - assert torch.equal(rev, lev), "exp_avg_sq not restored" diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py index fce4be2..2bb496e 100644 --- a/tests/unit/test_checkpoint.py +++ b/tests/unit/test_checkpoint.py @@ -16,7 +16,14 @@ restore_train_state, set_rng_state, ) -from kempnerforge.config.schema import AsyncCheckpointMode, CheckpointConfig +from kempnerforge.config.schema import ( + AsyncCheckpointMode, + CheckpointConfig, + ModelConfig, + OptimizerConfig, +) +from kempnerforge.model.transformer import Transformer +from kempnerforge.training.optimizer import build_optimizer # --------------------------------------------------------------------------- # RNG state capture/restore @@ -1023,3 +1030,128 @@ def test_sync_barrier_uses_dedicated_group_not_default(self, tmp_path, monkeypat # The barrier must target the dedicated gloo group, never the # default group (group=None) that DCP's async thread uses. assert barrier_calls == ["GLOO_PG"] + + +# --------------------------------------------------------------------------- +# CheckpointManager.load -- moment restore + exclude_keys branches +# --------------------------------------------------------------------------- + + +class TestCheckpointManagerLoad: + """End-to-end ``load()`` tests on the real DCP save/load path + (single-process mode). Lives in tests/unit/ so it counts toward CI + coverage of ``manager.py`` -- the same single-GPU regression coverage + that previously lived in tests/integration/. + """ + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + MODEL_CONFIG = ModelConfig(dim=64, n_layers=2, n_heads=2, vocab_size=256, max_seq_len=64) + + @staticmethod + def _moments(optimizer): + out = [] + for group in optimizer.param_groups: + for p in group["params"]: + st = optimizer.state.get(p, {}) + if "exp_avg" in st: + out.append((st["exp_avg"].clone(), st["exp_avg_sq"].clone())) + return out + + def _train_few_steps(self, model, opt): + for _ in range(3): + tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE) + model(tokens).sum().backward() + opt.step() + opt.zero_grad() + + def test_restores_optimizer_moments_into_fresh_optimizer(self, tmp_path): + """Regression: the manager's DCP path used to fill the load template + from a freshly-built optimizer's *empty* state_dict, silently dropping + the moments. After the fix, moments must restore bit-exactly into a + fresh optimizer.""" + from kempnerforge.checkpoint.manager import CheckpointManager + + torch.manual_seed(0) + model = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + self._train_few_steps(model, opt) + + ref = self._moments(opt) + assert ref and any(ea.abs().sum().item() > 0 for ea, _ in ref), ( + "no non-zero moments to test" + ) + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + model2 = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + assert not self._moments(opt2), "fresh optimizer should have empty state before load" + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load() + + loaded = self._moments(opt2) + assert loaded, "optimizer moments not restored (momentum reset on resume)" + for (rea, rev), (lea, lev) in zip(ref, loaded, strict=True): + assert torch.equal(rea, lea), "exp_avg not restored bit-exactly" + assert torch.equal(rev, lev), "exp_avg_sq not restored bit-exactly" + + def test_load_excludes_optimizer(self, tmp_path): + """``exclude_keys=['optimizer']`` (the scripts/eval.py / fine-tune + flow): load model state but leave the fresh optimizer untouched. + Covers the inner ``if load_optim:`` False branch in + ``CheckpointManager.load``.""" + from kempnerforge.checkpoint.manager import CheckpointManager + + torch.manual_seed(0) + model = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + self._train_few_steps(model, opt) + ref_weights = {n: p.clone() for n, p in model.named_parameters()} + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + torch.manual_seed(99) + model2 = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + pre_load = {n: p.clone() for n, p in model2.named_parameters()} + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load( + exclude_keys=["optimizer"] + ) + + for n, p in model2.named_parameters(): + assert not torch.equal(pre_load[n], p), f"{n} was not loaded" + assert torch.equal(ref_weights[n], p), f"{n} mismatch vs saved" + assert not self._moments(opt2), ( + "optimizer should remain empty on exclude_keys=['optimizer']" + ) + + def test_load_excludes_model(self, tmp_path): + """``exclude_keys=['model']`` (symmetric case): load optimizer moments + but leave the existing model weights untouched. Covers the inner + ``if load_model:`` False branch in ``CheckpointManager.load``.""" + from kempnerforge.checkpoint.manager import CheckpointManager + + torch.manual_seed(0) + model = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) + self._train_few_steps(model, opt) + ref_moments = self._moments(opt) + assert ref_moments, "fixture: expected non-empty optimizer state" + + ckpt_dir = str(tmp_path / "ckpt") + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model, opt).save(step=3) + + torch.manual_seed(99) + model2 = Transformer(self.MODEL_CONFIG).to(self.DEVICE) + opt2 = build_optimizer(model2, OptimizerConfig(lr=1e-3, fused=False)) + pre_load = {n: p.clone() for n, p in model2.named_parameters()} + CheckpointManager(CheckpointConfig(dir=ckpt_dir), model2, opt2).load(exclude_keys=["model"]) + + for n, p in model2.named_parameters(): + assert torch.equal(pre_load[n], p), f"{n} should not have changed" + loaded = self._moments(opt2) + assert loaded, "optimizer moments should load with exclude_keys=['model']" + for (rea, rev), (lea, lev) in zip(ref_moments, loaded, strict=True): + assert torch.equal(rea, lea), "exp_avg not restored" + assert torch.equal(rev, lev), "exp_avg_sq not restored"