diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b3..2035b7b23 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,3 +1,4 @@ +from ._dual_cone import DualConeProjector, QPSolverBased, projector_or_default from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor @@ -15,4 +16,7 @@ "flatten", "reshape", "movedim", + "DualConeProjector", + "QPSolverBased", + "projector_or_default", ] diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py new file mode 100644 index 000000000..7cbb673a3 --- /dev/null +++ b/src/torchjd/_linalg/_dual_cone.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from typing import Literal, TypeAlias + +import numpy as np +import torch +from qpsolvers import solve_qp +from torch import Tensor + +from ._matrix import PSDMatrix + + +class DualConeProjector(ABC): + @abstractmethod + def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: + r""" + Computes the weights `w` of the projection of `J^T u` onto the dual cone of + the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that + satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. + + By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic + program: + minimize v^T G v + subject to u \preceq v + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + + :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. + :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. + :return: A tensor of projection weights with the same shape as `U`. + """ + + +def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: + if projector is None: + return QPSolverBased("quadprog") + return projector + + +class QPSolverBased(DualConeProjector): + SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] + + def __init__(self, solver: SUPPORTED_SOLVER) -> None: + self.solver = solver + + def __repr__(self) -> str: + return f"QPSolverBased({repr(self.solver)})" + + def project_weights(self, U: Tensor, G: Tensor) -> Tensor: + + G_ = _to_array(G) + U_ = _to_array(U) + + W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) + + return torch.as_tensor(W, device=G.device, dtype=G.dtype) + + def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: + + m = G.shape[0] + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver) + + if w is None: # This may happen when G has large values. + raise ValueError("Failed to solve the quadratic programming problem.") + + return w + + +def _to_array(tensor: Tensor) -> np.ndarray: + """Transforms a tensor into a numpy array with float64 dtype.""" + + return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e379f1276..e6117a77e 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,11 @@ from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -32,18 +31,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) + w = self.projector.project_weights(u, G) return w @property @@ -77,6 +76,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class DualProj(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -102,12 +109,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -134,10 +139,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c1e4807e3..a2d285150 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,13 +1,12 @@ import torch from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -33,18 +32,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + W = self.projector.project_weights(U, G) return torch.sum(W, dim=0) @property @@ -80,6 +79,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class UPGrad(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -105,12 +112,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -137,10 +142,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py deleted file mode 100644 index b076366be..000000000 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Literal, TypeAlias - -import numpy as np -import torch -from qpsolvers import solve_qp -from torch import Tensor - -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - - -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: - """ - Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the - rows of a matrix whose Gramian is provided. - - :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. - :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. - :return: A tensor of projection weights with the same shape as `U`. - """ - - G_ = _to_array(G) - U_ = _to_array(U) - - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) - - return torch.as_tensor(W, device=G.device, dtype=G.dtype) - - -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: - r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. - - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: - minimize v^T G v - subject to u \preceq v - - Reference: - [1] `Jacobian Descent For Multi-Objective Optimization `_. - - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to - project. - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be - symmetric and positive definite. - :param solver: The quadratic programming solver to use. - """ - - m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) - - if w is None: # This may happen when G has large values. - raise ValueError("Failed to solve the quadratic programming problem.") - - return w - - -def _to_array(tensor: Tensor) -> np.ndarray: - """Transforms a tensor into a numpy array with float64 dtype.""" - - return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 34fe8d462..7852fa593 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, DualProj from torchjd.aggregation._dualproj import DualProjWeighting @@ -47,9 +48,12 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") + A = DualProj( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=" + "QPSolverBased('quadprog'))" ) assert str(A) == "DualProj" @@ -57,11 +61,11 @@ def test_representations() -> None: pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d3..ca9391165 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -55,7 +55,6 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - solver="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 075680a02..c04d32f15 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, UPGrad from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,19 +54,24 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + A = UPGrad( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) + assert ( + repr(A) + == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased('quadprog'))" + ) assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "UPGrad([1., 2., 3.])" diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py similarity index 68% rename from tests/unit/aggregation/_utils/test_dual_cone.py rename to tests/unit/linalg/test_dual_cone.py index 68a8a75d7..6faa25af1 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -1,14 +1,17 @@ +from typing import cast + import numpy as np import torch from pytest import mark, raises from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd._linalg import DualConeProjector, PSDMatrix, QPSolverBased, compute_gramian +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]) -> None: +def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -31,10 +34,10 @@ def test_solution_weights(shape: tuple[int, int]) -> None: """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = projector.project_weights(u, G) dual_gap = w - u # Dual feasibility @@ -52,25 +55,30 @@ def test_solution_weights(shape: tuple[int, int]) -> None: assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: +def test_scale_invariant( + projector: DualConeProjector, shape: tuple[int, int], scaling: float +) -> None: """ Tests that `_project_weights` is invariant under scaling. """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) + scaled_G = cast(PSDMatrix, scaling * G) u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = projector.project_weights(u, G) + w_scaled = projector.project_weights(u, scaled_G) assert_close(w_scaled, w) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]) -> None: +def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -80,18 +88,23 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: U_tensor = randn_(shape) U_matrix = U_tensor.reshape([-1, shape[-1]]) - G = matrix @ matrix.T + G = compute_gramian(matrix) - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = projector.project_weights(U_tensor, G) + W_matrix = projector.project_weights(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure() -> None: - """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" +def test_qp_solver_based_failure() -> None: + """ + Tests that `QPSolverBased._project_weight_vector` raises an error when the input G has too large + values. + """ + + projector = QPSolverBased("quadprog") large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") + projector._project_weight_vector(np.ones(10), large_G)