Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `MoCo` and `MoCoWeighting` from
[Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023)](https://openreview.net/forum?id=dLAYGdKTi2).
- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are
now importable from `torchjd.aggregation` and documented. They can be extended to easily implement
custom `Aggregator`s.
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Abstract base classes
krum.rst
mean.rst
mgda.rst
moco.rst
nash_mtl.rst
pcgrad.rst
random.rst
Expand Down
10 changes: 10 additions & 0 deletions docs/source/docs/aggregation/moco.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
:hide-toc:

MoCo
====

.. autoclass:: torchjd.aggregation.MoCo
:members: __call__, reset

.. autoclass:: torchjd.aggregation.MoCoWeighting
:members: __call__, reset
3 changes: 3 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._moco import MoCo, MoCoWeighting
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sum import Sum, SumWeighting
Expand Down Expand Up @@ -106,6 +107,8 @@
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"MoCo",
"MoCoWeighting",
"PCGrad",
"PCGradWeighting",
"Random",
Expand Down
258 changes: 258 additions & 0 deletions src/torchjd/aggregation/_moco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
from typing import cast

import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._utils.non_differentiable import raise_non_differentiable_error

Check failure on line 10 in src/torchjd/aggregation/_moco.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/aggregation/_moco.py:10:7: unresolved-import: Cannot resolve imported module `._utils.non_differentiable` info: Searched in the following paths during module resolution: info: 1. /home/runner/work/TorchJD/TorchJD/src (first-party code) info: 2. /home/runner/work/TorchJD/TorchJD (first-party code) info: 3. vendored://stdlib (stdlib typeshed stubs vendored by ty) info: 4. /home/runner/work/TorchJD/TorchJD/.venv/lib/python3.14/site-packages (site-packages) info: 5. /home/runner/work/TorchJD/TorchJD/.venv/lib64/python3.14/site-packages (site-packages) info: make sure your Python environment is properly configured: https://docs.astral.sh/ty/modules/#python-environment
from ._weighting_bases import _MatrixWeighting


class MoCoWeighting(_MatrixWeighting, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] giving the weights of
:class:`~torchjd.aggregation.MoCo`.

This weighting is stateful: it keeps the moving gradient estimate :math:`Y` and the task weights
:math:`\lambda` across calls. Use :meth:`reset` between independent runs.

.. warning::
MoCo aggregates the moving estimate :math:`Y`, not the current matrix. Therefore, using
these weights directly on the current matrix does not generally reproduce
:class:`~torchjd.aggregation.MoCo`.

:param beta: Learning rate of the moving gradient estimate.
:param beta_sigma: Decay exponent of ``beta``.
:param gamma: Learning rate of the task weights.
:param gamma_sigma: Decay exponent of ``gamma``.
:param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update.
"""

def __init__(
self,
beta: float = 0.5,
beta_sigma: float = 0.5,
gamma: float = 0.1,
gamma_sigma: float = 0.5,
rho: float = 0.0,
) -> None:
super().__init__()
self.beta = beta
self.beta_sigma = beta_sigma
self.gamma = gamma
self.gamma_sigma = gamma_sigma
self.rho = rho
self.reset()

@property
def beta(self) -> float:
return self._beta

@beta.setter
def beta(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `beta` must be non-negative. Found beta={value!r}.")
self._beta = value

@property
def beta_sigma(self) -> float:
return self._beta_sigma

@beta_sigma.setter
def beta_sigma(self, value: float) -> None:
if value < 0.0:
raise ValueError(
f"Attribute `beta_sigma` must be non-negative. Found beta_sigma={value!r}."
)
self._beta_sigma = value

@property
def gamma(self) -> float:
return self._gamma

@gamma.setter
def gamma(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `gamma` must be non-negative. Found gamma={value!r}.")
self._gamma = value

@property
def gamma_sigma(self) -> float:
return self._gamma_sigma

@gamma_sigma.setter
def gamma_sigma(self, value: float) -> None:
if value < 0.0:
raise ValueError(
f"Attribute `gamma_sigma` must be non-negative. Found gamma_sigma={value!r}."
)
self._gamma_sigma = value

@property
def rho(self) -> float:
return self._rho

@rho.setter
def rho(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.")
self._rho = value

def reset(self) -> None:
"""Clears the moving gradient estimate and resets the task weights."""

self.step = 0
self._y: Tensor | None = None
self._lambd: Tensor | None = None
self._state_key: tuple[int, int, torch.device, torch.dtype] | None = None

def forward(self, matrix: Matrix, /) -> Tensor:
if matrix.shape[0] == 0:
self.reset()
self._y = matrix.detach().clone()
self._state_key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype)
return matrix.new_empty((0,))
Comment on lines +114 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen (I think we can safely assume that matrix.shape[0] > 0 in the forward call to an aggregator. Maybe we want to specify that somewhere @ValerianRey

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's get rid of this. And obviously matrices should have at least 1 row, it's not even like autojac can output jacobians with 0 rows. So it would be when a user would really want to aggregate their own empty matrix with this specific aggregator. This will not happen.


self._ensure_state(matrix)
self.step += 1

y = cast(Tensor, self._y)
lambd = cast(Tensor, self._lambd)

beta_step = self.beta / (self.step**self.beta_sigma)
gamma_step = self.gamma / (self.step**self.gamma_sigma)

with torch.no_grad():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, why not having a no_grad around the whole function instead? Also what is the current philosphy on grads of outputs of aggregators @ValerianRey ? Should we unify this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are just two cases:

  1. The aggregator/weighting is differentiable: and we do not want to use any torch.no_grad(), so that the graph is correctly built and stored

  2. The aggregator/weighting is non-differentiable (could be for example because some operations are made on numpy arrays, like for UPGrad, DualProj, and a few others, or for some other niche reasons, like PCGrad, and I think similarly GradVac): we want to raise an error when we try to backward through it (we do that already) and I even think we would like to prevent graph construction by wrapping the forward in a torch.no_grad() (which we don't do currently). And now that I think of it, I don't even think we need to raise an error when calling backward on a non-differentiable module if its forward is wrapped in a torch.no_grad(). No graph will ever be created to begin with, so autograd will never try to backward through the module.

So I think that we should improve on that:

  • Differentiable aggregators/weighting should inherit from a Differentiable mixin. It will not do anything, and it wont be public, but it will serve as internal documentation.
  • Non-differentiable aggregators/weighting should inherit from a NonDifferentiable mixin (similarly protected). It could maybe wrap the forward pass in a torch.no_grad, (and maybe still make it raise an error if we try to differentiate through it).

About MoCo, idk if it's differentiable or not. I guess that changing the value of y inplace will lead to runtime error when we try to differentiate through it (the same reason why we made PCGrad non-differentiable). Need to verify that. If it's not, I think we should make it inherit NonDifferentiable. If it is, let's remove the torch.no_grad and matrix.detach. In any case, let's remove those two things because NonDifferentiable will handle it itself if we add it.

I'll ask claude to work on that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #677

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, upon merge of #677 this should be removed, and then the aggregator/weighting pair should implement the Non differentiable mixin. @rkhosrowshahi Take a look at #677 as it specifically needs to be inherited first.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now merged FYI.

y = y - beta_step * (y - matrix.detach())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of matrix/y are the critical part. This is slightly costly because we have another time the full Jacobian in memory. I would not use y and use self._y instead to garbage collect the previous value and save one full Jacobian. Actualy the most memory efficient implementation would be:

Suggested change
y = y - beta_step * (y - matrix.detach())
self._y = (1-\beta_step) * self._y
self._y += beta_step * matrix.detach()

With this you have exactly two Jacobians stored.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be complicated to adapt with the normalization wrapper as we normalize the y itself, not matrix. So I think the normalization step of equation 6 needs to be implemented here. Note that what opencode suggested above from LibMTL doesn't match what is said in the paper, so I'm not sure which one to use.

yy_t = y @ y.T
if self.rho != 0.0:
eye = torch.eye(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
yy_t = yy_t + self.rho * eye
lambd = torch.softmax(lambd - gamma_step * (yy_t @ lambd), dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper, the softmax is supposed to be a projection to the probability simplex. Why is this a softmax (it is on the probability simplex, but not a projection)?


self._y = y
self._lambd = lambd

return lambd
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is incorrect. In equation 10 of the paper, they use lambda @ y as an update, not lambda @ matrix. I think this would make moco not weighted (the row spans of matrix and y can be very different in general).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I just saw below that while this is a weighting, the moco aggregator is not weighted. I assumed wrongly that it was.


@property
def y(self) -> Tensor:
if self._y is None:
raise RuntimeError("The moving gradient estimate is not initialized yet.")
return self._y

def _ensure_state(self, matrix: Matrix) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the role of this function to initialize the state if none? we call it differently if so, maybe something in the direction of _conditionally_initialize_state (but probably improvable).

key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype)
if self._state_key == key and self._y is not None and self._lambd is not None:
return

self._y = torch.zeros_like(matrix)
self._lambd = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0])
self._state_key = key


class MoCo(Aggregator, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Aggregator` implementing MoCo from `Mitigating Gradient Bias in
Multi-objective Learning: A Provably Convergent Approach (ICLR 2023)
<https://openreview.net/forum?id=dLAYGdKTi2>`_.

This aggregator is stateful: it keeps the moving gradient estimate :math:`Y` and the task
weights :math:`\lambda` across calls. Use :meth:`reset` between independent runs.

.. warning::
The output depends on previously seen matrices. Call :meth:`reset` between independent
experiments.

:param beta: Learning rate of the moving gradient estimate.
:param beta_sigma: Decay exponent of ``beta``.
:param gamma: Learning rate of the task weights.
:param gamma_sigma: Decay exponent of ``gamma``.
:param rho: Non-negative :math:`\ell_2` regularization parameter for the task-weight update.
"""

weighting: MoCoWeighting

def __init__(
self,
beta: float = 0.5,
beta_sigma: float = 0.5,
gamma: float = 0.1,
gamma_sigma: float = 0.5,
rho: float = 0.0,
) -> None:
super().__init__()
self.weighting = MoCoWeighting(
beta=beta,
beta_sigma=beta_sigma,
gamma=gamma,
gamma_sigma=gamma_sigma,
rho=rho,
)
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
if matrix.shape[0] == 0:
return matrix.sum(dim=0)
Comment on lines +201 to +202
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove

Suggested change
if matrix.shape[0] == 0:
return matrix.sum(dim=0)


vector = weights @ self.weighting.y
if matrix.requires_grad:
vector = vector + 0.0 * matrix.sum(dim=0)
Comment on lines +205 to +206
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not do that. No reason to create a wrong differentiation graph for vector. Maybe we want to make this aggregator/weighting pair non-differentiable (because we don't want to keep all the Jacobians for all times). This can be done after #677 is merged.

Suggested change
if matrix.requires_grad:
vector = vector + 0.0 * matrix.sum(dim=0)

return vector

@property
def beta(self) -> float:
return self.weighting.beta

@beta.setter
def beta(self, value: float) -> None:
self.weighting.beta = value

@property
def beta_sigma(self) -> float:
return self.weighting.beta_sigma

@beta_sigma.setter
def beta_sigma(self, value: float) -> None:
self.weighting.beta_sigma = value

@property
def gamma(self) -> float:
return self.weighting.gamma

@gamma.setter
def gamma(self, value: float) -> None:
self.weighting.gamma = value

@property
def gamma_sigma(self) -> float:
return self.weighting.gamma_sigma

@gamma_sigma.setter
def gamma_sigma(self, value: float) -> None:
self.weighting.gamma_sigma = value

@property
def rho(self) -> float:
return self.weighting.rho

@rho.setter
def rho(self, value: float) -> None:
self.weighting.rho = value

def reset(self) -> None:
"""Clears the moving gradient estimate and resets the task weights."""

self.weighting.reset()

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(beta={self.beta!r}, beta_sigma={self.beta_sigma!r}, "
f"gamma={self.gamma!r}, gamma_sigma={self.gamma_sigma!r}, rho={self.rho!r})"
)
Loading
Loading