diff --git a/CHANGELOG.md b/CHANGELOG.md index 43ccaeb2b..00a906c60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,12 @@ changelog does not include internal changes that do not affect the user. ### Added -- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting` - public. These abstract base classes are now importable from `torchjd.aggregation` and documented. - They can be extended to easily implement custom `Weighting`s and `Aggregator`s. +- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are + now importable from `torchjd.aggregation` and documented. They can be extended to easily implement + custom `Aggregator`s. +- Made `Matrix` and `PSDMatrix` public. These type annotation classes are now importable from + `torchjd.linalg` and documented. Users can now subclass `Weighting[Matrix]` or + `Weighting[PSDMatrix]` to implement custom `Weighting`s. - Added getters and setters for the constructor parameters of all aggregators and weightings, so that they can be changed after initialization. This includes: `pref_vector`, `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 04a8de666..ff6e18112 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -19,12 +19,6 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Weighting :members: __call__ -.. autoclass:: torchjd.aggregation.MatrixWeighting - :members: __call__ - -.. autoclass:: torchjd.aggregation.GramianWeighting - :members: __call__ - .. autoclass:: torchjd.aggregation.GeneralizedWeighting :members: __call__ diff --git a/docs/source/docs/linalg/index.rst b/docs/source/docs/linalg/index.rst new file mode 100644 index 000000000..4446ccea7 --- /dev/null +++ b/docs/source/docs/linalg/index.rst @@ -0,0 +1,12 @@ +linalg +====== + +.. automodule:: torchjd.linalg + :no-members: + +.. toctree:: + :hidden: + :maxdepth: 1 + + matrix.rst + psd_matrix.rst diff --git a/docs/source/docs/linalg/matrix.rst b/docs/source/docs/linalg/matrix.rst new file mode 100644 index 000000000..165f3718a --- /dev/null +++ b/docs/source/docs/linalg/matrix.rst @@ -0,0 +1,6 @@ +:hide-toc: + +Matrix +====== + +.. autoclass:: torchjd.linalg.Matrix diff --git a/docs/source/docs/linalg/psd_matrix.rst b/docs/source/docs/linalg/psd_matrix.rst new file mode 100644 index 000000000..8ee262a23 --- /dev/null +++ b/docs/source/docs/linalg/psd_matrix.rst @@ -0,0 +1,6 @@ +:hide-toc: + +PSDMatrix +========= + +.. autoclass:: torchjd.linalg.PSDMatrix diff --git a/docs/source/index.rst b/docs/source/index.rst index 1b2bd970f..d8b14f830 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,3 +70,4 @@ TorchJD is open-source, under MIT License. The source code is available on docs/autogram/index.rst docs/autojac/index.rst docs/aggregation/index.rst + docs/linalg/index.rst diff --git a/src/torchjd/_linalg/_matrix.py b/src/torchjd/_linalg/_matrix.py index a7b5ce614..815a54f87 100644 --- a/src/torchjd/_linalg/_matrix.py +++ b/src/torchjd/_linalg/_matrix.py @@ -8,7 +8,16 @@ class Matrix(Tensor): - """Tensor with exactly 2 dimensions.""" + """ + Tensor with exactly 2 dimensions. + + Common examples include the Jacobian matrix J of shape ``[m, n]``, where m is the number of + objectives and n is the number of model parameters, and the Gramian of the Jacobian + G = J J^T of shape ``[m, m]``. + + .. note:: + This class should never be instantiated. It is only used for static type checking. + """ class PSDTensor(Tensor): @@ -20,7 +29,15 @@ class PSDTensor(Tensor): class PSDMatrix(PSDTensor, Matrix): - """Positive semi-definite matrix.""" + """ + Positive semi-definite matrix. + + A common example is the Gramian of the Jacobian G = J J^T of shape ``[m, m]``, where J is a + Jacobian matrix of shape ``[m, n]``. + + .. note:: + This class should never be instantiated. It is only used for static type checking. + """ def is_matrix(t: Tensor) -> TypeGuard[Matrix]: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index ec871e899..bb8892ff5 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -9,8 +9,9 @@ .. note:: Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights - from this Gramian using a :class:`~torchjd.aggregation.GramianWeighting`, and then combining the - rows of the Jacobian using these weights. For all of them, we provide both the + from this Gramian using a :class:`~torchjd.aggregation.Weighting` + [:class:`~torchjd.linalg.PSDMatrix`], and then combining the rows of the Jacobian using these + weights. For all of them, we provide both the :class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the :class:`~torchjd.aggregation.Weighting` interface (to be used in autogram). For the rest, we only provide the :class:`~torchjd.aggregation.Aggregator` @@ -80,7 +81,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import GeneralizedWeighting, GramianWeighting, MatrixWeighting, Weighting +from ._weighting_bases import GeneralizedWeighting, Weighting __all__ = [ "Aggregator", @@ -97,12 +98,10 @@ "GradVac", "GradVacWeighting", "GramianWeightedAggregator", - "GramianWeighting", "IMTLG", "IMTLGWeighting", "Krum", "KrumWeighting", - "MatrixWeighting", "Mean", "MeanWeighting", "MGDA", diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 7372b0027..9dd014ca2 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import cast from torch import Tensor, nn -from torchjd._linalg import Matrix, compute_gramian, is_matrix +from torchjd._linalg import compute_gramian, is_matrix +from torchjd.linalg import Matrix, PSDMatrix -from ._weighting_bases import GramianWeighting, MatrixWeighting +from ._weighting_bases import Weighting class Aggregator(nn.Module, ABC): @@ -48,12 +48,12 @@ def __str__(self) -> str: class WeightedAggregator(Aggregator): """ Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a - :class:`~torchjd.aggregation.MatrixWeighting` to it. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] to it. :param weighting: The object responsible for extracting the vector of weights from the matrix. """ - def __init__(self, weighting: MatrixWeighting) -> None: + def __init__(self, weighting: Weighting[Matrix]) -> None: super().__init__() self.weighting = weighting @@ -76,12 +76,13 @@ def forward(self, matrix: Matrix, /) -> Tensor: class GramianWeightedAggregator(WeightedAggregator): """ :class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input - Jacobian matrix before applying a :class:`~torchjd.aggregation.GramianWeighting` to it. + Jacobian matrix before applying a :class:`~torchjd.aggregation.Weighting` + [:class:`~torchjd.linalg.PSDMatrix`] to it. :param gramian_weighting: The object responsible for extracting the vector of weights from the gramian. """ - def __init__(self, gramian_weighting: GramianWeighting) -> None: - super().__init__(cast(MatrixWeighting, gramian_weighting << compute_gramian)) + def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: + super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 5ff9e5ed5..961102631 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -7,20 +7,20 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] -class AlignedMTLWeighting(GramianWeighting): +class AlignedMTLWeighting(_GramianWeighting): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.AlignedMTL`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index a0b09cb1c..82eb84ca9 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,9 +1,9 @@ from typing import cast -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting check_dependencies_are_installed(["cvxpy", "clarabel"]) @@ -18,10 +18,10 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGradWeighting(GramianWeighting): +class CAGradWeighting(_GramianWeighting): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.CAGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.CAGrad`. :param c: The scale of the radius of the ball constraint. :param norm_eps: A small value to avoid division by zero when normalizing. diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 10fbd9986..98ac08572 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -5,7 +5,7 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator from ._sum import SumWeighting diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 91a639ba6..54f973a22 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -2,13 +2,13 @@ from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import _MatrixWeighting -class ConstantWeighting(MatrixWeighting): +class ConstantWeighting(_MatrixWeighting): """ - :class:`~torchjd.aggregation.MatrixWeighting` that returns constant, pre-determined - weights. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] + that returns constant, pre-determined weights. :param weights: The weights to return at each call. """ diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index a9167648c..5ba3645c4 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,19 +1,20 @@ from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import normalize, regularize +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class DualProjWeighting(GramianWeighting): +class DualProjWeighting(_GramianWeighting): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.DualProj`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.DualProj`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index c693b8041..c3354f578 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator from ._utils.non_differentiable import raise_non_differentiable_error diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index b791f6808..e5fcdd4ba 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -5,19 +5,19 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class GradVacWeighting(GramianWeighting, Stateful): +class GradVacWeighting(_GramianWeighting, Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.GradVac`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.GradVac`. All required quantities (gradient norms, cosine similarities, and their updates after the vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index b1dd0cbff..21a7975f5 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,17 +1,17 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class IMTLGWeighting(GramianWeighting): +class IMTLGWeighting(_GramianWeighting): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.IMTLG`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.IMTLG`. """ def forward(self, gramian: PSDMatrix, /) -> Tensor: diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index db268e0c7..5527ab500 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,16 +2,16 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class KrumWeighting(GramianWeighting): +class KrumWeighting(_GramianWeighting): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.Krum`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.Krum`. :param n_byzantine: The number of rows of the input matrix that can come from an adversarial source. diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index e8b75e7cc..8f9f54570 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -2,12 +2,12 @@ from torch import Tensor from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import _MatrixWeighting -class MeanWeighting(MatrixWeighting): +class MeanWeighting(_MatrixWeighting): r""" - :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index ec9d0afc3..a013ca839 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,16 +1,16 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class MGDAWeighting(GramianWeighting): +class MGDAWeighting(_GramianWeighting): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.MGDA`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.MGDA`. :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. :param max_iters: The maximum number of iterations of the optimization loop. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 5be55afd5..63271d631 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -4,7 +4,7 @@ from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import _MatrixWeighting check_dependencies_are_installed(["cvxpy", "ecos"]) @@ -18,9 +18,10 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class _NashMTLWeighting(MatrixWeighting, Stateful): +class _NashMTLWeighting(_MatrixWeighting, Stateful): """ - :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.MatrixWeighting` that + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index fce7af24d..a796179b7 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -3,17 +3,17 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class PCGradWeighting(GramianWeighting): +class PCGradWeighting(_GramianWeighting): """ - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.PCGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.PCGrad`. """ def forward(self, gramian: PSDMatrix, /) -> Tensor: diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index f1e4010c1..aae7c8c8f 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -3,13 +3,13 @@ from torch.nn import functional as F from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import _MatrixWeighting -class RandomWeighting(MatrixWeighting): +class RandomWeighting(_MatrixWeighting): """ - :class:`~torchjd.aggregation.MatrixWeighting` that generates positive random weights - at each call. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] + that generates positive random weights at each call. """ def forward(self, matrix: Tensor, /) -> Tensor: diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 1a48ef41f..001960c25 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -2,12 +2,12 @@ from torch import Tensor from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import MatrixWeighting +from ._weighting_bases import _MatrixWeighting -class SumWeighting(MatrixWeighting): +class SumWeighting(_MatrixWeighting): r""" - :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 731986843..e039a4ea6 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,20 +1,21 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import normalize, regularize +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import GramianWeighting +from ._weighting_bases import _GramianWeighting -class UPGradWeighting(GramianWeighting): +class UPGradWeighting(_GramianWeighting): r""" - :class:`~torchjd.aggregation.GramianWeighting` giving the weights of - :class:`~torchjd.aggregation.UPGrad`. + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] + giving the weights of :class:`~torchjd.aggregation.UPGrad`. :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index caffabd9b..be87c3530 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,8 +1,8 @@ from torch import Tensor -from torchjd._linalg import Matrix from torchjd.aggregation._constant import ConstantWeighting from torchjd.aggregation._weighting_bases import Weighting +from torchjd.linalg import Matrix from .str import vector_to_str diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 2655aa28a..00eea54ab 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,8 @@ from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, PSDTensor, is_psd_tensor +from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd.linalg import Matrix, PSDMatrix _T = TypeVar("_T", contravariant=True, bound=Tensor) _FnInputT = TypeVar("_FnInputT", bound=Tensor) @@ -34,8 +35,6 @@ def __call__(self, stat: Tensor, /) -> Tensor: :param stat: The stat from which the weights must be extracted. """ - # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of - # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: @@ -59,6 +58,26 @@ def forward(self, stat: _T, /) -> Tensor: return self.weighting(self.fn(stat)) +class _MatrixWeighting(Weighting[Matrix]): + def __call__(self, matrix: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input matrix and applies all registered hooks. + + :param matrix: The matrix from which the weights must be extracted. + """ + return super().__call__(matrix) + + +class _GramianWeighting(Weighting[PSDMatrix]): + def __call__(self, gramian: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input Gramian and applies all registered hooks. + + :param gramian: The Gramian from which the weights must be extracted. + """ + return super().__call__(gramian) + + class GeneralizedWeighting(nn.Module, ABC): r""" Abstract base class for all weightings that operate on generalized Gramians. It has the role of @@ -84,31 +103,3 @@ def __call__(self, generalized_gramian: Tensor, /) -> Tensor: assert is_psd_tensor(generalized_gramian) return super().__call__(generalized_gramian) - - -# Subclasses used only to redefine the __call__ method with more specific parameter names and -# docstrings. Note that MatrixWeighting <: Weighting[Matrix] <: Weighting[PSDMatrix], because -# PSDMatrix <: Matrix and Weighting[_T] is contravariant with _T. -# Also note that we don't have: MatrixWeighting <: GramianWeighting. GramianWeighting is not -# just an alias of Weighting[PSDMatrix], it's a subtype of it. So the type Weighting[PSDMatrix] -# should still be used when we expect a Weighting that works at least on PSD matrices. - - -class MatrixWeighting(Weighting[Matrix]): - def __call__(self, matrix: Tensor, /) -> Tensor: - """ - Computes the vector of weights from the input matrix and applies all registered hooks. - - :param matrix: The matrix from which the weights must be extracted. - """ - return super().__call__(matrix) - - -class GramianWeighting(Weighting[PSDMatrix]): - def __call__(self, gramian: Tensor, /) -> Tensor: - """ - Computes the vector of weights from the input gramian and applies all registered hooks. - - :param gramian: The gramian from which the weights must be extracted. - """ - return super().__call__(gramian) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0a93d2aa1..72112e802 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd._linalg import PSDMatrix, movedim, reshape +from torchjd._linalg import movedim, reshape +from torchjd.linalg import PSDMatrix from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index e9fe81f8d..432b73f0d 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,4 +1,4 @@ -from torchjd._linalg import PSDMatrix +from torchjd.linalg import PSDMatrix class GramianAccumulator: diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index cdc7ce939..e44884971 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -4,8 +4,9 @@ from torch import Tensor from torch.utils._pytree import PyTree -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.autogram._jacobian_computer import JacobianComputer +from torchjd.linalg import Matrix, PSDMatrix class GramianComputer(ABC): diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index c5d7ad4c6..4caa4455c 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -8,7 +8,7 @@ from torch.overrides import is_tensor_like from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only -from torchjd._linalg import Matrix +from torchjd.linalg import Matrix # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index bcebcb39d..1b259bd47 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -5,13 +5,14 @@ import torch from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.aggregation import ( Aggregator, GramianWeightedAggregator, WeightedAggregator, Weighting, ) +from torchjd.linalg import Matrix, PSDMatrix from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py new file mode 100644 index 000000000..f8238104e --- /dev/null +++ b/src/torchjd/linalg/__init__.py @@ -0,0 +1,8 @@ +""" +This module provides type annotation classes representing tensors with specific structural +properties. +""" + +from torchjd._linalg._matrix import Matrix, PSDMatrix + +__all__ = ["Matrix", "PSDMatrix"]