-
Notifications
You must be signed in to change notification settings - Fork 16
refactor!: Add DualConeProjector
#678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
336b47c
f1076a7
3343ebb
25d0c91
c63571d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Literal, TypeAlias | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from qpsolvers import solve_qp | ||
| from torch import Tensor | ||
|
|
||
| from ._matrix import PSDMatrix | ||
|
|
||
|
|
||
| class DualConeProjector(ABC): | ||
| @abstractmethod | ||
| def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: | ||
| r""" | ||
| Computes the weights `w` of the projection of `J^T u` onto the dual cone of | ||
| the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that | ||
| satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. | ||
|
|
||
| By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic | ||
| program: | ||
| minimize v^T G v | ||
| subject to u \preceq v | ||
|
|
||
| Reference: | ||
| [1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_. | ||
|
|
||
| :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. | ||
| :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. | ||
| :return: A tensor of projection weights with the same shape as `U`. | ||
| """ | ||
|
|
||
|
|
||
| def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: | ||
| if projector is None: | ||
| return QPSolverBased("quadprog") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think quadprog should be a subclass of QPSolverBased. If we don't do that, we'll be unable to use solver-specific extra parameters. |
||
| return projector | ||
|
|
||
|
|
||
| class QPSolverBased(DualConeProjector): | ||
| SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] | ||
|
|
||
| def __init__(self, solver: SUPPORTED_SOLVER) -> None: | ||
| self.solver = solver | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"QPSolverBased({repr(self.solver)})" | ||
|
|
||
| def project_weights(self, U: Tensor, G: Tensor) -> Tensor: | ||
|
|
||
| G_ = _to_array(G) | ||
| U_ = _to_array(U) | ||
|
|
||
| W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) | ||
|
|
||
| return torch.as_tensor(W, device=G.device, dtype=G.dtype) | ||
|
|
||
| def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: | ||
|
|
||
| m = G.shape[0] | ||
| w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver) | ||
|
|
||
| if w is None: # This may happen when G has large values. | ||
| raise ValueError("Failed to solve the quadratic programming problem.") | ||
|
|
||
| return w | ||
|
|
||
|
|
||
| def _to_array(tensor: Tensor) -> np.ndarray: | ||
| """Transforms a tensor into a numpy array with float64 dtype.""" | ||
|
|
||
| return tensor.cpu().detach().numpy().astype(np.float64) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,11 @@ | ||
| from torch import Tensor | ||
|
|
||
| from torchjd._linalg import normalize, regularize | ||
| from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize | ||
| from torchjd.linalg import PSDMatrix | ||
|
|
||
| from ._aggregator_bases import GramianWeightedAggregator | ||
| from ._mean import MeanWeighting | ||
| from ._mixins import _NonDifferentiable | ||
| from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights | ||
| from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting | ||
| from ._weighting_bases import _GramianWeighting | ||
|
|
||
|
|
@@ -32,18 +31,18 @@ | |
| pref_vector: Tensor | None = None, | ||
| norm_eps: float = 0.0001, | ||
| reg_eps: float = 0.0001, | ||
| solver: SUPPORTED_SOLVER = "quadprog", | ||
| projector: DualConeProjector | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.pref_vector = pref_vector | ||
| self.norm_eps = norm_eps | ||
| self.reg_eps = reg_eps | ||
| self.solver: SUPPORTED_SOLVER = solver | ||
| self.projector = projector_or_default(projector) | ||
|
|
||
| def forward(self, gramian: PSDMatrix, /) -> Tensor: | ||
| u = self.weighting(gramian) | ||
| G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the regularization and normalization should become part of the projector, because the requiered amount of regularization or projection may vary per solver. Norm_eps and reg_eps should thus also be given to the projector directly I think. |
||
| w = project_weights(u, G, self.solver) | ||
| w = self.projector.project_weights(u, G) | ||
| return w | ||
|
|
||
| @property | ||
|
|
@@ -77,6 +76,14 @@ | |
|
|
||
| self._reg_eps = value | ||
|
|
||
| @property | ||
| def projector(self) -> DualConeProjector: | ||
| return self._projector | ||
|
|
||
| @projector.setter | ||
| def projector(self, value: DualConeProjector | None) -> None: | ||
| self._projector = projector_or_default(value) | ||
|
|
||
|
|
||
| class DualProj(_NonDifferentiable, GramianWeightedAggregator): | ||
| r""" | ||
|
|
@@ -102,12 +109,10 @@ | |
| pref_vector: Tensor | None = None, | ||
| norm_eps: float = 0.0001, | ||
| reg_eps: float = 0.0001, | ||
| solver: SUPPORTED_SOLVER = "quadprog", | ||
| projector: DualConeProjector | None = None, | ||
| ) -> None: | ||
| self._solver: SUPPORTED_SOLVER = solver | ||
|
|
||
| super().__init__( | ||
| DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), | ||
| DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), | ||
| ) | ||
|
|
||
| @property | ||
|
|
@@ -134,10 +139,18 @@ | |
| def reg_eps(self, value: float) -> None: | ||
| self.gramian_weighting.reg_eps = value | ||
|
|
||
| @property | ||
| def projector(self) -> DualConeProjector: | ||
| return self.gramian_weighting.projector | ||
|
|
||
| @projector.setter | ||
| def projector(self, value: DualConeProjector | None) -> None: | ||
| self.gramian_weighting.projector = value | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" | ||
| f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" | ||
| f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" | ||
| ) | ||
|
|
||
| def __str__(self) -> str: | ||
|
|
||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to
__call__?