Skip to content
Merged
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ changelog does not include internal changes that do not affect the user.

### Added

- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting`
public. These abstract base classes are now importable from `torchjd.aggregation` and documented.
They can be extended to easily implement custom `Weighting`s and `Aggregator`s.
- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are
now importable from `torchjd.aggregation` and documented. They can be extended to easily implement
custom `Aggregator`s.
- Made `Matrix` and `PSDMatrix` public. These type annotation classes are now importable from
`torchjd.linalg` and documented. Users can now subclass `Weighting[Matrix]` or
`Weighting[PSDMatrix]` to implement custom `Weighting`s.
- Added getters and setters for the constructor parameters of all aggregators and weightings, so
that they can be changed after initialization. This includes: `pref_vector`,
`norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`;
Expand Down
6 changes: 0 additions & 6 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Weighting
:members: __call__

.. autoclass:: torchjd.aggregation.MatrixWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.GramianWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.GeneralizedWeighting
:members: __call__

Expand Down
12 changes: 12 additions & 0 deletions docs/source/docs/linalg/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
linalg
======

.. automodule:: torchjd.linalg
:no-members:

.. toctree::
:hidden:
:maxdepth: 1

matrix.rst
psd_matrix.rst
6 changes: 6 additions & 0 deletions docs/source/docs/linalg/matrix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:hide-toc:

Matrix
======

.. autoclass:: torchjd.linalg.Matrix
6 changes: 6 additions & 0 deletions docs/source/docs/linalg/psd_matrix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:hide-toc:

PSDMatrix
=========

.. autoclass:: torchjd.linalg.PSDMatrix
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ TorchJD is open-source, under MIT License. The source code is available on
docs/autogram/index.rst
docs/autojac/index.rst
docs/aggregation/index.rst
docs/linalg/index.rst
21 changes: 19 additions & 2 deletions src/torchjd/_linalg/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@


class Matrix(Tensor):
"""Tensor with exactly 2 dimensions."""
"""
Tensor with exactly 2 dimensions.

Common examples include the Jacobian matrix J of shape ``[m, n]``, where m is the number of
objectives and n is the number of model parameters, and the Gramian of the Jacobian
G = J J^T of shape ``[m, m]``.

.. note::
This class should never be instantiated. It is only used for static type checking.
"""


class PSDTensor(Tensor):
Expand All @@ -20,7 +29,15 @@ class PSDTensor(Tensor):


class PSDMatrix(PSDTensor, Matrix):
"""Positive semi-definite matrix."""
"""
Positive semi-definite matrix.

A common example is the Gramian of the Jacobian G = J J^T of shape ``[m, m]``, where J is a
Jacobian matrix of shape ``[m, n]``.

.. note::
This class should never be instantiated. It is only used for static type checking.
"""


def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
Expand Down
9 changes: 4 additions & 5 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

.. note::
Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights
from this Gramian using a :class:`~torchjd.aggregation.GramianWeighting`, and then combining the
rows of the Jacobian using these weights. For all of them, we provide both the
from this Gramian using a :class:`~torchjd.aggregation.Weighting`
[:class:`~torchjd.linalg.PSDMatrix`], and then combining the rows of the Jacobian using these
weights. For all of them, we provide both the
:class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the
:class:`~torchjd.aggregation.Weighting` interface (to be used in autogram).
For the rest, we only provide the :class:`~torchjd.aggregation.Aggregator`
Expand Down Expand Up @@ -80,7 +81,7 @@
from ._utils.check_dependencies import (
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
)
from ._weighting_bases import GeneralizedWeighting, GramianWeighting, MatrixWeighting, Weighting
from ._weighting_bases import GeneralizedWeighting, Weighting

__all__ = [
"Aggregator",
Expand All @@ -97,12 +98,10 @@
"GradVac",
"GradVacWeighting",
"GramianWeightedAggregator",
"GramianWeighting",
"IMTLG",
"IMTLGWeighting",
"Krum",
"KrumWeighting",
"MatrixWeighting",
"Mean",
"MeanWeighting",
"MGDA",
Expand Down
17 changes: 9 additions & 8 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from typing import cast

from torch import Tensor, nn

from torchjd._linalg import Matrix, compute_gramian, is_matrix
from torchjd._linalg import compute_gramian, is_matrix
from torchjd.linalg import Matrix, PSDMatrix

from ._weighting_bases import GramianWeighting, MatrixWeighting
from ._weighting_bases import Weighting


class Aggregator(nn.Module, ABC):
Expand Down Expand Up @@ -48,12 +48,12 @@ def __str__(self) -> str:
class WeightedAggregator(Aggregator):
"""
Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a
:class:`~torchjd.aggregation.MatrixWeighting` to it.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] to it.

:param weighting: The object responsible for extracting the vector of weights from the matrix.
"""

def __init__(self, weighting: MatrixWeighting) -> None:
def __init__(self, weighting: Weighting[Matrix]) -> None:
super().__init__()
self.weighting = weighting

Expand All @@ -76,12 +76,13 @@ def forward(self, matrix: Matrix, /) -> Tensor:
class GramianWeightedAggregator(WeightedAggregator):
"""
:class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input
Jacobian matrix before applying a :class:`~torchjd.aggregation.GramianWeighting` to it.
Jacobian matrix before applying a :class:`~torchjd.aggregation.Weighting`
[:class:`~torchjd.linalg.PSDMatrix`] to it.

:param gramian_weighting: The object responsible for extracting the vector of weights from the
gramian.
"""

def __init__(self, gramian_weighting: GramianWeighting) -> None:
super().__init__(cast(MatrixWeighting, gramian_weighting << compute_gramian))
def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
super().__init__(gramian_weighting << compute_gramian)
self.gramian_weighting = gramian_weighting
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting

SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]


class AlignedMTLWeighting(GramianWeighting):
class AlignedMTLWeighting(_GramianWeighting):
r"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.AlignedMTL`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.AlignedMTL`.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import cast

from torchjd._linalg import PSDMatrix
from torchjd.linalg import PSDMatrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting

check_dependencies_are_installed(["cvxpy", "clarabel"])

Expand All @@ -18,10 +18,10 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGradWeighting(GramianWeighting):
class CAGradWeighting(_GramianWeighting):
"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.CAGrad`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.CAGrad`.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._sum import SumWeighting
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import MatrixWeighting
from ._weighting_bases import _MatrixWeighting


class ConstantWeighting(MatrixWeighting):
class ConstantWeighting(_MatrixWeighting):
"""
:class:`~torchjd.aggregation.MatrixWeighting` that returns constant, pre-determined
weights.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`]
that returns constant, pre-determined weights.

:param weights: The weights to return at each call.
"""
Expand Down
11 changes: 6 additions & 5 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix, normalize, regularize
from torchjd._linalg import normalize, regularize
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting


class DualProjWeighting(GramianWeighting):
class DualProjWeighting(_GramianWeighting):
r"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.DualProj`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.DualProj`.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._utils.non_differentiable import raise_non_differentiable_error
Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation._mixins import Stateful
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting


class GradVacWeighting(GramianWeighting, Stateful):
class GradVacWeighting(_GramianWeighting, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.GradVac`.

All required quantities (gradient norms, cosine similarities, and their updates after the
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting


class IMTLGWeighting(GramianWeighting):
class IMTLGWeighting(_GramianWeighting):
"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.IMTLG`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.IMTLG`.
"""

def forward(self, gramian: PSDMatrix, /) -> Tensor:
Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._linalg import PSDMatrix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting


class KrumWeighting(GramianWeighting):
class KrumWeighting(_GramianWeighting):
"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.Krum`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.Krum`.

:param n_byzantine: The number of rows of the input matrix that can come from an adversarial
source.
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from torch import Tensor

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import MatrixWeighting
from ._weighting_bases import _MatrixWeighting


class MeanWeighting(MatrixWeighting):
class MeanWeighting(_MatrixWeighting):
r"""
:class:`~torchjd.aggregation.MatrixWeighting` that gives the weights
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
"""
Expand Down
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import GramianWeighting
from ._weighting_bases import _GramianWeighting


class MGDAWeighting(GramianWeighting):
class MGDAWeighting(_GramianWeighting):
r"""
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.MGDA`.
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.MGDA`.

:param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
:param max_iters: The maximum number of iterations of the optimization loop.
Expand Down
Loading
Loading