diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e379f1276..45a619dab 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -43,7 +43,7 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) + w = project_weights[self.solver](u, G) return w @property diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c1e4807e3..febed53c5 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -44,7 +44,7 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + W = project_weights[self.solver](U, G) return torch.sum(W, dim=0) @property diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py index b076366be..7be551c63 100644 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ b/src/torchjd/aggregation/_utils/dual_cone.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from typing import Literal, TypeAlias import numpy as np @@ -8,7 +9,12 @@ SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: +project_weights: dict[str, Callable[[Tensor, Tensor], Tensor]] = { + "quadprog": lambda U, G: project_weights_qp_solvers(U, G, "quadprog") +} + + +def project_weights_qp_solvers(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: """ Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the rows of a matrix whose Gramian is provided. @@ -22,12 +28,16 @@ def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: G_ = _to_array(G) U_ = _to_array(U) - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) + W = np.apply_along_axis( + lambda u: _project_weight_vector_qp_solvers(u, G_, solver), axis=-1, arr=U_ + ) return torch.as_tensor(W, device=G.device, dtype=G.dtype) -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: +def _project_weight_vector_qp_solvers( + u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER +) -> np.ndarray: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/aggregation/_utils/test_dual_cone.py index 68a8a75d7..16febff0b 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/aggregation/_utils/test_dual_cone.py @@ -4,7 +4,7 @@ from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd.aggregation._utils.dual_cone import _project_weight_vector_qp_solvers, project_weights @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) @@ -34,7 +34,7 @@ def test_solution_weights(shape: tuple[int, int]) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = project_weights["quadprog"](u, G) dual_gap = w - u # Dual feasibility @@ -63,8 +63,8 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = project_weights["quadprog"](u, G) + w_scaled = project_weights["quadprog"](u, scaling * G) assert_close(w_scaled, w) @@ -82,8 +82,8 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: G = matrix @ matrix.T - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = project_weights["quadprog"](U_tensor, G) + W_matrix = project_weights["quadprog"](U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) @@ -94,4 +94,4 @@ def test_project_weight_vector_failure() -> None: large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") + _project_weight_vector_qp_solvers(np.ones(10), large_G, "quadprog")