diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a8aedb5..6103d2f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `MoCo` and `MoCoWeighting` from + [Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023)](https://openreview.net/forum?id=dLAYGdKTi2). - Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Aggregator`s. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index ff6e1811..1b5acba9 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -43,6 +43,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + moco.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/moco.rst b/docs/source/docs/aggregation/moco.rst new file mode 100644 index 00000000..5a7f9565 --- /dev/null +++ b/docs/source/docs/aggregation/moco.rst @@ -0,0 +1,10 @@ +:hide-toc: + +MoCo +==== + +.. autoclass:: torchjd.aggregation.MoCo + :members: __call__, reset + +.. autoclass:: torchjd.aggregation.MoCoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index bb8892ff..f30e783f 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -73,6 +73,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._moco import MoCo, MoCoWeighting from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -106,6 +107,8 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoCo", + "MoCoWeighting", "PCGrad", "PCGradWeighting", "Random", diff --git a/src/torchjd/aggregation/_moco.py b/src/torchjd/aggregation/_moco.py new file mode 100644 index 00000000..c564f16c --- /dev/null +++ b/src/torchjd/aggregation/_moco.py @@ -0,0 +1,258 @@ +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import Matrix + +from ._aggregator_bases import Aggregator +from ._utils.non_differentiable import raise_non_differentiable_error +from ._weighting_bases import _MatrixWeighting + + +class MoCoWeighting(_MatrixWeighting, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] giving the weights of + :class:`~torchjd.aggregation.MoCo`. + + This weighting is stateful: it keeps the moving gradient estimate :math:`Y` and the task weights + :math:`\lambda` across calls. Use :meth:`reset` between independent runs. + + .. warning:: + MoCo aggregates the moving estimate :math:`Y`, not the current matrix. Therefore, using + these weights directly on the current matrix does not generally reproduce + :class:`~torchjd.aggregation.MoCo`. + + :param beta: Learning rate of the moving gradient estimate. + :param beta_sigma: Decay exponent of ``beta``. + :param gamma: Learning rate of the task weights. + :param gamma_sigma: Decay exponent of ``gamma``. + :param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update. + """ + + def __init__( + self, + beta: float = 0.5, + beta_sigma: float = 0.5, + gamma: float = 0.1, + gamma_sigma: float = 0.5, + rho: float = 0.0, + ) -> None: + super().__init__() + self.beta = beta + self.beta_sigma = beta_sigma + self.gamma = gamma + self.gamma_sigma = gamma_sigma + self.rho = rho + self.reset() + + @property + def beta(self) -> float: + return self._beta + + @beta.setter + def beta(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `beta` must be non-negative. Found beta={value!r}.") + self._beta = value + + @property + def beta_sigma(self) -> float: + return self._beta_sigma + + @beta_sigma.setter + def beta_sigma(self, value: float) -> None: + if value < 0.0: + raise ValueError( + f"Attribute `beta_sigma` must be non-negative. Found beta_sigma={value!r}." + ) + self._beta_sigma = value + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `gamma` must be non-negative. Found gamma={value!r}.") + self._gamma = value + + @property + def gamma_sigma(self) -> float: + return self._gamma_sigma + + @gamma_sigma.setter + def gamma_sigma(self, value: float) -> None: + if value < 0.0: + raise ValueError( + f"Attribute `gamma_sigma` must be non-negative. Found gamma_sigma={value!r}." + ) + self._gamma_sigma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the moving gradient estimate and resets the task weights.""" + + self.step = 0 + self._y: Tensor | None = None + self._lambd: Tensor | None = None + self._state_key: tuple[int, int, torch.device, torch.dtype] | None = None + + def forward(self, matrix: Matrix, /) -> Tensor: + if matrix.shape[0] == 0: + self.reset() + self._y = matrix.detach().clone() + self._state_key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype) + return matrix.new_empty((0,)) + + self._ensure_state(matrix) + self.step += 1 + + y = cast(Tensor, self._y) + lambd = cast(Tensor, self._lambd) + + beta_step = self.beta / (self.step**self.beta_sigma) + gamma_step = self.gamma / (self.step**self.gamma_sigma) + + with torch.no_grad(): + y = y - beta_step * (y - matrix.detach()) + yy_t = y @ y.T + if self.rho != 0.0: + eye = torch.eye(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) + yy_t = yy_t + self.rho * eye + lambd = torch.softmax(lambd - gamma_step * (yy_t @ lambd), dim=-1) + + self._y = y + self._lambd = lambd + + return lambd + + @property + def y(self) -> Tensor: + if self._y is None: + raise RuntimeError("The moving gradient estimate is not initialized yet.") + return self._y + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype) + if self._state_key == key and self._y is not None and self._lambd is not None: + return + + self._y = torch.zeros_like(matrix) + self._lambd = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) + self._state_key = key + + +class MoCo(Aggregator, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Aggregator` implementing MoCo from `Mitigating Gradient Bias in + Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) + `_. + + This aggregator is stateful: it keeps the moving gradient estimate :math:`Y` and the task + weights :math:`\lambda` across calls. Use :meth:`reset` between independent runs. + + .. warning:: + The output depends on previously seen matrices. Call :meth:`reset` between independent + experiments. + + :param beta: Learning rate of the moving gradient estimate. + :param beta_sigma: Decay exponent of ``beta``. + :param gamma: Learning rate of the task weights. + :param gamma_sigma: Decay exponent of ``gamma``. + :param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update. + """ + + weighting: MoCoWeighting + + def __init__( + self, + beta: float = 0.5, + beta_sigma: float = 0.5, + gamma: float = 0.1, + gamma_sigma: float = 0.5, + rho: float = 0.0, + ) -> None: + super().__init__() + self.weighting = MoCoWeighting( + beta=beta, + beta_sigma=beta_sigma, + gamma=gamma, + gamma_sigma=gamma_sigma, + rho=rho, + ) + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + def forward(self, matrix: Matrix, /) -> Tensor: + weights = self.weighting(matrix) + if matrix.shape[0] == 0: + return matrix.sum(dim=0) + + vector = weights @ self.weighting.y + if matrix.requires_grad: + vector = vector + 0.0 * matrix.sum(dim=0) + return vector + + @property + def beta(self) -> float: + return self.weighting.beta + + @beta.setter + def beta(self, value: float) -> None: + self.weighting.beta = value + + @property + def beta_sigma(self) -> float: + return self.weighting.beta_sigma + + @beta_sigma.setter + def beta_sigma(self, value: float) -> None: + self.weighting.beta_sigma = value + + @property + def gamma(self) -> float: + return self.weighting.gamma + + @gamma.setter + def gamma(self, value: float) -> None: + self.weighting.gamma = value + + @property + def gamma_sigma(self) -> float: + return self.weighting.gamma_sigma + + @gamma_sigma.setter + def gamma_sigma(self, value: float) -> None: + self.weighting.gamma_sigma = value + + @property + def rho(self) -> float: + return self.weighting.rho + + @rho.setter + def rho(self, value: float) -> None: + self.weighting.rho = value + + def reset(self) -> None: + """Clears the moving gradient estimate and resets the task weights.""" + + self.weighting.reset() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(beta={self.beta!r}, beta_sigma={self.beta_sigma!r}, " + f"gamma={self.gamma!r}, gamma_sigma={self.gamma_sigma!r}, rho={self.rho!r})" + ) diff --git a/tests/unit/aggregation/test_moco.py b/tests/unit/aggregation/test_moco.py new file mode 100644 index 00000000..40bcf9d5 --- /dev/null +++ b/tests/unit/aggregation/test_moco.py @@ -0,0 +1,177 @@ +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import ones_, randn_, tensor_ + +from torchjd.aggregation import MoCo, MoCoWeighting + +from ._asserts import assert_expected_structure, assert_non_differentiable +from ._inputs import scaled_matrices, typical_matrices + +scaled_pairs = [(MoCo(), matrix) for matrix in scaled_matrices] +typical_pairs = [(MoCo(), matrix) for matrix in typical_matrices] +requires_grad_pairs = [(MoCo(), ones_(3, 5, requires_grad=True))] +PARAMETER_VALUES = [ + ("beta", 0.25), + ("beta_sigma", 0.75), + ("gamma", 0.2), + ("gamma_sigma", 0.6), + ("rho", 0.1), +] + + +def test_representations() -> None: + A = MoCo(beta=0.25, beta_sigma=0.75, gamma=0.2, gamma_sigma=0.6, rho=0.1) + assert repr(A) == "MoCo(beta=0.25, beta_sigma=0.75, gamma=0.2, gamma_sigma=0.6, rho=0.1)" + assert str(A) == "MoCo" + + +def test_zero_rows_returns_zero_vector() -> None: + out = MoCo()(tensor_([]).reshape(0, 3)) + assert_close(out, tensor_([0.0, 0.0, 0.0])) + + +def test_zero_columns_returns_zero_vector() -> None: + out = MoCo()(tensor_([]).reshape(2, 0)) + assert out.shape == (0,) + + +@mark.parametrize("matrix", typical_matrices) +def test_reset_restores_first_step_behavior(matrix: Tensor) -> None: + A = MoCo() + first = A(matrix) + A(matrix) + A.reset() + assert_close(first, A(matrix)) + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: MoCo, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) +def test_non_differentiable(aggregator: MoCo, matrix: Tensor) -> None: + assert_non_differentiable(aggregator, matrix) + + +def test_weighting_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + W = MoCoWeighting() + first = W(J) + W(J) + W.reset() + assert_close(first, W(J)) + + +def test_weighting_reset_clears_state() -> None: + J = randn_((3, 8)) + W = MoCoWeighting() + W(J) + + W.reset() + + assert W.step == 0 + with raises(RuntimeError, match="moving gradient estimate"): + _ = W.y + + +def test_aggregator_reset_clears_weighting_state() -> None: + J = randn_((3, 8)) + A = MoCo() + A(J) + + A.reset() + + assert A.weighting.step == 0 + with raises(RuntimeError, match="moving gradient estimate"): + _ = A.weighting.y + + +def test_y_getter_returns_current_moving_gradient_estimate() -> None: + J = randn_((3, 8)) + W = MoCoWeighting() + + W(J) + + assert_close(W.y, 0.5 * J) + + +def test_weighting_matches_aggregator_state_update() -> None: + J = randn_((3, 8)) + + A = MoCo(beta=0.3, beta_sigma=0.4, gamma=0.2, gamma_sigma=0.6, rho=0.1) + expected = A(J) + + W = MoCoWeighting(beta=0.3, beta_sigma=0.4, gamma=0.2, gamma_sigma=0.6, rho=0.1) + weights = W(J) + result = weights @ W.y + + assert_close(result, expected) + + +@mark.parametrize(["attribute", "value"], PARAMETER_VALUES) +def test_getters_return_constructor_values(attribute: str, value: float) -> None: + A = MoCo(**{attribute: value}) + + assert getattr(A, attribute) == value + assert getattr(A.weighting, attribute) == value + + +@mark.parametrize(["attribute", "value"], PARAMETER_VALUES) +def test_weighting_getters_return_constructor_values(attribute: str, value: float) -> None: + W = MoCoWeighting(**{attribute: value}) + + assert getattr(W, attribute) == value + + +def test_aggregator_setters_update_values() -> None: + A = MoCo() + A.beta = 0.25 + A.beta_sigma = 0.75 + A.gamma = 0.2 + A.gamma_sigma = 0.6 + A.rho = 0.1 + assert A.beta == 0.25 + assert A.beta_sigma == 0.75 + assert A.gamma == 0.2 + assert A.gamma_sigma == 0.6 + assert A.rho == 0.1 + assert A.weighting.beta == 0.25 + assert A.weighting.beta_sigma == 0.75 + assert A.weighting.gamma == 0.2 + assert A.weighting.gamma_sigma == 0.6 + assert A.weighting.rho == 0.1 + + +@mark.parametrize(["attribute", "value"], PARAMETER_VALUES) +def test_aggregator_setter_updates_matching_weighting_value(attribute: str, value: float) -> None: + A = MoCo() + + setattr(A, attribute, value) + + assert getattr(A, attribute) == value + assert getattr(A.weighting, attribute) == value + + +@mark.parametrize(["attribute", "value"], PARAMETER_VALUES) +def test_weighting_setter_updates_value(attribute: str, value: float) -> None: + W = MoCoWeighting() + + setattr(W, attribute, value) + + assert getattr(W, attribute) == value + + +@mark.parametrize("attribute", ["beta", "beta_sigma", "gamma", "gamma_sigma", "rho"]) +def test_aggregator_setters_reject_negative(attribute: str) -> None: + A = MoCo() + with raises(ValueError, match=attribute): + setattr(A, attribute, -1e-9) + + +@mark.parametrize("attribute", ["beta", "beta_sigma", "gamma", "gamma_sigma", "rho"]) +def test_weighting_setters_reject_negative(attribute: str) -> None: + W = MoCoWeighting() + with raises(ValueError, match=attribute): + setattr(W, attribute, -1e-9) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index f468dc44..67894682 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -22,6 +22,8 @@ Mean, MeanWeighting, MGDAWeighting, + MoCo, + MoCoWeighting, PCGrad, PCGradWeighting, Random, @@ -64,6 +66,7 @@ (Krum(n_byzantine=1, n_selected=4), J_Krum, tensor([1.2500, 0.7500, 1.5000])), (Mean(), J_base, tensor([1.0, 1.0, 1.0])), (MGDA(), J_base, tensor([0.0, 1.0, 1.0])), + (MoCo(), J_base, tensor([0.1891, 0.5000, 0.5000])), (PCGrad(), J_base, tensor([0.5848, 3.8012, 3.8012])), (Random(), J_base, tensor([-2.6229, 1.0000, 1.0000])), (Sum(), J_base, tensor([2.0, 2.0, 2.0])), @@ -83,6 +86,7 @@ (GradVacWeighting(), G_base, tensor([2.2222, 1.5789])), (MeanWeighting(), G_base, tensor([0.5000, 0.5000])), (MGDAWeighting(), G_base, tensor([0.6000, 0.4000])), + (MoCoWeighting(), J_base, tensor([0.5622, 0.4378])), (PCGradWeighting(), G_base, tensor([2.2222, 1.5789])), (RandomWeighting(), G_base, tensor([0.8623, 0.1377])), (SumWeighting(), G_base, tensor([1.0, 1.0])),