From 1db8cc1197521f9e040517fc37fb7b27bb04ab1e Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 2 Mar 2026 10:15:05 +0000 Subject: [PATCH 1/5] Feat: Implement GradientAccumulation for SupervisedTrainer (Issue #6100) Closes #6100 Signed-off-by: Soumya Snigdha Kundu --- monai/engines/__init__.py | 1 + monai/engines/utils.py | 119 +++++++ tests/engines/test_gradient_accumulation.py | 354 ++++++++++++++++++++ 3 files changed, 474 insertions(+) create mode 100644 tests/engines/test_gradient_accumulation.py diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 93cc40e292..ea87730f4a 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -15,6 +15,7 @@ from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( DiffusionPrepareBatch, + GradientAccumulation, IterationEvents, PrepareBatch, PrepareBatchDefault, diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 9095f8d943..990a13ae7f 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -41,6 +41,7 @@ "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", + "GradientAccumulation", ] @@ -360,3 +361,121 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: """ return current_metric > prev_best + + +def _noop(*args: Any, **kwargs: Any) -> None: + """No-op callable used to suppress optimizer/scaler methods during gradient accumulation.""" + + +class GradientAccumulation: + """ + Callable class implementing gradient accumulation for use with ``SupervisedTrainer``. + + Gradients are accumulated over ``accumulation_steps`` mini-batches before calling + ``optimizer.step()``, simulating a larger effective batch size on memory-constrained + hardware. + + Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``:: + + trainer = SupervisedTrainer( + ..., + iteration_update=GradientAccumulation(accumulation_steps=4), + ) + + All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``, + ``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so + existing handlers (checkpoint savers, metric loggers, etc.) are unaffected. + + When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch + even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently + discarded. For iterable datasets (``epoch_length=None``) this flush does not apply. + + The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled** + original loss value, so metrics and loggers report the true loss. Internally + the loss is divided by ``accumulation_steps`` for the backward pass only. + + Args: + accumulation_steps: number of mini-batches to accumulate before updating + weights. Must be a positive integer. Default: 2. + + Raises: + ValueError: when ``accumulation_steps`` is not a positive integer. + """ + + def __init__(self, accumulation_steps: int = 2) -> None: + if not isinstance(accumulation_steps, int) or accumulation_steps < 1: + raise ValueError( + f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}." + ) + self.accumulation_steps = accumulation_steps + + def __repr__(self) -> str: + return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})" + + def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict: + """ + Execute one iteration with gradient accumulation. + + Args: + engine: the Ignite engine (usually ``SupervisedTrainer``). + batchdata: batch data for this iteration. + + Returns: + the output dict from ``engine._iteration()``. + """ + acc = self.accumulation_steps + + if acc == 1: + return engine._iteration(engine, batchdata) + + # engine.state.iteration is 1-indexed and already incremented before __call__ + epoch_length = engine.state.epoch_length # None for iterable datasets + if epoch_length is not None: + local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length + else: + local_iter = engine.state.iteration - 1 # 0-indexed global + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 + + # Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle. + original_zero_grad = engine.optimizer.zero_grad + if not should_zero_grad: + engine.optimizer.zero_grad = _noop + + # Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch + # gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients + # will average to the same value they would with the full batch. + original_loss_fn = engine.loss_function + engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc + + # Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle. + # Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training. + original_step = engine.optimizer.step + original_scaler_step = None + original_scaler_update = None + if not should_step: + engine.optimizer.step = _noop + if hasattr(engine, "scaler") and engine.scaler is not None: + original_scaler_step = engine.scaler.step + original_scaler_update = engine.scaler.update + engine.scaler.step = _noop + engine.scaler.update = _noop + + try: + result = engine._iteration(engine, batchdata) + finally: + engine.optimizer.zero_grad = original_zero_grad + engine.loss_function = original_loss_fn + engine.optimizer.step = original_step + if original_scaler_step is not None: + engine.scaler.step = original_scaler_step + engine.scaler.update = original_scaler_update + + # Restore the unscaled loss for logging and metrics. The backward pass + # already used the scaled value, so this only affects what handlers see. + if CommonKeys.LOSS in result: + result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc + + return result diff --git a/tests/engines/test_gradient_accumulation.py b/tests/engines/test_gradient_accumulation.py new file mode 100644 index 0000000000..9dd01b0094 --- /dev/null +++ b/tests/engines/test_gradient_accumulation.py @@ -0,0 +1,354 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock + +import torch +import torch.nn as nn +from parameterized import parameterized + +from monai.engines import GradientAccumulation +from monai.utils import IgniteInfo, min_version, optional_import +from monai.utils.enums import CommonKeys + +_, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version) + +INVALID_ACCUMULATION_STEPS = [ + (0,), (-1,), (2.5,), ("2",), +] + +SUPPRESSION_CASES = [ + # (attr_name, acc, epoch_length, num_iters, expected) + ( + "zero_grad", 4, 12, 12, + [True, False, False, False, True, False, False, False, True, False, False, False], + ), + ( + "step", 4, 12, 12, + [False, False, False, True, False, False, False, True, False, False, False, True], + ), + # epoch_length=11 not divisible by 4 → flush at epoch end + ( + "step", 4, 11, 11, + [False, False, False, True, False, False, False, True, False, False, True], + ), + # epoch_length=None (iterable dataset) → no epoch flush + ( + "step", 4, None, 10, + [False, False, False, True, False, False, False, True, False, False], + ), +] + + +def _make_engine(epoch_length, iteration=1, scaler=None): + """Create a mock engine whose _iteration observes patched methods.""" + engine = MagicMock() + engine.state.epoch_length = epoch_length + engine.state.iteration = iteration + engine.scaler = scaler + engine.optimizer = MagicMock() + engine.loss_function = MagicMock(return_value=torch.tensor(1.0)) + engine._iteration.return_value = {CommonKeys.LOSS: torch.tensor(1.0)} + return engine + + +class TestGradientAccumulation(unittest.TestCase): + """Test cases for GradientAccumulation callable.""" + + # ---- input validation ---- + + @parameterized.expand(INVALID_ACCUMULATION_STEPS) + def test_invalid_accumulation_steps(self, value) -> None: + with self.assertRaises(ValueError) as cm: + GradientAccumulation(accumulation_steps=value) + self.assertIn("positive integer", str(cm.exception)) + + def test_repr(self) -> None: + ga = GradientAccumulation(accumulation_steps=8) + self.assertEqual(repr(ga), "GradientAccumulation(accumulation_steps=8)") + + # ---- passthrough ---- + + def test_passthrough_when_accumulation_steps_1(self) -> None: + grad_accum = GradientAccumulation(accumulation_steps=1) + engine = _make_engine(epoch_length=12, iteration=1) + expected_output = {CommonKeys.LOSS: torch.tensor(0.5), CommonKeys.PRED: torch.tensor([1.0])} + engine._iteration.return_value = expected_output + + result = grad_accum(engine, {}) + + engine._iteration.assert_called_once_with(engine, {}) + self.assertIs(result, expected_output) + + # ---- suppression logic ---- + + @parameterized.expand(SUPPRESSION_CASES) + def test_suppression(self, attr_name, acc, epoch_length, num_iters, expected) -> None: + grad_accum = GradientAccumulation(accumulation_steps=acc) + original = MagicMock(name=attr_name) + engine = _make_engine(epoch_length) + setattr(engine.optimizer, attr_name, original) + + saw_original: list[bool] = [] + + def fake_iteration(eng, batch): + saw_original.append(getattr(eng.optimizer, attr_name) is original) + return {CommonKeys.LOSS: torch.tensor(1.0)} + + engine._iteration.side_effect = fake_iteration + + for i in range(1, num_iters + 1): + engine.state.iteration = i + grad_accum(engine, {}) + + self.assertEqual(saw_original, expected) + + # ---- patching / restoration ---- + + def test_patching_and_restoration(self) -> None: + engine = _make_engine(epoch_length=4, iteration=1) + + original_zero_grad = MagicMock(name="original_zero_grad") + original_step = MagicMock(name="original_step") + original_loss_fn = MagicMock(return_value=torch.tensor(0.5), name="original_loss_fn") + + engine.optimizer.zero_grad = original_zero_grad + engine.optimizer.step = original_step + engine.loss_function = original_loss_fn + + GradientAccumulation(accumulation_steps=2)(engine, {}) + + self.assertIs(engine.optimizer.zero_grad, original_zero_grad) + self.assertIs(engine.optimizer.step, original_step) + self.assertIs(engine.loss_function, original_loss_fn) + + def test_restoration_after_exception(self) -> None: + """try/finally must restore all originals even when _iteration raises.""" + engine = _make_engine(epoch_length=8, iteration=2) + + original_zero_grad = MagicMock(name="zero_grad") + original_step = MagicMock(name="step") + original_loss_fn = MagicMock(return_value=torch.tensor(1.0), name="loss_fn") + original_scaler_step = MagicMock(name="scaler_step") + original_scaler_update = MagicMock(name="scaler_update") + + engine.optimizer.zero_grad = original_zero_grad + engine.optimizer.step = original_step + engine.loss_function = original_loss_fn + engine.scaler = MagicMock() + engine.scaler.step = original_scaler_step + engine.scaler.update = original_scaler_update + engine._iteration.side_effect = RuntimeError("boom") + + with self.assertRaises(RuntimeError): + GradientAccumulation(accumulation_steps=4)(engine, {}) + + self.assertIs(engine.optimizer.zero_grad, original_zero_grad) + self.assertIs(engine.optimizer.step, original_step) + self.assertIs(engine.loss_function, original_loss_fn) + self.assertIs(engine.scaler.step, original_scaler_step) + self.assertIs(engine.scaler.update, original_scaler_update) + + # ---- scaler ---- + + def test_scaler_not_patched_when_stepping(self) -> None: + engine = _make_engine(epoch_length=4, iteration=2) # acc=2 → should_step=True + original_scaler_step = MagicMock(name="scaler_step") + original_scaler_update = MagicMock(name="scaler_update") + engine.scaler = MagicMock() + engine.scaler.step = original_scaler_step + engine.scaler.update = original_scaler_update + + GradientAccumulation(accumulation_steps=2)(engine, {}) + + self.assertIs(engine.scaler.step, original_scaler_step) + self.assertIs(engine.scaler.update, original_scaler_update) + + def test_scaler_patched_and_restored_when_suppressed(self) -> None: + engine = _make_engine(epoch_length=8, iteration=2) # should_step=False for acc=4 + original_scaler_step = MagicMock(name="scaler_step") + original_scaler_update = MagicMock(name="scaler_update") + engine.scaler = MagicMock() + engine.scaler.step = original_scaler_step + engine.scaler.update = original_scaler_update + + scaler_was_patched = [] + + def check_scaler(eng, batch): + scaler_was_patched.append(eng.scaler.step is not original_scaler_step) + return {CommonKeys.LOSS: torch.tensor(0.5)} + + engine._iteration.side_effect = check_scaler + GradientAccumulation(accumulation_steps=4)(engine, {}) + + self.assertTrue(scaler_was_patched[0], "scaler.step should be patched during _iteration") + self.assertIs(engine.scaler.step, original_scaler_step) + self.assertIs(engine.scaler.update, original_scaler_update) + + def test_no_scaler_attribute(self) -> None: + """Engine without a scaler attribute at all should work (hasattr returns False).""" + engine = _make_engine(epoch_length=4, iteration=1) + del engine.scaler # MagicMock auto-creates attrs; delete to test hasattr branch + + result = GradientAccumulation(accumulation_steps=2)(engine, {}) + self.assertIn(CommonKeys.LOSS, result) + + def test_scaler_is_none(self) -> None: + engine = _make_engine(epoch_length=4, iteration=2) + engine.scaler = None + + result = GradientAccumulation(accumulation_steps=2)(engine, {}) + self.assertIn(CommonKeys.LOSS, result) + + # ---- batch / loss ---- + + def test_batch_data_passed_correctly(self) -> None: + engine = _make_engine(epoch_length=4, iteration=1) + test_batch = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)} + + GradientAccumulation(accumulation_steps=2)(engine, test_batch) + + engine._iteration.assert_called_once() + call_args = engine._iteration.call_args + self.assertEqual(call_args[0][0], engine) + self.assertEqual(call_args[0][1], test_batch) + + def test_loss_output_is_unscaled(self) -> None: + """Output loss should be rescaled to the original (unscaled) value.""" + engine = _make_engine(epoch_length=9, iteration=1) + engine.scaler = None + original_loss = torch.tensor(6.0) + engine.loss_function = MagicMock(return_value=original_loss) + + def fake_iteration(*args, **kwargs): + scaled = engine.loss_function() + return {CommonKeys.LOSS: scaled} + + engine._iteration.side_effect = fake_iteration + + result = GradientAccumulation(accumulation_steps=3)(engine, {}) + self.assertAlmostEqual(result[CommonKeys.LOSS].item(), 6.0, places=5) + + # ---- integration (require ignite) ---- + + @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") + def test_integration_gradient_equivalence(self) -> None: + """Accumulated gradients over N mini-batches equal one large-batch step.""" + from monai.engines import SupervisedTrainer + + torch.manual_seed(42) + acc_steps, lr = 4, 0.1 + batches = [ + {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} + for _ in range(acc_steps) + ] + + ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + + ref_opt.zero_grad() + for batch in batches: + loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps + loss.backward() + ref_opt.step() + + trainer = SupervisedTrainer( + device=torch.device("cpu"), max_epochs=1, train_data_loader=batches, + network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + ) + trainer.run() + + for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): + torch.testing.assert_close(p_test.data, p_ref.data) + + @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") + def test_integration_epoch_boundary_flush(self) -> None: + """When epoch_length is not divisible by acc_steps, flush at epoch end.""" + from monai.engines import SupervisedTrainer + + torch.manual_seed(123) + acc_steps, lr = 3, 0.1 + batches = [ + {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} + for _ in range(5) + ] + + ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + + for cycle_batches in [batches[:3], batches[3:]]: + ref_opt.zero_grad() + for batch in cycle_batches: + loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps + loss.backward() + ref_opt.step() + + trainer = SupervisedTrainer( + device=torch.device("cpu"), max_epochs=1, train_data_loader=batches, + network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + ) + trainer.run() + + for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): + torch.testing.assert_close(p_test.data, p_ref.data) + + @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") + def test_integration_multi_epoch(self) -> None: + """Verify gradient accumulation is correct across multiple epochs.""" + from monai.engines import SupervisedTrainer + + torch.manual_seed(42) + acc_steps, lr, num_epochs = 2, 0.1, 3 + batches = [ + {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} + for _ in range(4) + ] + + ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + + for _epoch in range(num_epochs): + for cycle_batches in [batches[:2], batches[2:]]: + ref_opt.zero_grad() + for batch in cycle_batches: + loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps + loss.backward() + ref_opt.step() + + trainer = SupervisedTrainer( + device=torch.device("cpu"), max_epochs=num_epochs, train_data_loader=batches, + network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + ) + trainer.run() + + for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): + torch.testing.assert_close(p_test.data, p_ref.data) + + +def _make_model_pair(lr): + """Create a reference and test model pair with identical initial weights.""" + ref_model = nn.Linear(4, 1, bias=False) + init_weight = ref_model.weight.data.clone() + ref_opt = torch.optim.SGD(ref_model.parameters(), lr=lr) + ref_model.train() + + test_model = nn.Linear(4, 1, bias=False) + test_model.weight.data.copy_(init_weight) + test_opt = torch.optim.SGD(test_model.parameters(), lr=lr) + + return ref_model, test_model, ref_opt, test_opt, init_weight + + +if __name__ == "__main__": + unittest.main() From a3eca14a77445eb29251bf03706b85a867096523 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 11:39:58 +0000 Subject: [PATCH 2/5] fix for Flake8-py3 codeformat error Signed-off-by: Soumya Snigdha Kundu --- monai/engines/utils.py | 4 +- tests/engines/test_gradient_accumulation.py | 63 +++++++++------------ 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 990a13ae7f..bbec6ba012 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -404,9 +404,7 @@ class GradientAccumulation: def __init__(self, accumulation_steps: int = 2) -> None: if not isinstance(accumulation_steps, int) or accumulation_steps < 1: - raise ValueError( - f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}." - ) + raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") self.accumulation_steps = accumulation_steps def __repr__(self) -> str: diff --git a/tests/engines/test_gradient_accumulation.py b/tests/engines/test_gradient_accumulation.py index 9dd01b0094..13b7e666f3 100644 --- a/tests/engines/test_gradient_accumulation.py +++ b/tests/engines/test_gradient_accumulation.py @@ -24,30 +24,16 @@ _, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version) -INVALID_ACCUMULATION_STEPS = [ - (0,), (-1,), (2.5,), ("2",), -] +INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] SUPPRESSION_CASES = [ # (attr_name, acc, epoch_length, num_iters, expected) - ( - "zero_grad", 4, 12, 12, - [True, False, False, False, True, False, False, False, True, False, False, False], - ), - ( - "step", 4, 12, 12, - [False, False, False, True, False, False, False, True, False, False, False, True], - ), + ("zero_grad", 4, 12, 12, [True, False, False, False, True, False, False, False, True, False, False, False]), + ("step", 4, 12, 12, [False, False, False, True, False, False, False, True, False, False, False, True]), # epoch_length=11 not divisible by 4 → flush at epoch end - ( - "step", 4, 11, 11, - [False, False, False, True, False, False, False, True, False, False, True], - ), + ("step", 4, 11, 11, [False, False, False, True, False, False, False, True, False, False, True]), # epoch_length=None (iterable dataset) → no epoch flush - ( - "step", 4, None, 10, - [False, False, False, True, False, False, False, True, False, False], - ), + ("step", 4, None, 10, [False, False, False, True, False, False, False, True, False, False]), ] @@ -249,10 +235,7 @@ def test_integration_gradient_equivalence(self) -> None: torch.manual_seed(42) acc_steps, lr = 4, 0.1 - batches = [ - {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} - for _ in range(acc_steps) - ] + batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)] ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) @@ -263,8 +246,12 @@ def test_integration_gradient_equivalence(self) -> None: ref_opt.step() trainer = SupervisedTrainer( - device=torch.device("cpu"), max_epochs=1, train_data_loader=batches, - network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=batches, + network=test_model, + optimizer=test_opt, + loss_function=nn.MSELoss(), iteration_update=GradientAccumulation(accumulation_steps=acc_steps), ) trainer.run() @@ -279,10 +266,7 @@ def test_integration_epoch_boundary_flush(self) -> None: torch.manual_seed(123) acc_steps, lr = 3, 0.1 - batches = [ - {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} - for _ in range(5) - ] + batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(5)] ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) @@ -294,8 +278,12 @@ def test_integration_epoch_boundary_flush(self) -> None: ref_opt.step() trainer = SupervisedTrainer( - device=torch.device("cpu"), max_epochs=1, train_data_loader=batches, - network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=batches, + network=test_model, + optimizer=test_opt, + loss_function=nn.MSELoss(), iteration_update=GradientAccumulation(accumulation_steps=acc_steps), ) trainer.run() @@ -310,10 +298,7 @@ def test_integration_multi_epoch(self) -> None: torch.manual_seed(42) acc_steps, lr, num_epochs = 2, 0.1, 3 - batches = [ - {CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} - for _ in range(4) - ] + batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)] ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) @@ -326,8 +311,12 @@ def test_integration_multi_epoch(self) -> None: ref_opt.step() trainer = SupervisedTrainer( - device=torch.device("cpu"), max_epochs=num_epochs, train_data_loader=batches, - network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), + device=torch.device("cpu"), + max_epochs=num_epochs, + train_data_loader=batches, + network=test_model, + optimizer=test_opt, + loss_function=nn.MSELoss(), iteration_update=GradientAccumulation(accumulation_steps=acc_steps), ) trainer.run() From cad21055d598f73c1abb45f2b4f3d1eb8a6f6078 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 15:00:53 +0000 Subject: [PATCH 3/5] fix mypy error Signed-off-by: Soumya Snigdha Kundu --- monai/engines/utils.py | 5 ++++- tests/engines/test_gradient_accumulation.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index bbec6ba012..27e78d2e0a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -423,8 +423,11 @@ def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict: """ acc = self.accumulation_steps + result: dict + if acc == 1: - return engine._iteration(engine, batchdata) + result = engine._iteration(engine, batchdata) + return result # engine.state.iteration is 1-indexed and already incremented before __call__ epoch_length = engine.state.epoch_length # None for iterable datasets diff --git a/tests/engines/test_gradient_accumulation.py b/tests/engines/test_gradient_accumulation.py index 13b7e666f3..eceff47c02 100644 --- a/tests/engines/test_gradient_accumulation.py +++ b/tests/engines/test_gradient_accumulation.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from typing import Any from unittest.mock import MagicMock import torch @@ -201,7 +202,7 @@ def test_scaler_is_none(self) -> None: def test_batch_data_passed_correctly(self) -> None: engine = _make_engine(epoch_length=4, iteration=1) - test_batch = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)} + test_batch: dict[str, Any] = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)} GradientAccumulation(accumulation_steps=2)(engine, test_batch) From 61d2a9fceb01224f62f4b0e8d0521eb309b71160 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 9 Mar 2026 09:44:13 +0000 Subject: [PATCH 4/5] Integrate gradient accumulation directly into SupervisedTrainer Replace external GradientAccumulation callable class with a native `accumulation_steps` constructor parameter on SupervisedTrainer, per reviewer feedback. This eliminates monkey-patching of optimizer/loss/scaler internals and instead uses simple conditionals in `_iteration()`. Based on feedback from @ericspod Signed-off-by: Soumya Snigdha Kundu --- monai/engines/__init__.py | 1 - monai/engines/trainer.py | 43 ++- monai/engines/utils.py | 120 ------- tests/engines/test_gradient_accumulation.py | 335 ++++++++------------ 4 files changed, 165 insertions(+), 334 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index ea87730f4a..93cc40e292 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -15,7 +15,6 @@ from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( DiffusionPrepareBatch, - GradientAccumulation, IterationEvents, PrepareBatch, PrepareBatchDefault, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index b69a5015bb..53e75eac40 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -131,6 +131,12 @@ class SupervisedTrainer(Trainer): `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. + accumulation_steps: number of mini-batches over which to accumulate gradients before + calling ``optimizer.step()``, effectively simulating a larger batch size on + memory-constrained hardware. Must be a positive integer. Default: 1 (no accumulation). + When ``epoch_length`` is known and not divisible by ``accumulation_steps``, a flush + (optimizer step) is performed at the end of each epoch so no gradients are silently + discarded. The loss stored in ``engine.state.output`` is always the **unscaled** value. """ def __init__( @@ -160,7 +166,10 @@ def __init__( amp_kwargs: dict | None = None, compile: bool = False, compile_kwargs: dict | None = None, + accumulation_steps: int = 1, ) -> None: + if not isinstance(accumulation_steps, int) or accumulation_steps < 1: + raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") super().__init__( device=device, max_epochs=max_epochs, @@ -190,6 +199,7 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer self.optim_set_to_none = optim_set_to_none + self.accumulation_steps = accumulation_steps def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict: """ @@ -245,21 +255,42 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) + # Determine gradient accumulation state + acc = engine.accumulation_steps + if acc > 1: + epoch_length = engine.state.epoch_length + if epoch_length is not None: + local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length + else: + local_iter = engine.state.iteration - 1 # 0-indexed global + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 + else: + should_zero_grad = True + should_step = True + engine.network.train() - engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + if should_zero_grad: + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) if engine.amp and engine.scaler is not None: with torch.autocast("cuda", **engine.amp_kwargs): _compute_pred_loss() - engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() + loss = engine.state.output[Keys.LOSS] + engine.scaler.scale(loss / acc if acc > 1 else loss).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - engine.scaler.step(engine.optimizer) - engine.scaler.update() + if should_step: + engine.scaler.step(engine.optimizer) + engine.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + loss = engine.state.output[Keys.LOSS] + (loss / acc if acc > 1 else loss).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - engine.optimizer.step() + if should_step: + engine.optimizer.step() # copy back meta info if self.compile: if inputs_meta is not None: diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 27e78d2e0a..9095f8d943 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -41,7 +41,6 @@ "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", - "GradientAccumulation", ] @@ -361,122 +360,3 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: """ return current_metric > prev_best - - -def _noop(*args: Any, **kwargs: Any) -> None: - """No-op callable used to suppress optimizer/scaler methods during gradient accumulation.""" - - -class GradientAccumulation: - """ - Callable class implementing gradient accumulation for use with ``SupervisedTrainer``. - - Gradients are accumulated over ``accumulation_steps`` mini-batches before calling - ``optimizer.step()``, simulating a larger effective batch size on memory-constrained - hardware. - - Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``:: - - trainer = SupervisedTrainer( - ..., - iteration_update=GradientAccumulation(accumulation_steps=4), - ) - - All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``, - ``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so - existing handlers (checkpoint savers, metric loggers, etc.) are unaffected. - - When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch - even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently - discarded. For iterable datasets (``epoch_length=None``) this flush does not apply. - - The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled** - original loss value, so metrics and loggers report the true loss. Internally - the loss is divided by ``accumulation_steps`` for the backward pass only. - - Args: - accumulation_steps: number of mini-batches to accumulate before updating - weights. Must be a positive integer. Default: 2. - - Raises: - ValueError: when ``accumulation_steps`` is not a positive integer. - """ - - def __init__(self, accumulation_steps: int = 2) -> None: - if not isinstance(accumulation_steps, int) or accumulation_steps < 1: - raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") - self.accumulation_steps = accumulation_steps - - def __repr__(self) -> str: - return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})" - - def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict: - """ - Execute one iteration with gradient accumulation. - - Args: - engine: the Ignite engine (usually ``SupervisedTrainer``). - batchdata: batch data for this iteration. - - Returns: - the output dict from ``engine._iteration()``. - """ - acc = self.accumulation_steps - - result: dict - - if acc == 1: - result = engine._iteration(engine, batchdata) - return result - - # engine.state.iteration is 1-indexed and already incremented before __call__ - epoch_length = engine.state.epoch_length # None for iterable datasets - if epoch_length is not None: - local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length - else: - local_iter = engine.state.iteration - 1 # 0-indexed global - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 - - # Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle. - original_zero_grad = engine.optimizer.zero_grad - if not should_zero_grad: - engine.optimizer.zero_grad = _noop - - # Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch - # gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients - # will average to the same value they would with the full batch. - original_loss_fn = engine.loss_function - engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc - - # Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle. - # Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training. - original_step = engine.optimizer.step - original_scaler_step = None - original_scaler_update = None - if not should_step: - engine.optimizer.step = _noop - if hasattr(engine, "scaler") and engine.scaler is not None: - original_scaler_step = engine.scaler.step - original_scaler_update = engine.scaler.update - engine.scaler.step = _noop - engine.scaler.update = _noop - - try: - result = engine._iteration(engine, batchdata) - finally: - engine.optimizer.zero_grad = original_zero_grad - engine.loss_function = original_loss_fn - engine.optimizer.step = original_step - if original_scaler_step is not None: - engine.scaler.step = original_scaler_step - engine.scaler.update = original_scaler_update - - # Restore the unscaled loss for logging and metrics. The backward pass - # already used the scaled value, so this only affects what handlers see. - if CommonKeys.LOSS in result: - result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc - - return result diff --git a/tests/engines/test_gradient_accumulation.py b/tests/engines/test_gradient_accumulation.py index eceff47c02..2578cba1f3 100644 --- a/tests/engines/test_gradient_accumulation.py +++ b/tests/engines/test_gradient_accumulation.py @@ -12,14 +12,11 @@ from __future__ import annotations import unittest -from typing import Any -from unittest.mock import MagicMock import torch import torch.nn as nn from parameterized import parameterized -from monai.engines import GradientAccumulation from monai.utils import IgniteInfo, min_version, optional_import from monai.utils.enums import CommonKeys @@ -27,210 +24,79 @@ INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] -SUPPRESSION_CASES = [ - # (attr_name, acc, epoch_length, num_iters, expected) - ("zero_grad", 4, 12, 12, [True, False, False, False, True, False, False, False, True, False, False, False]), - ("step", 4, 12, 12, [False, False, False, True, False, False, False, True, False, False, False, True]), - # epoch_length=11 not divisible by 4 → flush at epoch end - ("step", 4, 11, 11, [False, False, False, True, False, False, False, True, False, False, True]), - # epoch_length=None (iterable dataset) → no epoch flush - ("step", 4, None, 10, [False, False, False, True, False, False, False, True, False, False]), -] - - -def _make_engine(epoch_length, iteration=1, scaler=None): - """Create a mock engine whose _iteration observes patched methods.""" - engine = MagicMock() - engine.state.epoch_length = epoch_length - engine.state.iteration = iteration - engine.scaler = scaler - engine.optimizer = MagicMock() - engine.loss_function = MagicMock(return_value=torch.tensor(1.0)) - engine._iteration.return_value = {CommonKeys.LOSS: torch.tensor(1.0)} - return engine +def _make_model_pair(lr): + """Create a reference and test model pair with identical initial weights.""" + ref_model = nn.Linear(4, 1, bias=False) + init_weight = ref_model.weight.data.clone() + ref_opt = torch.optim.SGD(ref_model.parameters(), lr=lr) + ref_model.train() + test_model = nn.Linear(4, 1, bias=False) + test_model.weight.data.copy_(init_weight) + test_opt = torch.optim.SGD(test_model.parameters(), lr=lr) + + return ref_model, test_model, ref_opt, test_opt, init_weight + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") class TestGradientAccumulation(unittest.TestCase): - """Test cases for GradientAccumulation callable.""" + """Test gradient accumulation integrated into SupervisedTrainer.""" # ---- input validation ---- @parameterized.expand(INVALID_ACCUMULATION_STEPS) def test_invalid_accumulation_steps(self, value) -> None: + from monai.engines import SupervisedTrainer + with self.assertRaises(ValueError) as cm: - GradientAccumulation(accumulation_steps=value) + SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}], + network=nn.Linear(4, 1), + optimizer=torch.optim.SGD(nn.Linear(4, 1).parameters(), lr=0.1), + loss_function=nn.MSELoss(), + accumulation_steps=value, + ) self.assertIn("positive integer", str(cm.exception)) - def test_repr(self) -> None: - ga = GradientAccumulation(accumulation_steps=8) - self.assertEqual(repr(ga), "GradientAccumulation(accumulation_steps=8)") - - # ---- passthrough ---- + # ---- passthrough (accumulation_steps=1) ---- def test_passthrough_when_accumulation_steps_1(self) -> None: - grad_accum = GradientAccumulation(accumulation_steps=1) - engine = _make_engine(epoch_length=12, iteration=1) - expected_output = {CommonKeys.LOSS: torch.tensor(0.5), CommonKeys.PRED: torch.tensor([1.0])} - engine._iteration.return_value = expected_output - - result = grad_accum(engine, {}) - - engine._iteration.assert_called_once_with(engine, {}) - self.assertIs(result, expected_output) - - # ---- suppression logic ---- - - @parameterized.expand(SUPPRESSION_CASES) - def test_suppression(self, attr_name, acc, epoch_length, num_iters, expected) -> None: - grad_accum = GradientAccumulation(accumulation_steps=acc) - original = MagicMock(name=attr_name) - engine = _make_engine(epoch_length) - setattr(engine.optimizer, attr_name, original) - - saw_original: list[bool] = [] - - def fake_iteration(eng, batch): - saw_original.append(getattr(eng.optimizer, attr_name) is original) - return {CommonKeys.LOSS: torch.tensor(1.0)} - - engine._iteration.side_effect = fake_iteration - - for i in range(1, num_iters + 1): - engine.state.iteration = i - grad_accum(engine, {}) - - self.assertEqual(saw_original, expected) - - # ---- patching / restoration ---- - - def test_patching_and_restoration(self) -> None: - engine = _make_engine(epoch_length=4, iteration=1) - - original_zero_grad = MagicMock(name="original_zero_grad") - original_step = MagicMock(name="original_step") - original_loss_fn = MagicMock(return_value=torch.tensor(0.5), name="original_loss_fn") - - engine.optimizer.zero_grad = original_zero_grad - engine.optimizer.step = original_step - engine.loss_function = original_loss_fn - - GradientAccumulation(accumulation_steps=2)(engine, {}) - - self.assertIs(engine.optimizer.zero_grad, original_zero_grad) - self.assertIs(engine.optimizer.step, original_step) - self.assertIs(engine.loss_function, original_loss_fn) - - def test_restoration_after_exception(self) -> None: - """try/finally must restore all originals even when _iteration raises.""" - engine = _make_engine(epoch_length=8, iteration=2) - - original_zero_grad = MagicMock(name="zero_grad") - original_step = MagicMock(name="step") - original_loss_fn = MagicMock(return_value=torch.tensor(1.0), name="loss_fn") - original_scaler_step = MagicMock(name="scaler_step") - original_scaler_update = MagicMock(name="scaler_update") - - engine.optimizer.zero_grad = original_zero_grad - engine.optimizer.step = original_step - engine.loss_function = original_loss_fn - engine.scaler = MagicMock() - engine.scaler.step = original_scaler_step - engine.scaler.update = original_scaler_update - engine._iteration.side_effect = RuntimeError("boom") - - with self.assertRaises(RuntimeError): - GradientAccumulation(accumulation_steps=4)(engine, {}) - - self.assertIs(engine.optimizer.zero_grad, original_zero_grad) - self.assertIs(engine.optimizer.step, original_step) - self.assertIs(engine.loss_function, original_loss_fn) - self.assertIs(engine.scaler.step, original_scaler_step) - self.assertIs(engine.scaler.update, original_scaler_update) - - # ---- scaler ---- - - def test_scaler_not_patched_when_stepping(self) -> None: - engine = _make_engine(epoch_length=4, iteration=2) # acc=2 → should_step=True - original_scaler_step = MagicMock(name="scaler_step") - original_scaler_update = MagicMock(name="scaler_update") - engine.scaler = MagicMock() - engine.scaler.step = original_scaler_step - engine.scaler.update = original_scaler_update - - GradientAccumulation(accumulation_steps=2)(engine, {}) - - self.assertIs(engine.scaler.step, original_scaler_step) - self.assertIs(engine.scaler.update, original_scaler_update) - - def test_scaler_patched_and_restored_when_suppressed(self) -> None: - engine = _make_engine(epoch_length=8, iteration=2) # should_step=False for acc=4 - original_scaler_step = MagicMock(name="scaler_step") - original_scaler_update = MagicMock(name="scaler_update") - engine.scaler = MagicMock() - engine.scaler.step = original_scaler_step - engine.scaler.update = original_scaler_update - - scaler_was_patched = [] - - def check_scaler(eng, batch): - scaler_was_patched.append(eng.scaler.step is not original_scaler_step) - return {CommonKeys.LOSS: torch.tensor(0.5)} - - engine._iteration.side_effect = check_scaler - GradientAccumulation(accumulation_steps=4)(engine, {}) - - self.assertTrue(scaler_was_patched[0], "scaler.step should be patched during _iteration") - self.assertIs(engine.scaler.step, original_scaler_step) - self.assertIs(engine.scaler.update, original_scaler_update) - - def test_no_scaler_attribute(self) -> None: - """Engine without a scaler attribute at all should work (hasattr returns False).""" - engine = _make_engine(epoch_length=4, iteration=1) - del engine.scaler # MagicMock auto-creates attrs; delete to test hasattr branch - - result = GradientAccumulation(accumulation_steps=2)(engine, {}) - self.assertIn(CommonKeys.LOSS, result) - - def test_scaler_is_none(self) -> None: - engine = _make_engine(epoch_length=4, iteration=2) - engine.scaler = None - - result = GradientAccumulation(accumulation_steps=2)(engine, {}) - self.assertIn(CommonKeys.LOSS, result) - - # ---- batch / loss ---- - - def test_batch_data_passed_correctly(self) -> None: - engine = _make_engine(epoch_length=4, iteration=1) - test_batch: dict[str, Any] = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)} - - GradientAccumulation(accumulation_steps=2)(engine, test_batch) + """With accumulation_steps=1, behaviour is identical to default training.""" + from monai.engines import SupervisedTrainer - engine._iteration.assert_called_once() - call_args = engine._iteration.call_args - self.assertEqual(call_args[0][0], engine) - self.assertEqual(call_args[0][1], test_batch) + torch.manual_seed(42) + lr = 0.1 + batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)] - def test_loss_output_is_unscaled(self) -> None: - """Output loss should be rescaled to the original (unscaled) value.""" - engine = _make_engine(epoch_length=9, iteration=1) - engine.scaler = None - original_loss = torch.tensor(6.0) - engine.loss_function = MagicMock(return_value=original_loss) + ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr) - def fake_iteration(*args, **kwargs): - scaled = engine.loss_function() - return {CommonKeys.LOSS: scaled} + # Reference: standard training loop + for batch in batches: + ref_opt.zero_grad() + loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() + loss.backward() + ref_opt.step() - engine._iteration.side_effect = fake_iteration + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=batches, + network=test_model, + optimizer=test_opt, + loss_function=nn.MSELoss(), + accumulation_steps=1, + ) + trainer.run() - result = GradientAccumulation(accumulation_steps=3)(engine, {}) - self.assertAlmostEqual(result[CommonKeys.LOSS].item(), 6.0, places=5) + for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): + torch.testing.assert_close(p_test.data, p_ref.data) - # ---- integration (require ignite) ---- + # ---- gradient equivalence ---- - @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") - def test_integration_gradient_equivalence(self) -> None: + def test_gradient_equivalence(self) -> None: """Accumulated gradients over N mini-batches equal one large-batch step.""" from monai.engines import SupervisedTrainer @@ -238,8 +104,9 @@ def test_integration_gradient_equivalence(self) -> None: acc_steps, lr = 4, 0.1 batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)] - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr) + # Reference: manual accumulation ref_opt.zero_grad() for batch in batches: loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps @@ -253,15 +120,16 @@ def test_integration_gradient_equivalence(self) -> None: network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), - iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + accumulation_steps=acc_steps, ) trainer.run() for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): torch.testing.assert_close(p_test.data, p_ref.data) - @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") - def test_integration_epoch_boundary_flush(self) -> None: + # ---- epoch boundary flush ---- + + def test_epoch_boundary_flush(self) -> None: """When epoch_length is not divisible by acc_steps, flush at epoch end.""" from monai.engines import SupervisedTrainer @@ -269,8 +137,9 @@ def test_integration_epoch_boundary_flush(self) -> None: acc_steps, lr = 3, 0.1 batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(5)] - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr) + # Reference: first 3 batches form one cycle, last 2 form a partial cycle flushed at epoch end for cycle_batches in [batches[:3], batches[3:]]: ref_opt.zero_grad() for batch in cycle_batches: @@ -285,15 +154,16 @@ def test_integration_epoch_boundary_flush(self) -> None: network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), - iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + accumulation_steps=acc_steps, ) trainer.run() for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): torch.testing.assert_close(p_test.data, p_ref.data) - @unittest.skipUnless(has_ignite, "Requires pytorch-ignite") - def test_integration_multi_epoch(self) -> None: + # ---- multi-epoch ---- + + def test_multi_epoch(self) -> None: """Verify gradient accumulation is correct across multiple epochs.""" from monai.engines import SupervisedTrainer @@ -301,8 +171,9 @@ def test_integration_multi_epoch(self) -> None: acc_steps, lr, num_epochs = 2, 0.1, 3 batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)] - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr) + # Reference: manual multi-epoch accumulation for _epoch in range(num_epochs): for cycle_batches in [batches[:2], batches[2:]]: ref_opt.zero_grad() @@ -318,26 +189,76 @@ def test_integration_multi_epoch(self) -> None: network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(), - iteration_update=GradientAccumulation(accumulation_steps=acc_steps), + accumulation_steps=acc_steps, ) trainer.run() for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): torch.testing.assert_close(p_test.data, p_ref.data) + # ---- loss output is unscaled ---- -def _make_model_pair(lr): - """Create a reference and test model pair with identical initial weights.""" - ref_model = nn.Linear(4, 1, bias=False) - init_weight = ref_model.weight.data.clone() - ref_opt = torch.optim.SGD(ref_model.parameters(), lr=lr) - ref_model.train() + def test_loss_output_is_unscaled(self) -> None: + """engine.state.output[LOSS] should be the unscaled loss, not loss/acc.""" + from monai.engines import SupervisedTrainer - test_model = nn.Linear(4, 1, bias=False) - test_model.weight.data.copy_(init_weight) - test_opt = torch.optim.SGD(test_model.parameters(), lr=lr) + torch.manual_seed(42) + acc_steps = 4 + batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)] - return ref_model, test_model, ref_opt, test_opt, init_weight + model = nn.Linear(4, 1, bias=False) + opt = torch.optim.SGD(model.parameters(), lr=0.1) + + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=batches, + network=model, + optimizer=opt, + loss_function=nn.MSELoss(), + accumulation_steps=acc_steps, + decollate=False, + ) + trainer.run() + + # The output loss should be the full (unscaled) loss value, not divided by acc_steps + output_loss = trainer.state.output[CommonKeys.LOSS].item() + self.assertGreater(output_loss, 0.0) + + # ---- accumulation_steps attribute ---- + + def test_accumulation_steps_stored(self) -> None: + """Verify the accumulation_steps attribute is accessible on the trainer.""" + from monai.engines import SupervisedTrainer + + model = nn.Linear(4, 1) + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}], + network=model, + optimizer=torch.optim.SGD(model.parameters(), lr=0.1), + loss_function=nn.MSELoss(), + accumulation_steps=8, + ) + self.assertEqual(trainer.accumulation_steps, 8) + + # ---- default is no accumulation ---- + + def test_default_no_accumulation(self) -> None: + """Default accumulation_steps=1 means no accumulation.""" + from monai.engines import SupervisedTrainer + + model = nn.Linear(4, 1) + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}], + network=model, + optimizer=torch.optim.SGD(model.parameters(), lr=0.1), + loss_function=nn.MSELoss(), + ) + self.assertEqual(trainer.accumulation_steps, 1) if __name__ == "__main__": From 3272ff1a269049ff2fdfd8e44505c6429d6f7671 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 9 Mar 2026 09:56:36 +0000 Subject: [PATCH 5/5] Remove isinstance check for accumulation_steps per reviewer feedback Signed-off-by: Soumya Snigdha Kundu --- monai/engines/trainer.py | 2 +- tests/engines/test_gradient_accumulation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 53e75eac40..15033cabac 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -168,7 +168,7 @@ def __init__( compile_kwargs: dict | None = None, accumulation_steps: int = 1, ) -> None: - if not isinstance(accumulation_steps, int) or accumulation_steps < 1: + if accumulation_steps < 1: raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") super().__init__( device=device, diff --git a/tests/engines/test_gradient_accumulation.py b/tests/engines/test_gradient_accumulation.py index 2578cba1f3..bb4725e0f6 100644 --- a/tests/engines/test_gradient_accumulation.py +++ b/tests/engines/test_gradient_accumulation.py @@ -22,7 +22,7 @@ _, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version) -INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] +INVALID_ACCUMULATION_STEPS = [(0,), (-1,)] def _make_model_pair(lr):