From 93046e9d716468d72845660be3433a1ce510867a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 17:19:32 +0200 Subject: [PATCH 1/8] feat(linalg): Make Matrix and PSDMatrix public Co-Authored-By: Claude Sonnet 4.6 --- docs/source/docs/linalg/index.rst | 12 +++++++++++ docs/source/docs/linalg/matrix.rst | 6 ++++++ docs/source/docs/linalg/psd_matrix.rst | 6 ++++++ docs/source/index.rst | 1 + src/torchjd/_linalg/_matrix.py | 21 +++++++++++++++++-- src/torchjd/linalg/__init__.py | 29 ++++++++++++++++++++++++++ 6 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 docs/source/docs/linalg/index.rst create mode 100644 docs/source/docs/linalg/matrix.rst create mode 100644 docs/source/docs/linalg/psd_matrix.rst create mode 100644 src/torchjd/linalg/__init__.py diff --git a/docs/source/docs/linalg/index.rst b/docs/source/docs/linalg/index.rst new file mode 100644 index 000000000..4446ccea7 --- /dev/null +++ b/docs/source/docs/linalg/index.rst @@ -0,0 +1,12 @@ +linalg +====== + +.. automodule:: torchjd.linalg + :no-members: + +.. toctree:: + :hidden: + :maxdepth: 1 + + matrix.rst + psd_matrix.rst diff --git a/docs/source/docs/linalg/matrix.rst b/docs/source/docs/linalg/matrix.rst new file mode 100644 index 000000000..165f3718a --- /dev/null +++ b/docs/source/docs/linalg/matrix.rst @@ -0,0 +1,6 @@ +:hide-toc: + +Matrix +====== + +.. autoclass:: torchjd.linalg.Matrix diff --git a/docs/source/docs/linalg/psd_matrix.rst b/docs/source/docs/linalg/psd_matrix.rst new file mode 100644 index 000000000..8ee262a23 --- /dev/null +++ b/docs/source/docs/linalg/psd_matrix.rst @@ -0,0 +1,6 @@ +:hide-toc: + +PSDMatrix +========= + +.. autoclass:: torchjd.linalg.PSDMatrix diff --git a/docs/source/index.rst b/docs/source/index.rst index 1b2bd970f..d8b14f830 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,3 +70,4 @@ TorchJD is open-source, under MIT License. The source code is available on docs/autogram/index.rst docs/autojac/index.rst docs/aggregation/index.rst + docs/linalg/index.rst diff --git a/src/torchjd/_linalg/_matrix.py b/src/torchjd/_linalg/_matrix.py index 7a4960ca7..574d63d73 100644 --- a/src/torchjd/_linalg/_matrix.py +++ b/src/torchjd/_linalg/_matrix.py @@ -8,7 +8,16 @@ class Matrix(Tensor): - """Tensor with exactly 2 dimensions.""" + """ + Tensor with exactly 2 dimensions. + + Common examples include the Jacobian matrix J of shape ``[m, n]``, where m is the number of + objectives and n is the number of model parameters, and the Gramian of the Jacobian + G = J J^T of shape ``[m, m]``. + + .. note:: + This class should never be instantiated. It is only used for static type checking. + """ class PSDTensor(Tensor): @@ -20,7 +29,15 @@ class PSDTensor(Tensor): class PSDMatrix(PSDTensor, Matrix): - """Positive semi-definite matrix.""" + """ + Positive semi-definite matrix. + + A common example is the Gramian of the Jacobian G = J J^T of shape ``[m, m]``, where J is a + Jacobian matrix of shape ``[m, n]``. + + .. note:: + This class should never be instantiated. It is only used for static type checking. + """ def is_matrix(t: Tensor) -> TypeGuard[Matrix]: diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py new file mode 100644 index 000000000..5fa613877 --- /dev/null +++ b/src/torchjd/linalg/__init__.py @@ -0,0 +1,29 @@ +""" +This module provides type annotation classes representing tensors with specific structural +properties. + +:class:`Matrix` represents any 2D tensor. A common example in the context of Jacobian descent +is the Jacobian matrix J of shape ``[m, n]``, where m is the number of objectives and n is the +number of model parameters. + +:class:`PSDMatrix` represents a symmetric positive semi-definite square matrix. A common +example is the Gramian of the Jacobian G = J J^T of shape ``[m, m]``. + +.. note:: + :class:`Matrix` and :class:`PSDMatrix` extend :class:`~torch.Tensor` for type-checking + purposes only and should never be directly instantiated. + +>>> import torch +>>> # Jacobian matrix of shape [m, n] = [2, 3] +>>> J = torch.tensor([[-4., 1., 1.], [6., 1., 1.]]) +>>> J.ndim +2 +>>> # Gramian of the Jacobian, of shape [m, m] = [2, 2] +>>> G = J @ J.T +>>> G.shape +torch.Size([2, 2]) +""" + +from torchjd._linalg._matrix import Matrix, PSDMatrix + +__all__ = ["Matrix", "PSDMatrix"] From d689560c13732ecaa22466560dece0282c0a7713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 17:29:33 +0200 Subject: [PATCH 2/8] refactor(aggregation): Replace MatrixWeighting/GramianWeighting with Weighting[Matrix]/Weighting[PSDMatrix] Co-Authored-By: Claude Sonnet 4.6 --- docs/source/docs/aggregation/index.rst | 6 ---- src/torchjd/aggregation/__init__.py | 6 ++-- src/torchjd/aggregation/_aggregator_bases.py | 14 ++++----- src/torchjd/aggregation/_aligned_mtl.py | 6 ++-- src/torchjd/aggregation/_cagrad.py | 6 ++-- src/torchjd/aggregation/_constant.py | 8 +++-- src/torchjd/aggregation/_dualproj.py | 6 ++-- src/torchjd/aggregation/_gradvac.py | 6 ++-- src/torchjd/aggregation/_imtl_g.py | 6 ++-- src/torchjd/aggregation/_krum.py | 6 ++-- src/torchjd/aggregation/_mean.py | 8 +++-- src/torchjd/aggregation/_mgda.py | 6 ++-- src/torchjd/aggregation/_nash_mtl.py | 7 +++-- src/torchjd/aggregation/_pcgrad.py | 6 ++-- src/torchjd/aggregation/_random.py | 8 +++-- src/torchjd/aggregation/_sum.py | 8 +++-- src/torchjd/aggregation/_upgrad.py | 6 ++-- src/torchjd/aggregation/_weighting_bases.py | 32 +------------------- 18 files changed, 61 insertions(+), 90 deletions(-) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 04a8de666..ff6e18112 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -19,12 +19,6 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Weighting :members: __call__ -.. autoclass:: torchjd.aggregation.MatrixWeighting - :members: __call__ - -.. autoclass:: torchjd.aggregation.GramianWeighting - :members: __call__ - .. autoclass:: torchjd.aggregation.GeneralizedWeighting :members: __call__ diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index ec871e899..bff6ebe92 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -9,7 +9,7 @@ .. note:: Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights - from this Gramian using a :class:`~torchjd.aggregation.GramianWeighting`, and then combining the + from this Gramian using a :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]``, and then combining the rows of the Jacobian using these weights. For all of them, we provide both the :class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the :class:`~torchjd.aggregation.Weighting` interface (to be used in autogram). @@ -80,7 +80,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import GeneralizedWeighting, GramianWeighting, MatrixWeighting, Weighting +from ._weighting_bases import GeneralizedWeighting, Weighting __all__ = [ "Aggregator", @@ -97,12 +97,10 @@ "GradVac", "GradVacWeighting", "GramianWeightedAggregator", - "GramianWeighting", "IMTLG", "IMTLGWeighting", "Krum", "KrumWeighting", - "MatrixWeighting", "Mean", "MeanWeighting", "MGDA", diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 7372b0027..25ac002f4 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -3,9 +3,9 @@ from torch import Tensor, nn -from torchjd._linalg import Matrix, compute_gramian, is_matrix +from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix -from ._weighting_bases import GramianWeighting, MatrixWeighting +from ._weighting_bases import Weighting class Aggregator(nn.Module, ABC): @@ -48,12 +48,12 @@ def __str__(self) -> str: class WeightedAggregator(Aggregator): """ Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a - :class:`~torchjd.aggregation.MatrixWeighting` to it. + :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` to it. :param weighting: The object responsible for extracting the vector of weights from the matrix. """ - def __init__(self, weighting: MatrixWeighting) -> None: + def __init__(self, weighting: Weighting[Matrix]) -> None: super().__init__() self.weighting = weighting @@ -76,12 +76,12 @@ def forward(self, matrix: Matrix, /) -> Tensor: class GramianWeightedAggregator(WeightedAggregator): """ :class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input - Jacobian matrix before applying a :class:`~torchjd.aggregation.GramianWeighting` to it. + Jacobian matrix before applying a :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` to it. :param gramian_weighting: The object responsible for extracting the vector of weights from the gramian. """ - def __init__(self, gramian_weighting: GramianWeighting) -> None: - super().__init__(cast(MatrixWeighting, gramian_weighting << compute_gramian)) + def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: + super().__init__(cast(Weighting[Matrix], gramian_weighting << compute_gramian)) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 5ff9e5ed5..e10ad6d18 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -12,14 +12,14 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] -class AlignedMTLWeighting(GramianWeighting): +class AlignedMTLWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. :param pref_vector: The preference vector to use. If not provided, defaults to diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index a0b09cb1c..357b55c8d 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -3,7 +3,7 @@ from torchjd._linalg import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "clarabel"]) @@ -18,9 +18,9 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGradWeighting(GramianWeighting): +class CAGradWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.CAGrad`. :param c: The scale of the radius of the ball constraint. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 91a639ba6..cc26e339c 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,13 +1,15 @@ from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import Weighting -class ConstantWeighting(MatrixWeighting): +class ConstantWeighting(Weighting[Matrix]): """ - :class:`~torchjd.aggregation.MatrixWeighting` that returns constant, pre-determined + :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that returns constant, pre-determined weights. :param weights: The weights to return at each call. diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index a9167648c..2ef28d9d4 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -7,12 +7,12 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class DualProjWeighting(GramianWeighting): +class DualProjWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.DualProj`. :param pref_vector: The preference vector to use. If not provided, defaults to diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index b791f6808..9d8772f7f 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -10,13 +10,13 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class GradVacWeighting(GramianWeighting, Stateful): +class GradVacWeighting(Weighting[PSDMatrix], Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.GradVac`. All required quantities (gradient norms, cosine similarities, and their updates after the diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index a53085be7..5935eaa3b 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -5,12 +5,12 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class IMTLGWeighting(GramianWeighting): +class IMTLGWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.IMTLG`. """ diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index db268e0c7..cdea91e07 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -5,12 +5,12 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class KrumWeighting(GramianWeighting): +class KrumWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.Krum`. :param n_byzantine: The number of rows of the input matrix that can come from an adversarial diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index e8b75e7cc..ed3851d89 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,13 +1,15 @@ import torch from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import Weighting -class MeanWeighting(MatrixWeighting): +class MeanWeighting(Weighting[Matrix]): r""" - :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights + :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index ec9d0afc3..4828e050d 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -4,12 +4,12 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class MGDAWeighting(GramianWeighting): +class MGDAWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.MGDA`. :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 5be55afd5..fcac07c55 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,10 +1,11 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. +from torchjd._linalg import Matrix from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) @@ -18,9 +19,9 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class _NashMTLWeighting(MatrixWeighting, Stateful): +class _NashMTLWeighting(Weighting[Matrix], Stateful): """ - :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.MatrixWeighting` that + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 25e244522..a2af06662 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -7,12 +7,12 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class PCGradWeighting(GramianWeighting): +class PCGradWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.PCGrad`. """ diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index d20d54db1..b0fbcb5f5 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,13 +2,15 @@ from torch import Tensor from torch.nn import functional as F +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import Weighting -class RandomWeighting(MatrixWeighting): +class RandomWeighting(Weighting[Matrix]): """ - :class:`~torchjd.aggregation.MatrixWeighting` that generates positive random weights + :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that generates positive random weights at each call. """ diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 1a48ef41f..d22f50ff7 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,13 +1,15 @@ import torch from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import Weighting -class SumWeighting(MatrixWeighting): +class SumWeighting(Weighting[Matrix]): r""" - :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights + :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 731986843..a731164db 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -8,12 +8,12 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import Weighting -class UPGradWeighting(GramianWeighting): +class UPGradWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of + :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of :class:`~torchjd.aggregation.UPGrad`. :param pref_vector: The preference vector to use. If not provided, defaults to diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 2655aa28a..6ca5dddc5 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,7 @@ from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, PSDTensor, is_psd_tensor +from torchjd._linalg import PSDTensor, is_psd_tensor _T = TypeVar("_T", contravariant=True, bound=Tensor) _FnInputT = TypeVar("_FnInputT", bound=Tensor) @@ -34,8 +34,6 @@ def __call__(self, stat: Tensor, /) -> Tensor: :param stat: The stat from which the weights must be extracted. """ - # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of - # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: @@ -84,31 +82,3 @@ def __call__(self, generalized_gramian: Tensor, /) -> Tensor: assert is_psd_tensor(generalized_gramian) return super().__call__(generalized_gramian) - - -# Subclasses used only to redefine the __call__ method with more specific parameter names and -# docstrings. Note that MatrixWeighting <: Weighting[Matrix] <: Weighting[PSDMatrix], because -# PSDMatrix <: Matrix and Weighting[_T] is contravariant with _T. -# Also note that we don't have: MatrixWeighting <: GramianWeighting. GramianWeighting is not -# just an alias of Weighting[PSDMatrix], it's a subtype of it. So the type Weighting[PSDMatrix] -# should still be used when we expect a Weighting that works at least on PSD matrices. - - -class MatrixWeighting(Weighting[Matrix]): - def __call__(self, matrix: Tensor, /) -> Tensor: - """ - Computes the vector of weights from the input matrix and applies all registered hooks. - - :param matrix: The matrix from which the weights must be extracted. - """ - return super().__call__(matrix) - - -class GramianWeighting(Weighting[PSDMatrix]): - def __call__(self, gramian: Tensor, /) -> Tensor: - """ - Computes the vector of weights from the input gramian and applies all registered hooks. - - :param gramian: The gramian from which the weights must be extracted. - """ - return super().__call__(gramian) From 67d15b7d99300bd5fa6a117e95b805c15fe3c281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 17:34:51 +0200 Subject: [PATCH 3/8] refactor(aggregation): Remove cast in GramianWeightedAggregator.__init__ Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/_aggregator_bases.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 25ac002f4..09e94d97c 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import cast from torch import Tensor, nn @@ -83,5 +82,5 @@ class GramianWeightedAggregator(WeightedAggregator): """ def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: - super().__init__(cast(Weighting[Matrix], gramian_weighting << compute_gramian)) + super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting From 7e43a9827e6760bfd1e318f3f4baf0bf04ac4ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 17:38:04 +0200 Subject: [PATCH 4/8] Remove ai slop from linalg __init__.py --- src/torchjd/linalg/__init__.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py index 5fa613877..f8238104e 100644 --- a/src/torchjd/linalg/__init__.py +++ b/src/torchjd/linalg/__init__.py @@ -1,27 +1,6 @@ """ This module provides type annotation classes representing tensors with specific structural properties. - -:class:`Matrix` represents any 2D tensor. A common example in the context of Jacobian descent -is the Jacobian matrix J of shape ``[m, n]``, where m is the number of objectives and n is the -number of model parameters. - -:class:`PSDMatrix` represents a symmetric positive semi-definite square matrix. A common -example is the Gramian of the Jacobian G = J J^T of shape ``[m, m]``. - -.. note:: - :class:`Matrix` and :class:`PSDMatrix` extend :class:`~torch.Tensor` for type-checking - purposes only and should never be directly instantiated. - ->>> import torch ->>> # Jacobian matrix of shape [m, n] = [2, 3] ->>> J = torch.tensor([[-4., 1., 1.], [6., 1., 1.]]) ->>> J.ndim -2 ->>> # Gramian of the Jacobian, of shape [m, m] = [2, 2] ->>> G = J @ J.T ->>> G.shape -torch.Size([2, 2]) """ from torchjd._linalg._matrix import Matrix, PSDMatrix From e448e22e5ee1043feb11e49cef69cf643f1c961e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 17:47:49 +0200 Subject: [PATCH 5/8] docs(aggregation): Make Matrix and PSDMatrix clickable in docstrings Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/__init__.py | 5 +++-- src/torchjd/aggregation/_aggregator_bases.py | 5 +++-- src/torchjd/aggregation/_aligned_mtl.py | 4 ++-- src/torchjd/aggregation/_cagrad.py | 4 ++-- src/torchjd/aggregation/_constant.py | 4 ++-- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_gradvac.py | 4 ++-- src/torchjd/aggregation/_imtl_g.py | 4 ++-- src/torchjd/aggregation/_krum.py | 4 ++-- src/torchjd/aggregation/_mean.py | 2 +- src/torchjd/aggregation/_mgda.py | 4 ++-- src/torchjd/aggregation/_nash_mtl.py | 3 ++- src/torchjd/aggregation/_pcgrad.py | 4 ++-- src/torchjd/aggregation/_random.py | 4 ++-- src/torchjd/aggregation/_sum.py | 2 +- src/torchjd/aggregation/_upgrad.py | 4 ++-- 16 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index bff6ebe92..bb8892ff5 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -9,8 +9,9 @@ .. note:: Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights - from this Gramian using a :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]``, and then combining the - rows of the Jacobian using these weights. For all of them, we provide both the + from this Gramian using a :class:`~torchjd.aggregation.Weighting` + [:class:`~torchjd.linalg.PSDMatrix`], and then combining the rows of the Jacobian using these + weights. For all of them, we provide both the :class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the :class:`~torchjd.aggregation.Weighting` interface (to be used in autogram). For the rest, we only provide the :class:`~torchjd.aggregation.Aggregator` diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 09e94d97c..8f2be2d19 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -47,7 +47,7 @@ def __str__(self) -> str: class WeightedAggregator(Aggregator): """ Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a - :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` to it. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] to it. :param weighting: The object responsible for extracting the vector of weights from the matrix. """ @@ -75,7 +75,8 @@ def forward(self, matrix: Matrix, /) -> Tensor: class GramianWeightedAggregator(WeightedAggregator): """ :class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input - Jacobian matrix before applying a :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` to it. + Jacobian matrix before applying a :class:`~torchjd.aggregation.Weighting` + [:class:`~torchjd.linalg.PSDMatrix`] to it. :param gramian_weighting: The object responsible for extracting the vector of weights from the gramian. diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index e10ad6d18..ed015897e 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -19,8 +19,8 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.AlignedMTL`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 357b55c8d..f0df44b7b 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -20,8 +20,8 @@ class CAGradWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.CAGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.CAGrad`. :param c: The scale of the radius of the ball constraint. :param norm_eps: A small value to avoid division by zero when normalizing. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index cc26e339c..db1b56a06 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -9,8 +9,8 @@ class ConstantWeighting(Weighting[Matrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that returns constant, pre-determined - weights. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] + that returns constant, pre-determined weights. :param weights: The weights to return at each call. """ diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 2ef28d9d4..22172f61c 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -12,8 +12,8 @@ class DualProjWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.DualProj`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.DualProj`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 9d8772f7f..d44f955b6 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -16,8 +16,8 @@ class GradVacWeighting(Weighting[PSDMatrix], Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.GradVac`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.GradVac`. All required quantities (gradient norms, cosine similarities, and their updates after the vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 5935eaa3b..677a2ac9d 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -10,8 +10,8 @@ class IMTLGWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.IMTLG`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.IMTLG`. """ def forward(self, gramian: PSDMatrix, /) -> Tensor: diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index cdea91e07..512ce88d8 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -10,8 +10,8 @@ class KrumWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.Krum`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.Krum`. :param n_byzantine: The number of rows of the input matrix that can come from an adversarial source. diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index ed3851d89..50a19e8f0 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -9,7 +9,7 @@ class MeanWeighting(Weighting[Matrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that gives the weights + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 4828e050d..87e9771e9 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -9,8 +9,8 @@ class MGDAWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.MGDA`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.MGDA`. :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. :param max_iters: The maximum number of iterations of the optimization loop. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index fcac07c55..fef03834c 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -21,7 +21,8 @@ class _NashMTLWeighting(Weighting[Matrix], Stateful): """ - :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index a2af06662..32515c9f6 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -12,8 +12,8 @@ class PCGradWeighting(Weighting[PSDMatrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.PCGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.PCGrad`. """ def forward(self, gramian: PSDMatrix, /) -> Tensor: diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index b0fbcb5f5..593357172 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -10,8 +10,8 @@ class RandomWeighting(Weighting[Matrix]): """ - :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that generates positive random weights - at each call. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] + that generates positive random weights at each call. """ def forward(self, matrix: Tensor, /) -> Tensor: diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index d22f50ff7..a005c52f3 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -9,7 +9,7 @@ class SumWeighting(Weighting[Matrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[Matrix]`` that gives the weights + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index a731164db..4bf99c545 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -13,8 +13,8 @@ class UPGradWeighting(Weighting[PSDMatrix]): r""" - :class:`~torchjd.aggregation.Weighting` ``[PSDMatrix]`` giving the weights of - :class:`~torchjd.aggregation.UPGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.UPGrad`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. From 402e5ea777eafb7171eb9e1ff5c9207a3f0017ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 19:20:32 +0200 Subject: [PATCH 6/8] refactor(aggregation): Re-add _MatrixWeighting and _GramianWeighting as private bases These private intermediate classes give the `__call__` method a properly named parameter (`matrix` or `gramian`) instead of the generic `stat`, improving documentation for all concrete weighting subclasses. Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/_aligned_mtl.py | 4 ++-- src/torchjd/aggregation/_cagrad.py | 4 ++-- src/torchjd/aggregation/_constant.py | 6 ++---- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_gradvac.py | 4 ++-- src/torchjd/aggregation/_imtl_g.py | 4 ++-- src/torchjd/aggregation/_krum.py | 4 ++-- src/torchjd/aggregation/_mean.py | 6 ++---- src/torchjd/aggregation/_mgda.py | 4 ++-- src/torchjd/aggregation/_nash_mtl.py | 5 ++--- src/torchjd/aggregation/_pcgrad.py | 4 ++-- src/torchjd/aggregation/_random.py | 6 ++---- src/torchjd/aggregation/_sum.py | 6 ++---- src/torchjd/aggregation/_upgrad.py | 4 ++-- src/torchjd/aggregation/_weighting_bases.py | 22 ++++++++++++++++++++- 15 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index ed015897e..61c6d7a9c 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -12,12 +12,12 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] -class AlignedMTLWeighting(Weighting[PSDMatrix]): +class AlignedMTLWeighting(_GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index f0df44b7b..35a4d1e83 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -3,7 +3,7 @@ from torchjd._linalg import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting check_dependencies_are_installed(["cvxpy", "clarabel"]) @@ -18,7 +18,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGradWeighting(Weighting[PSDMatrix]): +class CAGradWeighting(_GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.CAGrad`. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index db1b56a06..54f973a22 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,13 +1,11 @@ from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import Weighting +from ._weighting_bases import _MatrixWeighting -class ConstantWeighting(Weighting[Matrix]): +class ConstantWeighting(_MatrixWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that returns constant, pre-determined weights. diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 22172f61c..d0ab8d8c1 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -7,10 +7,10 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class DualProjWeighting(Weighting[PSDMatrix]): +class DualProjWeighting(_GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.DualProj`. diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index d44f955b6..674d07878 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -10,10 +10,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class GradVacWeighting(Weighting[PSDMatrix], Stateful): +class GradVacWeighting(_GramianWeighting, Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 677a2ac9d..0539c726d 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -5,10 +5,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class IMTLGWeighting(Weighting[PSDMatrix]): +class IMTLGWeighting(_GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.IMTLG`. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 512ce88d8..4bdc6f3a2 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -5,10 +5,10 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class KrumWeighting(Weighting[PSDMatrix]): +class KrumWeighting(_GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.Krum`. diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 50a19e8f0..8f9f54570 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,13 +1,11 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import _MatrixWeighting -class MeanWeighting(Weighting[Matrix]): +class MeanWeighting(_MatrixWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 87e9771e9..da27a4b47 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -4,10 +4,10 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class MGDAWeighting(Weighting[PSDMatrix]): +class MGDAWeighting(_GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.MGDA`. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index fef03834c..63271d631 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,11 +1,10 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from torchjd._linalg import Matrix from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Weighting +from ._weighting_bases import _MatrixWeighting check_dependencies_are_installed(["cvxpy", "ecos"]) @@ -19,7 +18,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class _NashMTLWeighting(Weighting[Matrix], Stateful): +class _NashMTLWeighting(_MatrixWeighting, Stateful): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 32515c9f6..b2d8a6feb 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -7,10 +7,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class PCGradWeighting(Weighting[PSDMatrix]): +class PCGradWeighting(_GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.PCGrad`. diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 593357172..cb2e651f3 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,13 +2,11 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import _MatrixWeighting -class RandomWeighting(Weighting[Matrix]): +class RandomWeighting(_MatrixWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that generates positive random weights at each call. diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index a005c52f3..001960c25 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,13 +1,11 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import _MatrixWeighting -class SumWeighting(Weighting[Matrix]): +class SumWeighting(_MatrixWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 4bf99c545..2d9e54dcf 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -8,10 +8,10 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import _GramianWeighting -class UPGradWeighting(Weighting[PSDMatrix]): +class UPGradWeighting(_GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.UPGrad`. diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 6ca5dddc5..8e62008d2 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,7 @@ from torch import Tensor, nn -from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd._linalg import Matrix, PSDMatrix, PSDTensor, is_psd_tensor _T = TypeVar("_T", contravariant=True, bound=Tensor) _FnInputT = TypeVar("_FnInputT", bound=Tensor) @@ -57,6 +57,26 @@ def forward(self, stat: _T, /) -> Tensor: return self.weighting(self.fn(stat)) +class _MatrixWeighting(Weighting[Matrix]): + def __call__(self, matrix: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input matrix and applies all registered hooks. + + :param matrix: The matrix from which the weights must be extracted. + """ + return super().__call__(matrix) + + +class _GramianWeighting(Weighting[PSDMatrix]): + def __call__(self, gramian: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input Gramian and applies all registered hooks. + + :param gramian: The Gramian from which the weights must be extracted. + """ + return super().__call__(gramian) + + class GeneralizedWeighting(nn.Module, ABC): r""" Abstract base class for all weightings that operate on generalized Gramians. It has the role of From b5de89dd4f36192080f02269fa08867a7a346be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 19:30:07 +0200 Subject: [PATCH 7/8] docs(changelog): Adapt entry for MatrixWeighting/GramianWeighting removal and Matrix/PSDMatrix addition Co-Authored-By: Claude Sonnet 4.6 --- CHANGELOG.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6cec74a4..f93e0d2fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,12 @@ changelog does not include internal changes that do not affect the user. ### Added -- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting` - public. These abstract base classes are now importable from `torchjd.aggregation` and documented. - They can be extended to easily implement custom `Weighting`s and `Aggregator`s. +- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are + now importable from `torchjd.aggregation` and documented. They can be extended to easily implement + custom `Aggregator`s. +- Made `Matrix` and `PSDMatrix` public. These type annotation classes are now importable from + `torchjd.linalg` and documented. Users can now subclass `Weighting[Matrix]` or + `Weighting[PSDMatrix]` to implement custom `Weighting`s. - Added getters and setters for the constructor parameters of all aggregators and weightings, so that they can be changed after initialization. This includes: `pref_vector`, `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; From 72338f6cecc016e6071c3b6e081a743bc216ec96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 8 May 2026 19:35:57 +0200 Subject: [PATCH 8/8] refactor: Import Matrix and PSDMatrix from torchjd.linalg throughout the codebase Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/_aggregator_bases.py | 3 ++- src/torchjd/aggregation/_aligned_mtl.py | 2 +- src/torchjd/aggregation/_cagrad.py | 2 +- src/torchjd/aggregation/_config.py | 2 +- src/torchjd/aggregation/_dualproj.py | 3 ++- src/torchjd/aggregation/_graddrop.py | 2 +- src/torchjd/aggregation/_gradvac.py | 2 +- src/torchjd/aggregation/_imtl_g.py | 2 +- src/torchjd/aggregation/_krum.py | 2 +- src/torchjd/aggregation/_mgda.py | 2 +- src/torchjd/aggregation/_pcgrad.py | 2 +- src/torchjd/aggregation/_upgrad.py | 3 ++- src/torchjd/aggregation/_utils/pref_vector.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 3 ++- src/torchjd/autogram/_engine.py | 3 ++- src/torchjd/autogram/_gramian_accumulator.py | 2 +- src/torchjd/autogram/_gramian_computer.py | 3 ++- src/torchjd/autogram/_jacobian_computer.py | 2 +- src/torchjd/autojac/_jac_to_grad.py | 3 ++- 19 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 8f2be2d19..9dd014ca2 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -2,7 +2,8 @@ from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix +from torchjd._linalg import compute_gramian, is_matrix +from torchjd.linalg import Matrix, PSDMatrix from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 61c6d7a9c..961102631 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -7,7 +7,7 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 35a4d1e83..82eb84ca9 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,6 +1,6 @@ from typing import cast -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import _GramianWeighting diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 10fbd9986..98ac08572 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -5,7 +5,7 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator from ._sum import SumWeighting diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index d0ab8d8c1..5ba3645c4 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,6 +1,7 @@ from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import normalize, regularize +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 81ebf8176..2b4941e21 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator from ._utils.non_differentiable import raise_non_differentiable_error diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 674d07878..e5fcdd4ba 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -5,8 +5,8 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 0539c726d..88a8a8090 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 4bdc6f3a2..5527ab500 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,7 +2,7 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import _GramianWeighting diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index da27a4b47..a013ca839 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import _GramianWeighting diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index b2d8a6feb..3d6e4a6f8 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 2d9e54dcf..e039a4ea6 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import normalize, regularize +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index caffabd9b..be87c3530 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,8 +1,8 @@ from torch import Tensor -from torchjd._linalg import Matrix from torchjd.aggregation._constant import ConstantWeighting from torchjd.aggregation._weighting_bases import Weighting +from torchjd.linalg import Matrix from .str import vector_to_str diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 8e62008d2..00eea54ab 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,8 @@ from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, PSDTensor, is_psd_tensor +from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd.linalg import Matrix, PSDMatrix _T = TypeVar("_T", contravariant=True, bound=Tensor) _FnInputT = TypeVar("_FnInputT", bound=Tensor) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0a93d2aa1..72112e802 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd._linalg import PSDMatrix, movedim, reshape +from torchjd._linalg import movedim, reshape +from torchjd.linalg import PSDMatrix from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index e9fe81f8d..432b73f0d 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,4 +1,4 @@ -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix class GramianAccumulator: diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index cdc7ce939..e44884971 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -4,8 +4,9 @@ from torch import Tensor from torch.utils._pytree import PyTree -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.autogram._jacobian_computer import JacobianComputer +from torchjd.linalg import Matrix, PSDMatrix class GramianComputer(ABC): diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index adc90b06a..b77c1ea98 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -8,7 +8,7 @@ from torch.overrides import is_tensor_like from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index bcebcb39d..1b259bd47 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -5,13 +5,14 @@ import torch from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.aggregation import ( Aggregator, GramianWeightedAggregator, WeightedAggregator, Weighting, ) +from torchjd.linalg import Matrix, PSDMatrix from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension