diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index aa855c3f..84cf6a0a 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -1,16 +1,12 @@ from collections.abc import Iterable, Sequence -from typing import cast from torch import Tensor -from torch.overrides import is_tensor_like -from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform +from ._transform import AccumulateJac, Jac, OrderedSet, Transform from ._utils import ( as_checked_ordered_set, - check_consistent_first_dimension, - check_matching_jac_shapes, - check_matching_length, check_optional_positive_chunk_size, + create_jac_dict, get_leaf_tensors, ) @@ -120,37 +116,11 @@ def backward( else: inputs_ = OrderedSet(inputs) - jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors) + jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "tensors", "jac_tensors") transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph) transform(jac_tensors_dict) -def _create_jac_tensors_dict( - tensors: OrderedSet[Tensor], - opt_jac_tensors: Sequence[Tensor] | Tensor | None, -) -> dict[Tensor, Tensor]: - """ - Creates a dictionary mapping tensors to their corresponding Jacobians. - - :param tensors: The tensors to differentiate. - :param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to - identity. - """ - if opt_jac_tensors is None: - # Transform that creates gradient outputs containing only ones. - init = Init(tensors) - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(tensors) - return (diag << init)({}) - jac_tensors = cast( - Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors - ) - check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors") - check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors") - check_consistent_first_dimension(jac_tensors, "jac_tensors") - return dict(zip(tensors, jac_tensors, strict=True)) - - def _create_transform( tensors: OrderedSet[Tensor], inputs: OrderedSet[Tensor], diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 00f65c29..911d5b9a 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -5,16 +5,12 @@ from torch.overrides import is_tensor_like from torchjd.autojac._transform._base import Transform -from torchjd.autojac._transform._diagonalize import Diagonalize -from torchjd.autojac._transform._init import Init from torchjd.autojac._transform._jac import Jac from torchjd.autojac._transform._ordered_set import OrderedSet from torchjd.autojac._utils import ( as_checked_ordered_set, - check_consistent_first_dimension, - check_matching_jac_shapes, - check_matching_length, check_optional_positive_chunk_size, + create_jac_dict, ) @@ -159,38 +155,12 @@ def jac( inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs) inputs_ = OrderedSet(inputs_with_repetition) - jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) + jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "outputs", "jac_outputs") transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) result = transform(jac_outputs_dict) return tuple(result[input] for input in inputs_with_repetition) -def _create_jac_outputs_dict( - outputs: OrderedSet[Tensor], - opt_jac_outputs: Sequence[Tensor] | Tensor | None, -) -> dict[Tensor, Tensor]: - """ - Creates a dictionary mapping outputs to their corresponding Jacobians. - - :param outputs: The tensors to differentiate. - :param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to - identity. - """ - if opt_jac_outputs is None: - # Transform that creates gradient outputs containing only ones. - init = Init(outputs) - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(outputs) - return (diag << init)({}) - jac_outputs = cast( - Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs - ) - check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") - check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs") - check_consistent_first_dimension(jac_outputs, "jac_outputs") - return dict(zip(outputs, jac_outputs, strict=True)) - - def _create_transform( outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index fdcc7ce1..0a0c9dfe 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -103,6 +103,41 @@ def check_matching_grad_shapes( ) +def create_jac_dict( + tensors: OrderedSet[Tensor], + opt_jacobians: Sequence[Tensor] | Tensor | None, + tensor_param_name: str, + jacobian_param_name: str, +) -> dict[Tensor, Tensor]: + """ + Creates a dictionary mapping tensors to their corresponding Jacobians. + + If ``opt_jacobians`` is ``None``, creates identity Jacobians using Init and Diagonalize + transforms. Otherwise, validates the provided Jacobians and returns them as a dict. + + :param tensors: The tensors to differentiate. + :param opt_jacobians: The initial Jacobians to backpropagate. If ``None``, defaults to + identity. + :param tensor_param_name: The name of the tensor parameter for error messages. + :param jacobian_param_name: The name of the jacobian parameter for error messages. + """ + from torchjd.autojac._transform._diagonalize import Diagonalize + from torchjd.autojac._transform._init import Init + + if opt_jacobians is None: + init = Init(tensors) + diag = Diagonalize(tensors) + return (diag << init)({}) + + jacobians = cast( + Sequence[Tensor], (opt_jacobians,) if is_tensor_like(opt_jacobians) else opt_jacobians + ) + check_matching_length(jacobians, tensors, jacobian_param_name, tensor_param_name) + check_matching_jac_shapes(jacobians, tensors, jacobian_param_name, tensor_param_name) + check_consistent_first_dimension(jacobians, jacobian_param_name) + return dict(zip(tensors, jacobians, strict=True)) + + def check_consistent_first_dimension( jacobians: Sequence[Tensor], variable_name: str, diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 806eb545..be42ebec 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -4,8 +4,9 @@ from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import backward -from torchjd.autojac._backward import _create_jac_tensors_dict, _create_transform +from torchjd.autojac._backward import _create_transform from torchjd.autojac._transform import OrderedSet +from torchjd.autojac._utils import create_jac_dict @mark.parametrize("default_jac_tensors", [True, False]) @@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_tensors: bool) -> None: None if default_jac_tensors else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] ) - jac_tensors = _create_jac_tensors_dict( + jac_tensors = create_jac_dict( tensors=OrderedSet([y1, y2]), - opt_jac_tensors=optional_jac_tensors, + opt_jacobians=optional_jac_tensors, + tensor_param_name="tensors", + jacobian_param_name="jac_tensors", ) transform = _create_transform( tensors=OrderedSet([y1, y2]), diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 75c68cb9..a13164fb 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -4,8 +4,9 @@ from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import jac -from torchjd.autojac._jac import _create_jac_outputs_dict, _create_transform +from torchjd.autojac._jac import _create_transform from torchjd.autojac._transform import OrderedSet +from torchjd.autojac._utils import create_jac_dict @mark.parametrize("default_jac_outputs", [True, False]) @@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_outputs: bool) -> None: None if default_jac_outputs else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] ) - jac_outputs = _create_jac_outputs_dict( - outputs=OrderedSet([y1, y2]), - opt_jac_outputs=optional_jac_outputs, + jac_outputs = create_jac_dict( + tensors=OrderedSet([y1, y2]), + opt_jacobians=optional_jac_outputs, + tensor_param_name="outputs", + jacobian_param_name="jac_outputs", ) transform = _create_transform( outputs=OrderedSet([y1, y2]),