Add GradientAccumulation utility for SupervisedTrainer#8763
Add GradientAccumulation utility for SupervisedTrainer#8763aymuos15 wants to merge 7 commits intoProject-MONAI:devfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis pull request adds gradient accumulation functionality to MONAI's SupervisedTrainer. A new parameter Estimated code review effort🎯 3 (Moderate) | ⏱️ ~18 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)
339-350: Drop the unusedinit_weightfrom the helper return.
It is not consumed by callers, so removing it tightens the helper contract and avoids dead unpacks downstream.♻️ Proposed cleanup
@@ -def _make_model_pair(lr): +def _make_model_pair(lr): @@ - return ref_model, test_model, ref_opt, test_opt, init_weight + return ref_model, test_model, ref_opt, test_opt@@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` around lines 339 - 350, The helper _make_model_pair currently returns an unused init_weight which tightens its contract unnecessarily; remove the creation or cloning of init_weight from inside _make_model_pair (or keep the local init copy only to set test_model weights) and update the return tuple from _make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then update any callers that unpack the result to stop expecting the fifth value so there are no dead unpacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 339-350: The helper _make_model_pair currently returns an unused
init_weight which tightens its contract unnecessarily; remove the creation or
cloning of init_weight from inside _make_model_pair (or keep the local init copy
only to set test_model weights) and update the return tuple from
_make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then
update any callers that unpack the result to stop expecting the fifth value so
there are no dead unpacks.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/engines/__init__.pymonai/engines/utils.pytests/engines/test_gradient_accumulation.py
…ject-MONAI#6100) Closes Project-MONAI#6100 Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
597e086 to
1db8cc1
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/engines/test_gradient_accumulation.py (1)
105-105: Consider marking intentionally-unused bindings with_prefixes.This keeps tests clear while avoiding avoidable lint noise.
🧹 Optional cleanup
- def fake_iteration(eng, batch): + def fake_iteration(eng, _batch): @@ - def check_scaler(eng, batch): + def check_scaler(eng, _batch): @@ - def fake_iteration(*args, **kwargs): + def fake_iteration(*_args, **_kwargs): @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)Also applies to: 188-188, 234-234, 257-257, 287-287, 318-318
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` at line 105, The test defines callback functions like fake_iteration(eng, batch) with parameters that are intentionally unused; update these function signatures (and the other occurrences at the same pattern) to mark unused parameters with leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the bindings are intentionally unused—search for the function name fake_iteration and the similar callback definitions at the other noted locations and rename the unused parameters with _ prefixes.monai/engines/utils.py (1)
366-368: Align new definitions with Google-style docstring sections.
_noop,__init__, and__repr__should include explicitArgs/Returns(andRaiseswhere applicable) sections to match repo docstring policy.♻️ Suggested docstring adjustments
def _noop(*args: Any, **kwargs: Any) -> None: - """No-op callable used to suppress optimizer/scaler methods during gradient accumulation.""" + """No-op callable used to suppress optimizer/scaler methods. + + Args: + *args: Ignored positional arguments. + **kwargs: Ignored keyword arguments. + + Returns: + None. + """ class GradientAccumulation: @@ def __init__(self, accumulation_steps: int = 2) -> None: + """Initialize gradient accumulation behavior. + + Args: + accumulation_steps: Number of mini-batches to accumulate before stepping. + + Raises: + ValueError: If `accumulation_steps` is not a positive integer. + """ if not isinstance(accumulation_steps, int) or accumulation_steps < 1: @@ def __repr__(self) -> str: + """Return a debug-friendly representation. + + Returns: + String representation with configured accumulation steps. + """ return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 405-413
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` around lines 366 - 368, Add Google-style docstring sections to the new definitions: for _noop include an "Args" section describing *args and **kwargs and a "Returns" section noting it returns None; for the class __init__ add an "Args" section for each parameter and a "Returns" section if applicable (or state None) and an optional "Raises" section if it can raise exceptions; for __repr__ add a "Returns" section describing the returned str. Update the docstrings in functions/methods named _noop, __init__, and __repr__ to follow the repo's Google-style (Args, Returns, and Raises where needed) and mirror the format used elsewhere in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@monai/engines/utils.py`:
- Around line 366-368: Add Google-style docstring sections to the new
definitions: for _noop include an "Args" section describing *args and **kwargs
and a "Returns" section noting it returns None; for the class __init__ add an
"Args" section for each parameter and a "Returns" section if applicable (or
state None) and an optional "Raises" section if it can raise exceptions; for
__repr__ add a "Returns" section describing the returned str. Update the
docstrings in functions/methods named _noop, __init__, and __repr__ to follow
the repo's Google-style (Args, Returns, and Raises where needed) and mirror the
format used elsewhere in the file.
In `@tests/engines/test_gradient_accumulation.py`:
- Line 105: The test defines callback functions like fake_iteration(eng, batch)
with parameters that are intentionally unused; update these function signatures
(and the other occurrences at the same pattern) to mark unused parameters with
leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the
bindings are intentionally unused—search for the function name fake_iteration
and the similar callback definitions at the other noted locations and rename the
unused parameters with _ prefixes.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/engines/__init__.pymonai/engines/utils.pytests/engines/test_gradient_accumulation.py
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
monai/engines/utils.py (1)
413-413: Widenbatchdatatype hint in__call__.
batchdata: dict[str, Any]is tighter than common trainer inputs. ConsiderAnyto avoid misleading static typing for tuple/list batch payloads.Proposed fix
- def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict: + def __call__(self, engine: Any, batchdata: Any) -> dict:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` at line 413, The type hint for the __call__ method currently restricts batchdata to dict[str, Any], which is too narrow for trainers that pass tuples/lists; change the signature of __call__ to accept batchdata: Any (or more permissive Union types) so it can handle dict, tuple, list, etc., and update any related type annotations/comments in the same function (named __call__) and its callers to reflect the broader type.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The validation for accumulation_steps currently allows
booleans because bool is an int subclass; update the check in
monai.engines.utils (the accumulation_steps validation) to explicitly reject
bools — e.g., require type(accumulation_steps) is int or add "and not
isinstance(accumulation_steps, bool)" to the isinstance check — and keep the
existing lower-bound check (accumulation_steps < 1) so True/False no longer pass
validation.
In `@tests/engines/test_gradient_accumulation.py`:
- Line 91: The test callback function parameters that are intentionally unused
(for example in function fake_iteration) are triggering ARG001; rename those
parameters by prefixing with an underscore (e.g., change def fake_iteration(eng,
batch): to def fake_iteration(_eng, _batch):) and apply the same pattern to the
other callbacks mentioned (the occurrences around the other reported locations)
so unused arguments are clearly marked and lint-clean.
- Line 240: The helper _make_model_pair currently returns a third value
init_weight that callers (tests in tests/engines/test_gradient_accumulation.py)
unpack but never use; remove this unused plumbing by changing _make_model_pair
to return only (ref_model, test_model, ref_opt, test_opt) and update all call
sites (e.g., the unpack at the shown line and similar occurrences at the other
locations) to stop expecting init_weight — adjust any tuple unpacking in the
tests to four variables matching the function's new signature.
---
Nitpick comments:
In `@monai/engines/utils.py`:
- Line 413: The type hint for the __call__ method currently restricts batchdata
to dict[str, Any], which is too narrow for trainers that pass tuples/lists;
change the signature of __call__ to accept batchdata: Any (or more permissive
Union types) so it can handle dict, tuple, list, etc., and update any related
type annotations/comments in the same function (named __call__) and its callers
to reflect the broader type.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/engines/utils.pytests/engines/test_gradient_accumulation.py
monai/engines/utils.py
Outdated
| if not isinstance(accumulation_steps, int) or accumulation_steps < 1: | ||
| raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") |
There was a problem hiding this comment.
Reject boolean values for accumulation_steps.
True currently passes validation because bool is an int subclass, so invalid config can silently map to 1.
Proposed fix
- if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+ if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")🧰 Tools
🪛 Ruff (0.15.2)
[warning] 407-407: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/engines/utils.py` around lines 406 - 407, The validation for
accumulation_steps currently allows booleans because bool is an int subclass;
update the check in monai.engines.utils (the accumulation_steps validation) to
explicitly reject bools — e.g., require type(accumulation_steps) is int or add
"and not isinstance(accumulation_steps, bool)" to the isinstance check — and
keep the existing lower-bound check (accumulation_steps < 1) so True/False no
longer pass validation.
There was a problem hiding this comment.
We shouldn't check the type of accumulation_steps at all since the type annotation is present, only the correct value range should be enforced.
There was a problem hiding this comment.
@ericspod, understood — I'll withdraw that suggestion. The type annotation is the right place to express the type contract, and the runtime guard should only enforce the value range (accumulation_steps < 1).
📝 Learnings were identified and not saved (knowledge base disabled). Enable
| acc_steps, lr = 4, 0.1 | ||
| batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)] | ||
|
|
||
| ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) |
There was a problem hiding this comment.
Remove unused init_weight plumbing from model-pair helper.
init_weight is returned/unpacked but never used by tests.
Proposed fix
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
-def _make_model_pair(lr):
+def _make_model_pair(lr):
"""Create a reference and test model pair with identical initial weights."""
ref_model = nn.Linear(4, 1, bias=False)
init_weight = ref_model.weight.data.clone()
@@
- return ref_model, test_model, ref_opt, test_opt, init_weight
+ return ref_model, test_model, ref_opt, test_optAlso applies to: 271-271, 303-303, 328-339
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 240-240: Unpacked variable init_weight is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/engines/test_gradient_accumulation.py` at line 240, The helper
_make_model_pair currently returns a third value init_weight that callers (tests
in tests/engines/test_gradient_accumulation.py) unpack but never use; remove
this unused plumbing by changing _make_model_pair to return only (ref_model,
test_model, ref_opt, test_opt) and update all call sites (e.g., the unpack at
the shown line and similar occurrences at the other locations) to stop expecting
init_weight — adjust any tuple unpacking in the tests to four variables matching
the function's new signature.
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/engines/utils.py (1)
406-407:⚠️ Potential issue | 🟡 MinorReject
boolforaccumulation_steps.
Truecurrently passes becauseboolis anintsubclass, so invalid config can silently map to1.Proposed fix
- if not isinstance(accumulation_steps, int) or accumulation_steps < 1: + if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1: raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` around lines 406 - 407, The current validation for accumulation_steps accepts bool because bool is an int subclass; update the check so booleans are rejected — e.g., replace the isinstance(accumulation_steps, int) test with a stricter type check (such as type(accumulation_steps) is int or add an explicit not isinstance(accumulation_steps, bool) condition) so that accumulation_steps must be a genuine int and >= 1; adjust the ValueError message if needed to reflect the stricter type requirement and keep the existing check for accumulation_steps < 1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 28-29: The test data in INVALID_ACCUMULATION_STEPS misses a
boolean edge case; update the tuples in INVALID_ACCUMULATION_STEPS (and the
similar list at lines 58-63 referenced in the comment) to include True as an
invalid input (e.g., add (True,) alongside (0,), (-1,), (2.5,), ("2",)) so the
test suite covers the bool-as-int validation bug for the functions that consume
INVALID_ACCUMULATION_STEPS.
---
Duplicate comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The current validation for accumulation_steps accepts
bool because bool is an int subclass; update the check so booleans are rejected
— e.g., replace the isinstance(accumulation_steps, int) test with a stricter
type check (such as type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/engines/utils.pytests/engines/test_gradient_accumulation.py
| INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] | ||
|
|
There was a problem hiding this comment.
Add explicit bool invalid-input coverage.
This suite misses True, which is the key edge case for the bool-as-int validation bug.
Proposed fix
-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
+INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,), (False,)]As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
Also applies to: 58-63
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/engines/test_gradient_accumulation.py` around lines 28 - 29, The test
data in INVALID_ACCUMULATION_STEPS misses a boolean edge case; update the tuples
in INVALID_ACCUMULATION_STEPS (and the similar list at lines 58-63 referenced in
the comment) to include True as an invalid input (e.g., add (True,) alongside
(0,), (-1,), (2.5,), ("2",)) so the test suite covers the bool-as-int validation
bug for the functions that consume INVALID_ACCUMULATION_STEPS.
|
Hi @aymuos15 thanks for looking into this which had fallen by the wayside a bit. I've looked over your solution and it's similar to what was proposed before. I feel that it's not quite the right way of going about it as it relies on the members of a lot of objects, so relies on the assumptions about the structure of those objects a lot. The way I would go about this feature is to modify |
Replace external GradientAccumulation callable class with a native `accumulation_steps` constructor parameter on SupervisedTrainer, per reviewer feedback. This eliminates monkey-patching of optimizer/loss/scaler internals and instead uses simple conditionals in `_iteration()`. Based on feedback from @ericspod Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
|
@ericspod Apologies, I think I misunderstood the previous recommendations over there. I have now reverted to what you said. Thank you very much. Please let me know if this is aligned now. |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/engines/test_gradient_accumulation.py (1)
25-25:⚠️ Potential issue | 🟡 MinorAdd
(True,)to invalid input coverage.
isinstance(True, int)isTruein Python. If the implementation doesn't guard against bools,Truewould pass validation but is semantically invalid.-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] +INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,)]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` at line 25, Update the test invalid input coverage by adding the tuple (True,) to the INVALID_ACCUMULATION_STEPS sequence so boolean True (which passes isinstance(True, int)) is treated as an invalid accumulation-steps input; modify the constant INVALID_ACCUMULATION_STEPS in tests/engines/test_gradient_accumulation.py to include (True,) alongside (0,), (-1,), (2.5,), and ("2",).
🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)
94-95: Considerstrict=Truefor zip calls.Catches length mismatches between model parameters. Applies to lines 94, 127, 161, 196.
- for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()): + for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters(), strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` around lines 94 - 95, The parameter-comparison loops use zip(test_model.parameters(), ref_model.parameters()) which silently ignores length mismatches; change these zip calls to zip(..., strict=True) in the loops that compare p_test and p_ref (where torch.testing.assert_close(p_test.data, p_ref.data) is called) and the similar occurrences at the other two comparison sites so mismatched parameter counts raise immediately.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/engines/trainer.py`:
- Around line 171-172: The validation for accumulation_steps in trainer.py
accepts bools because bool is a subclass of int; update the check around the
accumulation_steps validation (the if that currently uses
isinstance(accumulation_steps, int) and accumulation_steps < 1) to explicitly
reject booleans (e.g., ensure accumulation_steps is an int but not a bool, or
use type(accumulation_steps) is int) and still enforce accumulation_steps >= 1
so True/False cannot slip through; adjust the ValueError path accordingly.
---
Duplicate comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Line 25: Update the test invalid input coverage by adding the tuple (True,) to
the INVALID_ACCUMULATION_STEPS sequence so boolean True (which passes
isinstance(True, int)) is treated as an invalid accumulation-steps input; modify
the constant INVALID_ACCUMULATION_STEPS in
tests/engines/test_gradient_accumulation.py to include (True,) alongside (0,),
(-1,), (2.5,), and ("2",).
---
Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 94-95: The parameter-comparison loops use
zip(test_model.parameters(), ref_model.parameters()) which silently ignores
length mismatches; change these zip calls to zip(..., strict=True) in the loops
that compare p_test and p_ref (where torch.testing.assert_close(p_test.data,
p_ref.data) is called) and the similar occurrences at the other two comparison
sites so mismatched parameter counts raise immediately.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fa94b287-b3d6-416f-9eed-3882a19399ca
📒 Files selected for processing (2)
monai/engines/trainer.pytests/engines/test_gradient_accumulation.py
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Summary
GradientAccumulationcallable class inmonai.engines.utilsfor use asiteration_updateinSupervisedTrainer, enabling gradient accumulation over multiple mini-batches to simulate larger effective batch sizes on memory-constrained hardwareiteration_updatepattern established byInteractioninmonai.apps.deepedit(as referenced by @wyli in Add gradient accumulation logic to SupervisedTrainer #6101)IterationEventsfire every mini-batch, so existing handlers are unaffectedepoch_length % accumulation_steps != 0GradScaler) support includedCloses #6100
Supersedes #6101
Usage
Types of changes
Test plan
accumulation_steps=1zero_grad/optimizer.stepsuppression patterns verified across full epochsepoch_lengthnot divisible byaccumulation_stepsepoch_length=None) — no epoch flushtry/finally)GradScalerpatching when step suppressed, not patched when steppingscaler=Noneedge cases_iteration