-
Notifications
You must be signed in to change notification settings - Fork 16
feat(aggregation): Add MoCo #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,7 @@ Abstract base classes | |
| krum.rst | ||
| mean.rst | ||
| mgda.rst | ||
| moco.rst | ||
| nash_mtl.rst | ||
| pcgrad.rst | ||
| random.rst | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| :hide-toc: | ||
|
|
||
| MoCo | ||
| ==== | ||
|
|
||
| .. autoclass:: torchjd.aggregation.MoCo | ||
| :members: __call__, reset | ||
|
|
||
| .. autoclass:: torchjd.aggregation.MoCoWeighting | ||
| :members: __call__, reset |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,258 @@ | ||||||||
| from typing import cast | ||||||||
|
|
||||||||
| import torch | ||||||||
| from torch import Tensor | ||||||||
|
|
||||||||
| from torchjd.aggregation._mixins import Stateful | ||||||||
| from torchjd.linalg import Matrix | ||||||||
|
|
||||||||
| from ._aggregator_bases import Aggregator | ||||||||
| from ._utils.non_differentiable import raise_non_differentiable_error | ||||||||
|
Check failure on line 10 in src/torchjd/aggregation/_moco.py
|
||||||||
| from ._weighting_bases import _MatrixWeighting | ||||||||
|
|
||||||||
|
|
||||||||
| class MoCoWeighting(_MatrixWeighting, Stateful): | ||||||||
| r""" | ||||||||
| :class:`~torchjd.aggregation._mixins.Stateful` | ||||||||
| :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] giving the weights of | ||||||||
| :class:`~torchjd.aggregation.MoCo`. | ||||||||
|
|
||||||||
| This weighting is stateful: it keeps the moving gradient estimate :math:`Y` and the task weights | ||||||||
| :math:`\lambda` across calls. Use :meth:`reset` between independent runs. | ||||||||
|
|
||||||||
| .. warning:: | ||||||||
| MoCo aggregates the moving estimate :math:`Y`, not the current matrix. Therefore, using | ||||||||
| these weights directly on the current matrix does not generally reproduce | ||||||||
| :class:`~torchjd.aggregation.MoCo`. | ||||||||
|
|
||||||||
| :param beta: Learning rate of the moving gradient estimate. | ||||||||
| :param beta_sigma: Decay exponent of ``beta``. | ||||||||
| :param gamma: Learning rate of the task weights. | ||||||||
| :param gamma_sigma: Decay exponent of ``gamma``. | ||||||||
| :param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| beta: float = 0.5, | ||||||||
| beta_sigma: float = 0.5, | ||||||||
| gamma: float = 0.1, | ||||||||
| gamma_sigma: float = 0.5, | ||||||||
| rho: float = 0.0, | ||||||||
| ) -> None: | ||||||||
| super().__init__() | ||||||||
| self.beta = beta | ||||||||
| self.beta_sigma = beta_sigma | ||||||||
| self.gamma = gamma | ||||||||
| self.gamma_sigma = gamma_sigma | ||||||||
| self.rho = rho | ||||||||
| self.reset() | ||||||||
|
|
||||||||
| @property | ||||||||
| def beta(self) -> float: | ||||||||
| return self._beta | ||||||||
|
|
||||||||
| @beta.setter | ||||||||
| def beta(self, value: float) -> None: | ||||||||
| if value < 0.0: | ||||||||
| raise ValueError(f"Attribute `beta` must be non-negative. Found beta={value!r}.") | ||||||||
| self._beta = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def beta_sigma(self) -> float: | ||||||||
| return self._beta_sigma | ||||||||
|
|
||||||||
| @beta_sigma.setter | ||||||||
| def beta_sigma(self, value: float) -> None: | ||||||||
| if value < 0.0: | ||||||||
| raise ValueError( | ||||||||
| f"Attribute `beta_sigma` must be non-negative. Found beta_sigma={value!r}." | ||||||||
| ) | ||||||||
| self._beta_sigma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def gamma(self) -> float: | ||||||||
| return self._gamma | ||||||||
|
|
||||||||
| @gamma.setter | ||||||||
| def gamma(self, value: float) -> None: | ||||||||
| if value < 0.0: | ||||||||
| raise ValueError(f"Attribute `gamma` must be non-negative. Found gamma={value!r}.") | ||||||||
| self._gamma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def gamma_sigma(self) -> float: | ||||||||
| return self._gamma_sigma | ||||||||
|
|
||||||||
| @gamma_sigma.setter | ||||||||
| def gamma_sigma(self, value: float) -> None: | ||||||||
| if value < 0.0: | ||||||||
| raise ValueError( | ||||||||
| f"Attribute `gamma_sigma` must be non-negative. Found gamma_sigma={value!r}." | ||||||||
| ) | ||||||||
| self._gamma_sigma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def rho(self) -> float: | ||||||||
| return self._rho | ||||||||
|
|
||||||||
| @rho.setter | ||||||||
| def rho(self, value: float) -> None: | ||||||||
| if value < 0.0: | ||||||||
| raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") | ||||||||
| self._rho = value | ||||||||
|
|
||||||||
| def reset(self) -> None: | ||||||||
| """Clears the moving gradient estimate and resets the task weights.""" | ||||||||
|
|
||||||||
| self.step = 0 | ||||||||
| self._y: Tensor | None = None | ||||||||
| self._lambd: Tensor | None = None | ||||||||
| self._state_key: tuple[int, int, torch.device, torch.dtype] | None = None | ||||||||
|
|
||||||||
| def forward(self, matrix: Matrix, /) -> Tensor: | ||||||||
| if matrix.shape[0] == 0: | ||||||||
| self.reset() | ||||||||
| self._y = matrix.detach().clone() | ||||||||
| self._state_key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype) | ||||||||
| return matrix.new_empty((0,)) | ||||||||
|
|
||||||||
| self._ensure_state(matrix) | ||||||||
| self.step += 1 | ||||||||
|
|
||||||||
| y = cast(Tensor, self._y) | ||||||||
| lambd = cast(Tensor, self._lambd) | ||||||||
|
|
||||||||
| beta_step = self.beta / (self.step**self.beta_sigma) | ||||||||
| gamma_step = self.gamma / (self.step**self.gamma_sigma) | ||||||||
|
|
||||||||
| with torch.no_grad(): | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit weird, why not having a no_grad around the whole function instead? Also what is the current philosphy on grads of outputs of aggregators @ValerianRey ? Should we unify this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are just two cases:
So I think that we should improve on that:
About MoCo, idk if it's differentiable or not. I guess that changing the value of y inplace will lead to runtime error when we try to differentiate through it (the same reason why we made PCGrad non-differentiable). Need to verify that. If it's not, I think we should make it inherit NonDifferentiable. If it is, let's remove the torch.no_grad and matrix.detach. In any case, let's remove those two things because NonDifferentiable will handle it itself if we add it. I'll ask claude to work on that.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #677
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fantastic, upon merge of #677 this should be removed, and then the aggregator/weighting pair should implement the Non differentiable mixin. @rkhosrowshahi Take a look at #677 as it specifically needs to be inherited first.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's now merged FYI. |
||||||||
| y = y - beta_step * (y - matrix.detach()) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The size of matrix/y are the critical part. This is slightly costly because we have another time the full Jacobian in memory. I would not use
Suggested change
With this you have exactly two Jacobians stored.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will be complicated to adapt with the normalization wrapper as we normalize the |
||||||||
| yy_t = y @ y.T | ||||||||
| if self.rho != 0.0: | ||||||||
| eye = torch.eye(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) | ||||||||
| yy_t = yy_t + self.rho * eye | ||||||||
| lambd = torch.softmax(lambd - gamma_step * (yy_t @ lambd), dim=-1) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the paper, the softmax is supposed to be a projection to the probability simplex. Why is this a softmax (it is on the probability simplex, but not a projection)? |
||||||||
|
|
||||||||
| self._y = y | ||||||||
| self._lambd = lambd | ||||||||
|
|
||||||||
| return lambd | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is incorrect. In equation 10 of the paper, they use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, I just saw below that while this is a weighting, the moco aggregator is not weighted. I assumed wrongly that it was. |
||||||||
|
|
||||||||
| @property | ||||||||
| def y(self) -> Tensor: | ||||||||
| if self._y is None: | ||||||||
| raise RuntimeError("The moving gradient estimate is not initialized yet.") | ||||||||
| return self._y | ||||||||
|
|
||||||||
| def _ensure_state(self, matrix: Matrix) -> None: | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the role of this function to initialize the state if none? we call it differently if so, maybe something in the direction of |
||||||||
| key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype) | ||||||||
| if self._state_key == key and self._y is not None and self._lambd is not None: | ||||||||
| return | ||||||||
|
|
||||||||
| self._y = torch.zeros_like(matrix) | ||||||||
| self._lambd = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) | ||||||||
| self._state_key = key | ||||||||
|
|
||||||||
|
|
||||||||
| class MoCo(Aggregator, Stateful): | ||||||||
| r""" | ||||||||
| :class:`~torchjd.aggregation._mixins.Stateful` | ||||||||
| :class:`~torchjd.aggregation.Aggregator` implementing MoCo from `Mitigating Gradient Bias in | ||||||||
| Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) | ||||||||
| <https://openreview.net/forum?id=dLAYGdKTi2>`_. | ||||||||
|
|
||||||||
| This aggregator is stateful: it keeps the moving gradient estimate :math:`Y` and the task | ||||||||
| weights :math:`\lambda` across calls. Use :meth:`reset` between independent runs. | ||||||||
|
|
||||||||
| .. warning:: | ||||||||
| The output depends on previously seen matrices. Call :meth:`reset` between independent | ||||||||
| experiments. | ||||||||
|
|
||||||||
| :param beta: Learning rate of the moving gradient estimate. | ||||||||
| :param beta_sigma: Decay exponent of ``beta``. | ||||||||
| :param gamma: Learning rate of the task weights. | ||||||||
| :param gamma_sigma: Decay exponent of ``gamma``. | ||||||||
| :param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update. | ||||||||
| """ | ||||||||
|
|
||||||||
| weighting: MoCoWeighting | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| beta: float = 0.5, | ||||||||
| beta_sigma: float = 0.5, | ||||||||
| gamma: float = 0.1, | ||||||||
| gamma_sigma: float = 0.5, | ||||||||
| rho: float = 0.0, | ||||||||
| ) -> None: | ||||||||
| super().__init__() | ||||||||
| self.weighting = MoCoWeighting( | ||||||||
| beta=beta, | ||||||||
| beta_sigma=beta_sigma, | ||||||||
| gamma=gamma, | ||||||||
| gamma_sigma=gamma_sigma, | ||||||||
| rho=rho, | ||||||||
| ) | ||||||||
| self.register_full_backward_pre_hook(raise_non_differentiable_error) | ||||||||
|
|
||||||||
| def forward(self, matrix: Matrix, /) -> Tensor: | ||||||||
| weights = self.weighting(matrix) | ||||||||
| if matrix.shape[0] == 0: | ||||||||
| return matrix.sum(dim=0) | ||||||||
|
Comment on lines
+201
to
+202
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove
Suggested change
|
||||||||
|
|
||||||||
| vector = weights @ self.weighting.y | ||||||||
| if matrix.requires_grad: | ||||||||
| vector = vector + 0.0 * matrix.sum(dim=0) | ||||||||
|
Comment on lines
+205
to
+206
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not do that. No reason to create a wrong differentiation graph for vector. Maybe we want to make this aggregator/weighting pair non-differentiable (because we don't want to keep all the Jacobians for all times). This can be done after #677 is merged.
Suggested change
|
||||||||
| return vector | ||||||||
|
|
||||||||
| @property | ||||||||
| def beta(self) -> float: | ||||||||
| return self.weighting.beta | ||||||||
|
|
||||||||
| @beta.setter | ||||||||
| def beta(self, value: float) -> None: | ||||||||
| self.weighting.beta = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def beta_sigma(self) -> float: | ||||||||
| return self.weighting.beta_sigma | ||||||||
|
|
||||||||
| @beta_sigma.setter | ||||||||
| def beta_sigma(self, value: float) -> None: | ||||||||
| self.weighting.beta_sigma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def gamma(self) -> float: | ||||||||
| return self.weighting.gamma | ||||||||
|
|
||||||||
| @gamma.setter | ||||||||
| def gamma(self, value: float) -> None: | ||||||||
| self.weighting.gamma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def gamma_sigma(self) -> float: | ||||||||
| return self.weighting.gamma_sigma | ||||||||
|
|
||||||||
| @gamma_sigma.setter | ||||||||
| def gamma_sigma(self, value: float) -> None: | ||||||||
| self.weighting.gamma_sigma = value | ||||||||
|
|
||||||||
| @property | ||||||||
| def rho(self) -> float: | ||||||||
| return self.weighting.rho | ||||||||
|
|
||||||||
| @rho.setter | ||||||||
| def rho(self, value: float) -> None: | ||||||||
| self.weighting.rho = value | ||||||||
|
|
||||||||
| def reset(self) -> None: | ||||||||
| """Clears the moving gradient estimate and resets the task weights.""" | ||||||||
|
|
||||||||
| self.weighting.reset() | ||||||||
|
|
||||||||
| def __repr__(self) -> str: | ||||||||
| return ( | ||||||||
| f"{self.__class__.__name__}(beta={self.beta!r}, beta_sigma={self.beta_sigma!r}, " | ||||||||
| f"gamma={self.gamma!r}, gamma_sigma={self.gamma_sigma!r}, rho={self.rho!r})" | ||||||||
| ) | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should never happen (I think we can safely assume that
matrix.shape[0] > 0in the forward call to an aggregator. Maybe we want to specify that somewhere @ValerianReyUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, let's get rid of this. And obviously matrices should have at least 1 row, it's not even like autojac can output jacobians with 0 rows. So it would be when a user would really want to aggregate their own empty matrix with this specific aggregator. This will not happen.