From 57df9b9d852ea3f94e790eb61c842ef4d2ed223a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 11 May 2026 17:34:21 +0200 Subject: [PATCH 1/2] refactor(aggregation): Make NashMTL and CAGrad always importable (#678) Add _WithOptionalDeps mixin that raises ImportError at instantiation time when optional dependencies are missing, replacing the module-level guard that previously prevented import altogether. Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/__init__.py | 22 +++++--------------- src/torchjd/aggregation/_cagrad.py | 27 ++++++++++++------------- src/torchjd/aggregation/_mixins.py | 27 +++++++++++++++++++++++++ src/torchjd/aggregation/_nash_mtl.py | 27 ++++++++++++++----------- tests/unit/aggregation/test_cagrad.py | 14 ++++++------- tests/unit/aggregation/test_nash_mtl.py | 14 ++++++------- 6 files changed, 74 insertions(+), 57 deletions(-) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index bb8892ff..0299bfc3 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -62,6 +62,7 @@ from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting +from ._cagrad import CAGrad, CAGradWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting from ._dualproj import DualProj, DualProjWeighting @@ -73,20 +74,20 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad, UPGradWeighting -from ._utils.check_dependencies import ( - OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, -) from ._weighting_bases import GeneralizedWeighting, Weighting __all__ = [ "Aggregator", "AlignedMTL", "AlignedMTLWeighting", + "CAGrad", + "CAGradWeighting", "ConFIG", "Constant", "ConstantWeighting", @@ -106,6 +107,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "NashMTL", "PCGrad", "PCGradWeighting", "Random", @@ -119,17 +121,3 @@ "WeightedAggregator", "Weighting", ] - -try: - from ._cagrad import CAGrad, CAGradWeighting - - __all__ += ["CAGrad", "CAGradWeighting"] -except _OptionalDepsNotInstalledError: # The required dependencies are not installed - pass - -try: - from ._nash_mtl import NashMTL - - __all__ += ["NashMTL"] -except _OptionalDepsNotInstalledError: # The required dependencies are not installed - pass diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 05818399..f30cf691 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,25 +1,25 @@ +import contextlib from typing import cast -from torchjd.linalg import PSDMatrix - -from ._mixins import _NonDifferentiable -from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import _GramianWeighting - -check_dependencies_are_installed(["cvxpy", "clarabel"]) - -import cvxpy as cp import numpy as np import torch from torch import Tensor from torchjd._linalg import normalize +from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator +from ._mixins import _NonDifferentiable, _WithOptionalDeps +from ._weighting_bases import _GramianWeighting + +with contextlib.suppress(ImportError): + import cvxpy as cp # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class CAGradWeighting(_NonDifferentiable, _GramianWeighting): +class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting): + _REQUIRED_DEPS = ["cvxpy", "clarabel"] + _INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"' """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.CAGrad`. @@ -103,10 +103,9 @@ class CAGrad(_NonDifferentiable, GramianWeightedAggregator): :param norm_eps: A small value to avoid division by zero when normalizing. .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[cagrad]"``. + This aggregator requires optional dependencies. When they are not installed, instantiating + it raises an :class:`ImportError` with installation instructions. + To install them, use ``pip install "torchjd[cagrad]"``. """ gramian_weighting: CAGradWeighting diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 29bf5592..6856b963 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,10 +1,37 @@ from abc import ABC, abstractmethod +from importlib.util import find_spec from typing import Any import torch from torch import nn +class _WithOptionalDeps: + """ + Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies + are not installed. + + Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via + :func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message). + + .. warning:: + This mixin must appear **first** in the inheritance list so that its :meth:`__init__` + runs before any base class that uses the optional dependencies. + """ + + _REQUIRED_DEPS: list[str] + _INSTALL_HINT: str + + def __init__(self, *args: Any, **kwargs: Any) -> None: + missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None] + if missing: + raise ImportError( + f"{self.__class__.__name__} requires {missing} to be installed. " + f"{self._INSTALL_HINT}" + ) + super().__init__(*args, **kwargs) + + class Stateful(ABC): """Mixin adding a reset method.""" diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 99356fc9..47c87cd4 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,24 +1,28 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from __future__ import annotations -from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import _MatrixWeighting - -check_dependencies_are_installed(["cvxpy", "ecos"]) +import contextlib -import cvxpy as cp import numpy as np import torch -from cvxpy import Expression, SolverError from torch import Tensor +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps + from ._aggregator_bases import WeightedAggregator +from ._weighting_bases import _MatrixWeighting + +with contextlib.suppress(ImportError): + import cvxpy as cp + from cvxpy import Expression, SolverError # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. -class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting): +class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting): + _REQUIRED_DEPS = ["cvxpy", "ecos"] + _INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"' """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that @@ -215,10 +219,9 @@ class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator): :param optim_niter: The number of iterations of the underlying optimization process. .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[nash_mtl]"``. + This aggregator requires optional dependencies. When they are not installed, instantiating + it raises an :class:`ImportError` with installation instructions. + To install them, use ``pip install "torchjd[nash_mtl]"``. .. warning:: This implementation was adapted from the `official implementation diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index 9128899f..56f96f33 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -1,3 +1,8 @@ +import pytest + +pytest.importorskip("cvxpy") +pytest.importorskip("clarabel") + from contextlib import nullcontext as does_not_raise from pytest import mark, raises @@ -5,13 +10,8 @@ from utils.contexts import ExceptionContext from utils.tensors import ones_ -try: - from torchjd.aggregation import CAGrad - from torchjd.aggregation._cagrad import CAGradWeighting -except ImportError: - import pytest - - pytest.skip("CAGrad dependencies not installed", allow_module_level=True) +from torchjd.aggregation import CAGrad +from torchjd.aggregation._cagrad import CAGradWeighting from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 6dac1f0e..79e5b3a6 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,15 +1,15 @@ +import pytest + +pytest.importorskip("cvxpy") +pytest.importorskip("ecos") + from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_, tensor_ -try: - from torchjd.aggregation import NashMTL - from torchjd.aggregation._nash_mtl import _NashMTLWeighting -except ImportError: - import pytest - - pytest.skip("NashMTL dependencies not installed", allow_module_level=True) +from torchjd.aggregation import NashMTL +from torchjd.aggregation._nash_mtl import _NashMTLWeighting from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import nash_mtl_matrices From fb9d6c82a77adfa88e69370e3cb6fd46df255fa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 11 May 2026 17:36:04 +0200 Subject: [PATCH 2/2] chore: Update changelog for optional deps discoverability change Co-Authored-By: Claude Sonnet 4.6 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a8aedb5..ebeeb67d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Changed +- `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`, + even when their optional dependencies are not installed. Attempting to instantiate them without the + required dependencies now raises an `ImportError` with installation instructions, instead of + raising an `ImportError` at import time. - Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG, GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors that require gradients. Their forward pass is now wrapped in `torch.no_grad()`, so attempting to