diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a8aedb5..f74d8aaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient + Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf) + (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential + moving average across calls. - Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Aggregator`s. @@ -30,8 +34,8 @@ changelog does not include internal changes that do not affect the user. `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; - `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor - checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor checks. + Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 diff --git a/docs/source/docs/aggregation/cr_mogm.rst b/docs/source/docs/aggregation/cr_mogm.rst new file mode 100644 index 00000000..47e70f49 --- /dev/null +++ b/docs/source/docs/aggregation/cr_mogm.rst @@ -0,0 +1,15 @@ +:hide-toc: + +CR-MOGM +======= + +.. autoclass:: torchjd.aggregation.CRMOGMWeighting + :members: __call__, reset + +.. note:: + The usage example in the docstring above imports + ``WeightedAggregator`` / ``GramianWeightedAggregator`` from + ``torchjd.aggregation._aggregator_bases``, which is a private module. These two + aggregator base classes are not currently part of the public ``torchjd.aggregation`` + namespace, so this private-module import is the only path that works today. Promoting + them to the public namespace is a separate decision left to the maintainers. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index ff6e1811..3c0516b4 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes cagrad.rst config.rst constant.rst + cr_mogm.rst dualproj.rst flattening.rst graddrop.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index bb8892ff..1a76bde4 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -64,6 +64,7 @@ from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting +from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop @@ -90,6 +91,7 @@ "ConFIG", "Constant", "ConstantWeighting", + "CRMOGMWeighting", "DualProj", "DualProjWeighting", "Flattening", diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py new file mode 100644 index 00000000..a1d8bd8b --- /dev/null +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import TypeVar + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful + +from ._weighting_bases import Weighting + +_T = TypeVar("_T", contravariant=True, bound=Tensor) + + +class CRMOGMWeighting(Weighting[_T], Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another + :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it + produces with an exponential moving average (EMA) across calls. This is the weight-smoothing + modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and + Beyond `_ + (NeurIPS 2022). + + Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step + :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: + + .. math:: + + \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k + + with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top + \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first + forward call once :math:`m` is known and is reset automatically when ``m`` changes. + + Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a + ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate + aggregator base: + + .. code-block:: python + + from torchjd.aggregation import MeanWeighting, UPGradWeighting + from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, WeightedAggregator, + ) + from torchjd.aggregation._cr_mogm import CRMOGMWeighting + + matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` + when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also + reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. + + :param weighting: The wrapped weighting whose output is smoothed. + :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing + (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes + the weights at their initial uniform value. The default of ``0.9`` follows the usual + EMA convention (analogous to Adam's :math:`\beta_1`). + + .. note:: + ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a + schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning + rate decays. Update ``alpha`` between forward calls via the public attribute on the + wrapping aggregator: + + .. code-block:: python + + # With WeightedAggregator + aggregator.weighting.alpha = 1 - current_lr / initial_lr + + # With GramianWeightedAggregator + aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr + """ + + def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: + super().__init__() + self.weighting = weighting + self.alpha = alpha + self._lambda: Tensor | None = None + + @property + def alpha(self) -> float: + return self._alpha + + @alpha.setter + def alpha(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") + self._alpha = value + + def reset(self) -> None: + """Clears the EMA state so the next forward starts from uniform weights.""" + + if isinstance(self.weighting, Stateful): + self.weighting.reset() + self._lambda = None + + def forward(self, stat: _T, /) -> Tensor: + lambda_hat = self.weighting(stat) + + lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) + + lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat + + self._lambda = lambda_k.detach() + return lambda_k + + def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: + if self._lambda is None: + self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + elif self._lambda.shape[0] != m: + raise ValueError( + f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " + f"`reset()` before changing the number of objectives." + ) + return self._lambda diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py new file mode 100644 index 00000000..80fadcc0 --- /dev/null +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -0,0 +1,165 @@ +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting +from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, + WeightedAggregator, +) +from torchjd.aggregation._cr_mogm import CRMOGMWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in +# scaled_matrices, so the gramian-path structural test only uses typical_matrices. +matrix_pairs = [ + (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) + for m in typical_matrices + scaled_matrices +] +gramian_pairs = [ + (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices +] + + +@mark.parametrize(["aggregator", "matrix"], matrix_pairs) +def test_expected_structure_matrix_weighting( + aggregator: WeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + """ + Use ``UPGradWeighting`` so the weights actually depend on the input — with + ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would + be trivial. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_reset_propagates_to_stateful_weighting() -> None: + """ + Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is + :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal + state is cleared after ``reset()``. + """ + + inner = GradVacWeighting() + W = CRMOGMWeighting(inner, alpha=0.5) + J = randn_((3, 8)) + W(J @ J.T) + assert inner._phi_t is not None + W.reset() + assert inner._phi_t is None + + +def test_changing_m_raises() -> None: + """Verify that changing the number of objectives after the first call raises a ValueError.""" + + W = CRMOGMWeighting(MeanWeighting()) + W(randn_((3, 8)) @ randn_((3, 8)).T) + with raises(ValueError, match="number of objectives"): + W(randn_((2, 8)) @ randn_((2, 8)).T) + + +def test_alpha_setter_accepts_valid() -> None: + W = CRMOGMWeighting(MeanWeighting()) + W.alpha = 0.0 + assert W.alpha == 0.0 + W.alpha = 0.5 + assert W.alpha == 0.5 + W.alpha = 1.0 + assert W.alpha == 1.0 + + +def test_alpha_setter_rejects_out_of_range() -> None: + W = CRMOGMWeighting(MeanWeighting()) + with raises(ValueError, match="alpha"): + W.alpha = -0.1 + with raises(ValueError, match="alpha"): + W.alpha = 1.1 + + +def test_alpha_zero_reduces_to_bare_weighting() -> None: + """ + With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights + equal the bare weighting's output on every call — not just the first. + """ + + J = randn_((3, 8)) + G = J @ J.T + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) + + expected = bare(G) + assert_close(smoothed(G), expected) + assert_close(smoothed(G), expected) + + +def test_alpha_one_freezes_weights() -> None: + """ + With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at + their initial uniform value forever. Note: the equality with uniform weights is a + consequence of the uniform initialisation, not a general property of CR-MOGM. + """ + + J = randn_((3, 8)) + m = J.shape[0] + W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) + uniform = tensor_([1.0 / m] * m) + + assert_close(W(J @ J.T), uniform) + assert_close(W(J @ J.T), uniform) + + +def test_ema_is_applied() -> None: + """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" + + alpha = 0.9 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) + + lambda_hat_1 = bare(G1) + lambda_hat_2 = bare(G2) + uniform = tensor_([1.0 / m] * m) + + expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 + expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 + + assert_close(smoothed(G1), expected_1) + assert_close(smoothed(G2), expected_2) + + +def test_zero_columns() -> None: + """ + A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row + inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would + raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. + """ + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + out = aggregator(tensor_([]).reshape(2, 0)) + assert out.shape == (0,)