From 0aa1c8bfadcb4ad03ada491a942862c68a9d7af1 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 12:20:15 -0400 Subject: [PATCH 1/8] Add CRMOGMWeighting from NeurIPS 2022 --- CHANGELOG.md | 6 +- docs/source/docs/aggregation/cr_mogm.rst | 15 +++ docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_cr_mogm.py | 123 +++++++++++++++++++ tests/unit/aggregation/test_cr_mogm.py | 148 +++++++++++++++++++++++ 6 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 docs/source/docs/aggregation/cr_mogm.rst create mode 100644 src/torchjd/aggregation/_cr_mogm.py create mode 100644 tests/unit/aggregation/test_cr_mogm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fe52e1fe..3b20f2db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `CRMOGMWeighting` from + [Conflict-Reduction Multi-Objective Gradient Methods](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4e91f0648fb6e09f0156a7eaf6c4dfdb-Abstract-Conference.html). + It wraps an existing `Weighting` and stabilises its weights with an exponential moving average + across calls. - Added getters and setters for the constructor parameters of all aggregators and weightings, so that they can be changed after initialization. This includes: `pref_vector`, `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; @@ -18,7 +22,7 @@ changelog does not include internal changes that do not affect the user. `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor - checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. + checks. ## [0.10.0] - 2026-04-16 diff --git a/docs/source/docs/aggregation/cr_mogm.rst b/docs/source/docs/aggregation/cr_mogm.rst new file mode 100644 index 00000000..47e70f49 --- /dev/null +++ b/docs/source/docs/aggregation/cr_mogm.rst @@ -0,0 +1,15 @@ +:hide-toc: + +CR-MOGM +======= + +.. autoclass:: torchjd.aggregation.CRMOGMWeighting + :members: __call__, reset + +.. note:: + The usage example in the docstring above imports + ``WeightedAggregator`` / ``GramianWeightedAggregator`` from + ``torchjd.aggregation._aggregator_bases``, which is a private module. These two + aggregator base classes are not currently part of the public ``torchjd.aggregation`` + namespace, so this private-module import is the only path that works today. Promoting + them to the public namespace is a separate decision left to the maintainers. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 4d62f820..98725ef3 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -29,6 +29,7 @@ Abstract base classes cagrad.rst config.rst constant.rst + cr_mogm.rst dualproj.rst flattening.rst graddrop.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 400cfe27..d36a6d82 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -63,6 +63,7 @@ from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting +from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop @@ -89,6 +90,7 @@ "ConFIG", "Constant", "ConstantWeighting", + "CRMOGMWeighting", "DualProj", "DualProjWeighting", "Flattening", diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py new file mode 100644 index 00000000..0cb7e4d8 --- /dev/null +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import TypeVar, cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful + +from ._weighting_bases import Weighting + +_T = TypeVar("_T", contravariant=True, bound=Tensor) + + +class CRMOGMWeighting(Weighting[_T], Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another + :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it + produces with an exponential moving average (EMA) across calls. This is the weight-smoothing + modifier from `Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022) + `_. + + Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step + :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: + + .. math:: + + \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k + + with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top + \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first + forward call once :math:`m` is known and is reset automatically when ``m``, ``dtype`` or + ``device`` of the input changes. + + Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a + :class:`~torchjd.aggregation._weighting_bases.MatrixWeighting` or a + :class:`~torchjd.aggregation._weighting_bases.GramianWeighting`. The user composes it with + the appropriate aggregator base: + + .. code-block:: python + + from torchjd.aggregation import MeanWeighting, UPGradWeighting + from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, WeightedAggregator, + ) + from torchjd.aggregation._cr_mogm import CRMOGMWeighting + + matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` + when restarting the smoothing from uniform weights. + + :param weighting: The wrapped weighting whose output is smoothed. + :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing + (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes + the weights at their initial uniform value. The default of ``0.9`` follows the usual + EMA convention (analogous to Adam's :math:`\beta_1`). + + .. note:: + ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a + schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning + rate decays. Update ``alpha`` between forward calls via the public attribute on the + wrapping aggregator: + + .. code-block:: python + + # With WeightedAggregator + aggregator.weighting.alpha = 1 - current_lr / initial_lr + + # With GramianWeightedAggregator + aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr + """ + + def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: + super().__init__() + self.weighting = weighting + self.alpha = alpha + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def alpha(self) -> float: + return self._alpha + + @alpha.setter + def alpha(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") + self._alpha = value + + def reset(self) -> None: + """Clears the EMA state so the next forward starts from uniform weights.""" + + self._lambda = None + self._state_key = None + + def forward(self, stat: _T, /) -> Tensor: + device = stat.device + dtype = stat.dtype + m = stat.shape[0] + + self._ensure_state(m, dtype, device) + lambda_prev = cast(Tensor, self._lambda) + + lambda_hat = self.weighting(stat) + lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat + + self._lambda = lambda_k.detach() + return lambda_k + + def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: + key = (m, dtype, device) + if self._state_key != key or self._lambda is None: + if m > 0: + self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + else: + self._lambda = torch.zeros(0, dtype=dtype, device=device) + self._state_key = key + + def __repr__(self) -> str: + return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})" diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py new file mode 100644 index 00000000..a7f520a3 --- /dev/null +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -0,0 +1,148 @@ +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation import MeanWeighting, UPGradWeighting +from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, + WeightedAggregator, +) +from torchjd.aggregation._cr_mogm import CRMOGMWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in +# scaled_matrices, so the gramian-path structural test only uses typical_matrices. +matrix_pairs = [ + (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) + for m in typical_matrices + scaled_matrices +] +gramian_pairs = [ + (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices +] + + +def test_representations() -> None: + W = CRMOGMWeighting(MeanWeighting(), alpha=0.9) + expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)" + # Weighting does not define __str__, so it falls back to __repr__. + assert repr(W) == expected + assert str(W) == expected + + +@mark.parametrize(["aggregator", "matrix"], matrix_pairs) +def test_expected_structure_matrix_weighting( + aggregator: WeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + """ + Use ``UPGradWeighting`` so the weights actually depend on the input — with + ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would + be trivial. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_alpha_setter_accepts_valid() -> None: + W = CRMOGMWeighting(MeanWeighting()) + W.alpha = 0.0 + assert W.alpha == 0.0 + W.alpha = 0.5 + assert W.alpha == 0.5 + W.alpha = 1.0 + assert W.alpha == 1.0 + + +def test_alpha_setter_rejects_out_of_range() -> None: + W = CRMOGMWeighting(MeanWeighting()) + with raises(ValueError, match="alpha"): + W.alpha = -0.1 + with raises(ValueError, match="alpha"): + W.alpha = 1.1 + + +def test_alpha_zero_reduces_to_bare_weighting() -> None: + """ + With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights + equal the bare weighting's output on every call — not just the first. + """ + + J = randn_((3, 8)) + G = J @ J.T + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) + + expected = bare(G) + assert_close(smoothed(G), expected) + assert_close(smoothed(G), expected) + + +def test_alpha_one_freezes_weights() -> None: + """ + With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at + their initial uniform value forever. Note: the equality with uniform weights is a + consequence of the uniform initialisation, not a general property of CR-MOGM. + """ + + J = randn_((3, 8)) + m = J.shape[0] + W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) + uniform = tensor_([1.0 / m] * m) + + assert_close(W(J @ J.T), uniform) + assert_close(W(J @ J.T), uniform) + + +def test_ema_is_applied() -> None: + """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" + + alpha = 0.9 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) + + lambda_hat_1 = bare(G1) + lambda_hat_2 = bare(G2) + uniform = tensor_([1.0 / m] * m) + + expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 + expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 + + assert_close(smoothed(G1), expected_1) + assert_close(smoothed(G2), expected_2) + + +def test_zero_columns() -> None: + """ + A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row + inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would + raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. + """ + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + out = aggregator(tensor_([]).reshape(2, 0)) + assert out.shape == (0,) From 53f3eb30ee2ab427816b3af8381e101a3fa02c36 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 12:54:21 -0400 Subject: [PATCH 2/8] fix(aggregation): Fix Sphinx cross-reference warnings in CRMOGMWeighting docstring --- src/torchjd/aggregation/_cr_mogm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 0cb7e4d8..0e34faa2 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -34,9 +34,8 @@ class CRMOGMWeighting(Weighting[_T], Stateful): ``device`` of the input changes. Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a - :class:`~torchjd.aggregation._weighting_bases.MatrixWeighting` or a - :class:`~torchjd.aggregation._weighting_bases.GramianWeighting`. The user composes it with - the appropriate aggregator base: + ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate + aggregator base: .. code-block:: python From 23c0f62fc79b531e37540bc77d9d41813d64f4ba Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 13:38:15 -0400 Subject: [PATCH 3/8] fix(aggregation): Remove broken NeurIPS URL from CRMOGMWeighting --- CHANGELOG.md | 3 +-- src/torchjd/aggregation/_cr_mogm.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b20f2db..05596289 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `CRMOGMWeighting` from - [Conflict-Reduction Multi-Objective Gradient Methods](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4e91f0648fb6e09f0156a7eaf6c4dfdb-Abstract-Conference.html). +- Added `CRMOGMWeighting` from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential moving average across calls. - Added getters and setters for the constructor parameters of all aggregators and weightings, so diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 0e34faa2..7c3512a8 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -18,8 +18,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful): :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing - modifier from `Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022) - `_. + modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: From daf59f9aa484e33bf1e12b76e5fdab150f982577 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 13:49:20 -0400 Subject: [PATCH 4/8] test(aggregation): Cover zero-row branch in CRMOGMWeighting --- tests/unit/aggregation/test_cr_mogm.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index a7f520a3..2fcc1cea 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MeanWeighting, UPGradWeighting +from torchjd.aggregation import MeanWeighting, SumWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -146,3 +146,15 @@ def test_zero_columns() -> None: aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) out = aggregator(tensor_([]).reshape(2, 0)) assert out.shape == (0,) + + +def test_zero_rows() -> None: + """ + Exercises the ``m=0`` branch of ``_ensure_state``. ``SumWeighting`` is used because it + handles zero-row matrices cleanly (``torch.ones(0)``), unlike ``MeanWeighting`` which + would raise ``ZeroDivisionError``. + """ + + W = CRMOGMWeighting(SumWeighting()) + weights = W(tensor_([]).reshape(0, 8)) + assert weights.shape == (0,) From e846a9c2b7c5d23ea4289439fcc010ca13df4ae8 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 22:12:13 -0400 Subject: [PATCH 5/8] refactor(aggregation): Address review feedback on CRMOGMWeighting --- CHANGELOG.md | 10 +++++----- src/torchjd/aggregation/_cr_mogm.py | 27 +++++++++++++------------- tests/unit/aggregation/test_cr_mogm.py | 26 ++++++++++++++++--------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24aff524..5c62e64f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `CRMOGMWeighting` from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). - It wraps an existing `Weighting` and stabilises its weights with an exponential moving average - across calls. +- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient + Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf) + (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential + moving average across calls. - Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Weighting`s and `Aggregator`s. @@ -23,8 +24,7 @@ changelog does not include internal changes that do not affect the user. `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; - `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor - checks. + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 7c3512a8..6a0e7bc4 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -18,7 +18,9 @@ class CRMOGMWeighting(Weighting[_T], Stateful): :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing - modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). + modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and + Beyond `_ + (NeurIPS 2022). Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: @@ -76,7 +78,6 @@ def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: self.weighting = weighting self.alpha = alpha self._lambda: Tensor | None = None - self._state_key: tuple[int, torch.dtype, torch.device] | None = None @property def alpha(self) -> float: @@ -91,31 +92,29 @@ def alpha(self, value: float) -> None: def reset(self) -> None: """Clears the EMA state so the next forward starts from uniform weights.""" + if isinstance(self.weighting, Stateful): + self.weighting.reset() self._lambda = None - self._state_key = None def forward(self, stat: _T, /) -> Tensor: - device = stat.device - dtype = stat.dtype - m = stat.shape[0] + lambda_hat = self.weighting(stat) - self._ensure_state(m, dtype, device) + self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) lambda_prev = cast(Tensor, self._lambda) - lambda_hat = self.weighting(stat) lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat self._lambda = lambda_k.detach() return lambda_k def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: - key = (m, dtype, device) - if self._state_key != key or self._lambda is None: + if ( + self._lambda is None + or self._lambda.shape[0] != m + or self._lambda.dtype != dtype + or self._lambda.device != device + ): if m > 0: self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) else: self._lambda = torch.zeros(0, dtype=dtype, device=device) - self._state_key = key - - def __repr__(self) -> str: - return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})" diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 2fcc1cea..45a6fb82 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MeanWeighting, SumWeighting, UPGradWeighting +from torchjd.aggregation import GradVacWeighting, MeanWeighting, SumWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -24,14 +24,6 @@ ] -def test_representations() -> None: - W = CRMOGMWeighting(MeanWeighting(), alpha=0.9) - expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)" - # Weighting does not define __str__, so it falls back to __repr__. - assert repr(W) == expected - assert str(W) == expected - - @mark.parametrize(["aggregator", "matrix"], matrix_pairs) def test_expected_structure_matrix_weighting( aggregator: WeightedAggregator, matrix: Tensor @@ -62,6 +54,22 @@ def test_reset_restores_first_step_behavior() -> None: assert_close(first, W(G)) +def test_reset_propagates_to_stateful_weighting() -> None: + """ + Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is + :class:`~torchjd.aggregation.Stateful`. Uses ``GradVacWeighting`` as the inner weighting + because it is both stateful and produces weights that depend on its internal state. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(GradVacWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + def test_alpha_setter_accepts_valid() -> None: W = CRMOGMWeighting(MeanWeighting()) W.alpha = 0.0 From e16cf489456341e9b1b4effbe48976603102031d Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 09:56:16 -0400 Subject: [PATCH 6/8] refactor(aggregation): Simplify CRMOGMWeighting state logic and improve docstring --- src/torchjd/aggregation/_cr_mogm.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 6a0e7bc4..7ee086e5 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar, cast +from typing import TypeVar import torch from torch import Tensor @@ -31,11 +31,10 @@ class CRMOGMWeighting(Weighting[_T], Stateful): with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first - forward call once :math:`m` is known and is reset automatically when ``m``, ``dtype`` or - ``device`` of the input changes. + forward call once :math:`m` is known and is reset automatically when ``m`` changes. Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a - ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate + ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate aggregator base: .. code-block:: python @@ -50,7 +49,8 @@ class CRMOGMWeighting(Weighting[_T], Stateful): gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` - when restarting the smoothing from uniform weights. + when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also + reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. :param weighting: The wrapped weighting whose output is smoothed. :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing @@ -99,22 +99,17 @@ def reset(self) -> None: def forward(self, stat: _T, /) -> Tensor: lambda_hat = self.weighting(stat) - self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) - lambda_prev = cast(Tensor, self._lambda) + lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat self._lambda = lambda_k.detach() return lambda_k - def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: - if ( - self._lambda is None - or self._lambda.shape[0] != m - or self._lambda.dtype != dtype - or self._lambda.device != device - ): + def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: + if self._lambda is None or self._lambda.shape[0] != m: if m > 0: self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) else: self._lambda = torch.zeros(0, dtype=dtype, device=device) + return self._lambda From 1b74974331733792928c84dab0b81c628f5724ce Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 10:00:03 -0400 Subject: [PATCH 7/8] refactor(aggregation): Raise on shape change in CRMOGMWeighting._ensure_state --- src/torchjd/aggregation/_cr_mogm.py | 12 +++++++----- tests/unit/aggregation/test_cr_mogm.py | 14 +------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 7ee086e5..a1d8bd8b 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -107,9 +107,11 @@ def forward(self, stat: _T, /) -> Tensor: return lambda_k def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: - if self._lambda is None or self._lambda.shape[0] != m: - if m > 0: - self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) - else: - self._lambda = torch.zeros(0, dtype=dtype, device=device) + if self._lambda is None: + self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + elif self._lambda.shape[0] != m: + raise ValueError( + f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " + f"`reset()` before changing the number of objectives." + ) return self._lambda diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 45a6fb82..9c402d1b 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import GradVacWeighting, MeanWeighting, SumWeighting, UPGradWeighting +from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -154,15 +154,3 @@ def test_zero_columns() -> None: aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) out = aggregator(tensor_([]).reshape(2, 0)) assert out.shape == (0,) - - -def test_zero_rows() -> None: - """ - Exercises the ``m=0`` branch of ``_ensure_state``. ``SumWeighting`` is used because it - handles zero-row matrices cleanly (``torch.ones(0)``), unlike ``MeanWeighting`` which - would raise ``ZeroDivisionError``. - """ - - W = CRMOGMWeighting(SumWeighting()) - weights = W(tensor_([]).reshape(0, 8)) - assert weights.shape == (0,) From 1cb9953dd4d6fd345ac02f6799d43a1917add11d Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 11:23:36 -0400 Subject: [PATCH 8/8] test(aggregation): Fix reset propagation test and cover shape-change error --- tests/unit/aggregation/test_cr_mogm.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 9c402d1b..80fadcc0 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -57,17 +57,26 @@ def test_reset_restores_first_step_behavior() -> None: def test_reset_propagates_to_stateful_weighting() -> None: """ Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is - :class:`~torchjd.aggregation.Stateful`. Uses ``GradVacWeighting`` as the inner weighting - because it is both stateful and produces weights that depend on its internal state. + :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal + state is cleared after ``reset()``. """ + inner = GradVacWeighting() + W = CRMOGMWeighting(inner, alpha=0.5) J = randn_((3, 8)) - G = J @ J.T - W = CRMOGMWeighting(GradVacWeighting(), alpha=0.5) - first = W(G) - W(G) + W(J @ J.T) + assert inner._phi_t is not None W.reset() - assert_close(first, W(G)) + assert inner._phi_t is None + + +def test_changing_m_raises() -> None: + """Verify that changing the number of objectives after the first call raises a ValueError.""" + + W = CRMOGMWeighting(MeanWeighting()) + W(randn_((3, 8)) @ randn_((3, 8)).T) + with raises(ValueError, match="number of objectives"): + W(randn_((2, 8)) @ randn_((2, 8)).T) def test_alpha_setter_accepts_valid() -> None: