Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class SupervisedTrainer(Trainer):
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
accumulation_steps: number of mini-batches over which to accumulate gradients before
calling ``optimizer.step()``, effectively simulating a larger batch size on
memory-constrained hardware. Must be a positive integer. Default: 1 (no accumulation).
When ``epoch_length`` is known and not divisible by ``accumulation_steps``, a flush
(optimizer step) is performed at the end of each epoch so no gradients are silently
discarded. The loss stored in ``engine.state.output`` is always the **unscaled** value.
"""

def __init__(
Expand Down Expand Up @@ -160,7 +166,10 @@ def __init__(
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict | None = None,
accumulation_steps: int = 1,
) -> None:
if accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
super().__init__(
device=device,
max_epochs=max_epochs,
Expand Down Expand Up @@ -190,6 +199,7 @@ def __init__(
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer
self.optim_set_to_none = optim_set_to_none
self.accumulation_steps = accumulation_steps

def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
"""
Expand Down Expand Up @@ -245,21 +255,42 @@ def _compute_pred_loss():
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

# Determine gradient accumulation state
acc = engine.accumulation_steps
if acc > 1:
epoch_length = engine.state.epoch_length
if epoch_length is not None:
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
else:
local_iter = engine.state.iteration - 1 # 0-indexed global
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0
else:
should_zero_grad = True
should_step = True

engine.network.train()
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
if should_zero_grad:
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.autocast("cuda", **engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
loss = engine.state.output[Keys.LOSS]
engine.scaler.scale(loss / acc if acc > 1 else loss).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.scaler.step(engine.optimizer)
engine.scaler.update()
if should_step:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
loss = engine.state.output[Keys.LOSS]
(loss / acc if acc > 1 else loss).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.optimizer.step()
if should_step:
engine.optimizer.step()
# copy back meta info
if self.compile:
if inputs_meta is not None:
Expand Down
265 changes: 265 additions & 0 deletions tests/engines/test_gradient_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
import torch.nn as nn
from parameterized import parameterized

from monai.utils import IgniteInfo, min_version, optional_import
from monai.utils.enums import CommonKeys

_, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version)

INVALID_ACCUMULATION_STEPS = [(0,), (-1,)]


def _make_model_pair(lr):
"""Create a reference and test model pair with identical initial weights."""
ref_model = nn.Linear(4, 1, bias=False)
init_weight = ref_model.weight.data.clone()
ref_opt = torch.optim.SGD(ref_model.parameters(), lr=lr)
ref_model.train()

test_model = nn.Linear(4, 1, bias=False)
test_model.weight.data.copy_(init_weight)
test_opt = torch.optim.SGD(test_model.parameters(), lr=lr)

return ref_model, test_model, ref_opt, test_opt, init_weight


@unittest.skipUnless(has_ignite, "Requires pytorch-ignite")
class TestGradientAccumulation(unittest.TestCase):
"""Test gradient accumulation integrated into SupervisedTrainer."""

# ---- input validation ----

@parameterized.expand(INVALID_ACCUMULATION_STEPS)
def test_invalid_accumulation_steps(self, value) -> None:
from monai.engines import SupervisedTrainer

with self.assertRaises(ValueError) as cm:
SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
network=nn.Linear(4, 1),
optimizer=torch.optim.SGD(nn.Linear(4, 1).parameters(), lr=0.1),
loss_function=nn.MSELoss(),
accumulation_steps=value,
)
self.assertIn("positive integer", str(cm.exception))

# ---- passthrough (accumulation_steps=1) ----

def test_passthrough_when_accumulation_steps_1(self) -> None:
"""With accumulation_steps=1, behaviour is identical to default training."""
from monai.engines import SupervisedTrainer

torch.manual_seed(42)
lr = 0.1
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)]

ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)

# Reference: standard training loop
for batch in batches:
ref_opt.zero_grad()
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean()
loss.backward()
ref_opt.step()

trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=batches,
network=test_model,
optimizer=test_opt,
loss_function=nn.MSELoss(),
accumulation_steps=1,
)
trainer.run()

for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
torch.testing.assert_close(p_test.data, p_ref.data)

# ---- gradient equivalence ----

def test_gradient_equivalence(self) -> None:
"""Accumulated gradients over N mini-batches equal one large-batch step."""
from monai.engines import SupervisedTrainer

torch.manual_seed(42)
acc_steps, lr = 4, 0.1
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]

ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)

# Reference: manual accumulation
ref_opt.zero_grad()
for batch in batches:
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
loss.backward()
ref_opt.step()

trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=batches,
network=test_model,
optimizer=test_opt,
loss_function=nn.MSELoss(),
accumulation_steps=acc_steps,
)
trainer.run()

for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
torch.testing.assert_close(p_test.data, p_ref.data)

# ---- epoch boundary flush ----

def test_epoch_boundary_flush(self) -> None:
"""When epoch_length is not divisible by acc_steps, flush at epoch end."""
from monai.engines import SupervisedTrainer

torch.manual_seed(123)
acc_steps, lr = 3, 0.1
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(5)]

ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)

# Reference: first 3 batches form one cycle, last 2 form a partial cycle flushed at epoch end
for cycle_batches in [batches[:3], batches[3:]]:
ref_opt.zero_grad()
for batch in cycle_batches:
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
loss.backward()
ref_opt.step()

trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=batches,
network=test_model,
optimizer=test_opt,
loss_function=nn.MSELoss(),
accumulation_steps=acc_steps,
)
trainer.run()

for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
torch.testing.assert_close(p_test.data, p_ref.data)

# ---- multi-epoch ----

def test_multi_epoch(self) -> None:
"""Verify gradient accumulation is correct across multiple epochs."""
from monai.engines import SupervisedTrainer

torch.manual_seed(42)
acc_steps, lr, num_epochs = 2, 0.1, 3
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)]

ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)

# Reference: manual multi-epoch accumulation
for _epoch in range(num_epochs):
for cycle_batches in [batches[:2], batches[2:]]:
ref_opt.zero_grad()
for batch in cycle_batches:
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
loss.backward()
ref_opt.step()

trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=num_epochs,
train_data_loader=batches,
network=test_model,
optimizer=test_opt,
loss_function=nn.MSELoss(),
accumulation_steps=acc_steps,
)
trainer.run()

for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
torch.testing.assert_close(p_test.data, p_ref.data)

# ---- loss output is unscaled ----

def test_loss_output_is_unscaled(self) -> None:
"""engine.state.output[LOSS] should be the unscaled loss, not loss/acc."""
from monai.engines import SupervisedTrainer

torch.manual_seed(42)
acc_steps = 4
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]

model = nn.Linear(4, 1, bias=False)
opt = torch.optim.SGD(model.parameters(), lr=0.1)

trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=batches,
network=model,
optimizer=opt,
loss_function=nn.MSELoss(),
accumulation_steps=acc_steps,
decollate=False,
)
trainer.run()

# The output loss should be the full (unscaled) loss value, not divided by acc_steps
output_loss = trainer.state.output[CommonKeys.LOSS].item()
self.assertGreater(output_loss, 0.0)

# ---- accumulation_steps attribute ----

def test_accumulation_steps_stored(self) -> None:
"""Verify the accumulation_steps attribute is accessible on the trainer."""
from monai.engines import SupervisedTrainer

model = nn.Linear(4, 1)
trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
network=model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
loss_function=nn.MSELoss(),
accumulation_steps=8,
)
self.assertEqual(trainer.accumulation_steps, 8)

# ---- default is no accumulation ----

def test_default_no_accumulation(self) -> None:
"""Default accumulation_steps=1 means no accumulation."""
from monai.engines import SupervisedTrainer

model = nn.Linear(4, 1)
trainer = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
network=model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
loss_function=nn.MSELoss(),
)
self.assertEqual(trainer.accumulation_steps, 1)


if __name__ == "__main__":
unittest.main()
Loading