Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user.

### Changed

- `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`,
even when their optional dependencies are not installed. Attempting to instantiate them without the
required dependencies now raises an `ImportError` with installation instructions, instead of
raising an `ImportError` at import time.
- Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG,
GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors
that require gradients. Their forward pass is now wrapped in `torch.no_grad()`, so attempting to
Expand Down
22 changes: 5 additions & 17 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator
from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting
from ._cagrad import CAGrad, CAGradWeighting
from ._config import ConFIG
from ._constant import Constant, ConstantWeighting
from ._dualproj import DualProj, DualProjWeighting
Expand All @@ -73,20 +74,20 @@
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sum import Sum, SumWeighting
from ._trimmed_mean import TrimmedMean
from ._upgrad import UPGrad, UPGradWeighting
from ._utils.check_dependencies import (
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
)
from ._weighting_bases import GeneralizedWeighting, Weighting

__all__ = [
"Aggregator",
"AlignedMTL",
"AlignedMTLWeighting",
"CAGrad",
"CAGradWeighting",
"ConFIG",
"Constant",
"ConstantWeighting",
Expand All @@ -106,6 +107,7 @@
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"NashMTL",
"PCGrad",
"PCGradWeighting",
"Random",
Expand All @@ -119,17 +121,3 @@
"WeightedAggregator",
"Weighting",
]

try:
from ._cagrad import CAGrad, CAGradWeighting

__all__ += ["CAGrad", "CAGradWeighting"]
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
pass

try:
from ._nash_mtl import NashMTL

__all__ += ["NashMTL"]
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
pass
27 changes: 13 additions & 14 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import contextlib
from typing import cast

from torchjd.linalg import PSDMatrix

from ._mixins import _NonDifferentiable
from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _GramianWeighting

check_dependencies_are_installed(["cvxpy", "clarabel"])

import cvxpy as cp
import numpy as np
import torch
from torch import Tensor

from torchjd._linalg import normalize
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mixins import _NonDifferentiable, _WithOptionalDeps
from ._weighting_bases import _GramianWeighting

with contextlib.suppress(ImportError):
import cvxpy as cp


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class CAGradWeighting(_NonDifferentiable, _GramianWeighting):
class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
_REQUIRED_DEPS = ["cvxpy", "clarabel"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.CAGrad`.
Expand Down Expand Up @@ -103,10 +103,9 @@ class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
This aggregator requires optional dependencies. When they are not installed, instantiating
it raises an :class:`ImportError` with installation instructions.
To install them, use ``pip install "torchjd[cagrad]"``.
"""

gramian_weighting: CAGradWeighting
Expand Down
27 changes: 27 additions & 0 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any

import torch
from torch import nn


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
are not installed.

Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).

.. warning::
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
runs before any base class that uses the optional dependencies.
"""

_REQUIRED_DEPS: list[str]
_INSTALL_HINT: str

def __init__(self, *args: Any, **kwargs: Any) -> None:
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
if missing:
raise ImportError(
f"{self.__class__.__name__} requires {missing} to be installed. "
f"{self._INSTALL_HINT}"
)
super().__init__(*args, **kwargs)


class Stateful(ABC):
"""Mixin adding a reset method."""

Expand Down
27 changes: 15 additions & 12 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from __future__ import annotations

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _MatrixWeighting

check_dependencies_are_installed(["cvxpy", "ecos"])
import contextlib

import cvxpy as cp
import numpy as np
import torch
from cvxpy import Expression, SolverError
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import _MatrixWeighting

with contextlib.suppress(ImportError):
import cvxpy as cp
from cvxpy import Expression, SolverError


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting):
class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
_REQUIRED_DEPS = ["cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
Expand Down Expand Up @@ -215,10 +219,9 @@ class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
:param optim_niter: The number of iterations of the underlying optimization process.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[nash_mtl]"``.
This aggregator requires optional dependencies. When they are not installed, instantiating
it raises an :class:`ImportError` with installation instructions.
To install them, use ``pip install "torchjd[nash_mtl]"``.

.. warning::
This implementation was adapted from the `official implementation
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import pytest

pytest.importorskip("cvxpy")
pytest.importorskip("clarabel")

from contextlib import nullcontext as does_not_raise

from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

try:
from torchjd.aggregation import CAGrad
from torchjd.aggregation._cagrad import CAGradWeighting
except ImportError:
import pytest

pytest.skip("CAGrad dependencies not installed", allow_module_level=True)
from torchjd.aggregation import CAGrad
from torchjd.aggregation._cagrad import CAGradWeighting

from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
from ._inputs import scaled_matrices, typical_matrices
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/aggregation/test_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest

pytest.importorskip("cvxpy")
pytest.importorskip("ecos")

from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import ones_, randn_, tensor_

try:
from torchjd.aggregation import NashMTL
from torchjd.aggregation._nash_mtl import _NashMTLWeighting
except ImportError:
import pytest

pytest.skip("NashMTL dependencies not installed", allow_module_level=True)
from torchjd.aggregation import NashMTL
from torchjd.aggregation._nash_mtl import _NashMTLWeighting

from ._asserts import assert_expected_structure, assert_non_differentiable
from ._inputs import nash_mtl_matrices
Expand Down
Loading