From 336b47cc11d0383ae2986a7a7bece20673d6eeeb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 10:05:51 +0200 Subject: [PATCH 1/5] Add DualConeProjector --- src/torchjd/_linalg/_dual_cone.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/torchjd/_linalg/_dual_cone.py diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py new file mode 100644 index 00000000..4ff1167e --- /dev/null +++ b/src/torchjd/_linalg/_dual_cone.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +from torch import Tensor + +from ._matrix import PSDMatrix + + +class DualConeProjector(ABC): + @abstractmethod + def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: + r""" + Computes the weights `w` of the projection of `J^T u` onto the dual cone of + the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that + satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. + + By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic + program: + minimize v^T G v + subject to u \preceq v + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + + :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. + :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. + :return: A tensor of projection weights with the same shape as `U`. + """ From f1076a73f22318b89bbe71fe811a0dc5ea245e03 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 10:28:41 +0200 Subject: [PATCH 2/5] Implement and use QPSolverBased. Set as default. --- src/torchjd/_linalg/__init__.py | 4 ++ src/torchjd/_linalg/_dual_cone.py | 44 ++++++++++++- src/torchjd/aggregation/_dualproj.py | 17 +++-- src/torchjd/aggregation/_upgrad.py | 17 +++-- src/torchjd/aggregation/_utils/dual_cone.py | 62 ------------------- tests/unit/aggregation/test_dualproj.py | 12 ++-- tests/unit/aggregation/test_pcgrad.py | 2 +- tests/unit/aggregation/test_upgrad.py | 14 +++-- .../_utils => linalg}/test_dual_cone.py | 34 ++++++---- 9 files changed, 102 insertions(+), 104 deletions(-) delete mode 100644 src/torchjd/aggregation/_utils/dual_cone.py rename tests/unit/{aggregation/_utils => linalg}/test_dual_cone.py (72%) diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b..2035b7b2 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,3 +1,4 @@ +from ._dual_cone import DualConeProjector, QPSolverBased, projector_or_default from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor @@ -15,4 +16,7 @@ "flatten", "reshape", "movedim", + "DualConeProjector", + "QPSolverBased", + "projector_or_default", ] diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 4ff1167e..62aabbf3 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod +from typing import Literal, TypeAlias +import numpy as np +import torch +from qpsolvers import solve_qp from torch import Tensor from ._matrix import PSDMatrix @@ -7,7 +11,7 @@ class DualConeProjector(ABC): @abstractmethod - def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: + def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that @@ -25,3 +29,41 @@ def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. :return: A tensor of projection weights with the same shape as `U`. """ + + +def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: + if projector is None: + return QPSolverBased("quadprog") + return projector + + +class QPSolverBased(DualConeProjector): + SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] + + def __init__(self, solver: SUPPORTED_SOLVER) -> None: + self.solver = solver + + def project_weights(self, U: Tensor, G: Tensor) -> Tensor: + + G_ = _to_array(G) + U_ = _to_array(U) + + W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) + + return torch.as_tensor(W, device=G.device, dtype=G.dtype) + + def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: + + m = G.shape[0] + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver) + + if w is None: # This may happen when G has large values. + raise ValueError("Failed to solve the quadratic programming problem.") + + return w + + +def _to_array(tensor: Tensor) -> np.ndarray: + """Transforms a tensor into a numpy array with float64 dtype.""" + + return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e379f127..3dc33d05 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,11 @@ from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -32,18 +31,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) + w = self.projector.project_weights(u, G) return w @property @@ -102,12 +101,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -137,7 +134,7 @@ def reg_eps(self, value: float) -> None: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c1e4807e..29a4e654 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,13 +1,12 @@ import torch from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -33,18 +32,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + W = self.projector.project_weights(U, G) return torch.sum(W, dim=0) @property @@ -105,12 +104,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -140,7 +137,7 @@ def reg_eps(self, value: float) -> None: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py deleted file mode 100644 index b076366b..00000000 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Literal, TypeAlias - -import numpy as np -import torch -from qpsolvers import solve_qp -from torch import Tensor - -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - - -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: - """ - Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the - rows of a matrix whose Gramian is provided. - - :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. - :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. - :return: A tensor of projection weights with the same shape as `U`. - """ - - G_ = _to_array(G) - U_ = _to_array(U) - - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) - - return torch.as_tensor(W, device=G.device, dtype=G.dtype) - - -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: - r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. - - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: - minimize v^T G v - subject to u \preceq v - - Reference: - [1] `Jacobian Descent For Multi-Objective Optimization `_. - - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to - project. - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be - symmetric and positive definite. - :param solver: The quadratic programming solver to use. - """ - - m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) - - if w is None: # This may happen when G has large values. - raise ValueError("Failed to solve the quadratic programming problem.") - - return w - - -def _to_array(tensor: Tensor) -> np.ndarray: - """Transforms a tensor into a numpy array with float64 dtype.""" - - return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 34fe8d46..7852fa59 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, DualProj from torchjd.aggregation._dualproj import DualProjWeighting @@ -47,9 +48,12 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") + A = DualProj( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=" + "QPSolverBased('quadprog'))" ) assert str(A) == "DualProj" @@ -57,11 +61,11 @@ def test_representations() -> None: pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d..f7961e8c 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -55,7 +55,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - solver="quadprog", + projector="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 075680a0..c04d32f1 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, UPGrad from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,19 +54,24 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + A = UPGrad( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) + assert ( + repr(A) + == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased('quadprog'))" + ) assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "UPGrad([1., 2., 3.])" diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py similarity index 72% rename from tests/unit/aggregation/_utils/test_dual_cone.py rename to tests/unit/linalg/test_dual_cone.py index 68a8a75d..c39029fa 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -4,11 +4,12 @@ from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd._linalg import DualConeProjector, QPSolverBased +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]) -> None: +def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -34,7 +35,7 @@ def test_solution_weights(shape: tuple[int, int]) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = projector.project_weights(u, G) dual_gap = w - u # Dual feasibility @@ -52,9 +53,12 @@ def test_solution_weights(shape: tuple[int, int]) -> None: assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: +def test_scale_invariant( + projector: DualConeProjector, shape: tuple[int, int], scaling: float +) -> None: """ Tests that `_project_weights` is invariant under scaling. """ @@ -63,14 +67,15 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = projector.project_weights(u, G) + w_scaled = projector.project_weights(u, scaling * G) assert_close(w_scaled, w) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]) -> None: +def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -82,16 +87,21 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: G = matrix @ matrix.T - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = projector.project_weights(U_tensor, G) + W_matrix = projector.project_weights(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure() -> None: - """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" +def test_qp_solver_based_failure() -> None: + """ + Tests that `QPSolverBased._project_weight_vector` raises an error when the input G has too large + values. + """ + + projector = QPSolverBased("quadprog") large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") + projector._project_weight_vector(np.ones(10), large_G) From 3343ebb8dec5386b9cf5190dd7dce56fa9c3531d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:16:23 +0200 Subject: [PATCH 3/5] add getters and setters for projectors in UPGrad and DualProj --- src/torchjd/_linalg/_dual_cone.py | 3 +++ src/torchjd/aggregation/_dualproj.py | 18 +++++++++++++++++- src/torchjd/aggregation/_upgrad.py | 18 +++++++++++++++++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 62aabbf3..7cbb673a 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -43,6 +43,9 @@ class QPSolverBased(DualConeProjector): def __init__(self, solver: SUPPORTED_SOLVER) -> None: self.solver = solver + def __repr__(self) -> str: + return f"QPSolverBased({repr(self.solver)})" + def project_weights(self, U: Tensor, G: Tensor) -> Tensor: G_ = _to_array(G) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 3dc33d05..e6117a77 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -76,6 +76,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class DualProj(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -131,10 +139,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 29a4e654..a2d28515 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -79,6 +79,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class UPGrad(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -134,10 +142,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: From 25d0c916e468539b28e38bc06c3943745c2be4cc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:43:45 +0200 Subject: [PATCH 4/5] fix typing --- tests/unit/aggregation/test_pcgrad.py | 4 ++-- tests/unit/linalg/test_dual_cone.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index f7961e8c..819f2be0 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import compute_gramian +from torchjd._linalg import QPSolverBased, compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -55,7 +55,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - projector="quadprog", + projector=QPSolverBased("quadprog"), ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/linalg/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py index c39029fa..6faa25af 100644 --- a/tests/unit/linalg/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -1,10 +1,12 @@ +from typing import cast + import numpy as np import torch from pytest import mark, raises from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd._linalg import DualConeProjector, QPSolverBased +from torchjd._linalg import DualConeProjector, PSDMatrix, QPSolverBased, compute_gramian @mark.parametrize("projector", [QPSolverBased("quadprog")]) @@ -32,7 +34,7 @@ def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) u = rand_(shape[0]) w = projector.project_weights(u, G) @@ -64,11 +66,12 @@ def test_scale_invariant( """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) + scaled_G = cast(PSDMatrix, scaling * G) u = rand_(shape[0]) w = projector.project_weights(u, G) - w_scaled = projector.project_weights(u, scaling * G) + w_scaled = projector.project_weights(u, scaled_G) assert_close(w_scaled, w) @@ -85,7 +88,7 @@ def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ... U_tensor = randn_(shape) U_matrix = U_tensor.reshape([-1, shape[-1]]) - G = matrix @ matrix.T + G = compute_gramian(matrix) W_tensor = projector.project_weights(U_tensor, G) W_matrix = projector.project_weights(U_matrix, G) From c63571d82abe3dd6c963495a1899b9ea0221b3d0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:46:58 +0200 Subject: [PATCH 5/5] Make PCGrad test use default Projector --- tests/unit/aggregation/test_pcgrad.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 819f2be0..ca939116 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import QPSolverBased, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -55,7 +55,6 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - projector=QPSolverBased("quadprog"), ) result = pc_grad_weighting(gramian)