Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 3 additions & 33 deletions src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from collections.abc import Iterable, Sequence
from typing import cast

from torch import Tensor
from torch.overrides import is_tensor_like

from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
from ._transform import AccumulateJac, Jac, OrderedSet, Transform
from ._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_jac_shapes,
check_matching_length,
check_optional_positive_chunk_size,
create_jac_dict,
get_leaf_tensors,
)

Expand Down Expand Up @@ -120,37 +116,11 @@ def backward(
else:
inputs_ = OrderedSet(inputs)

jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors)
jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "tensors", "jac_tensors")
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
transform(jac_tensors_dict)


def _create_jac_tensors_dict(
tensors: OrderedSet[Tensor],
opt_jac_tensors: Sequence[Tensor] | Tensor | None,
) -> dict[Tensor, Tensor]:
"""
Creates a dictionary mapping tensors to their corresponding Jacobians.

:param tensors: The tensors to differentiate.
:param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to
identity.
"""
if opt_jac_tensors is None:
# Transform that creates gradient outputs containing only ones.
init = Init(tensors)
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(tensors)
return (diag << init)({})
jac_tensors = cast(
Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors
)
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
check_consistent_first_dimension(jac_tensors, "jac_tensors")
return dict(zip(tensors, jac_tensors, strict=True))


def _create_transform(
tensors: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
Expand Down
34 changes: 2 additions & 32 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
from torch.overrides import is_tensor_like

from torchjd.autojac._transform._base import Transform
from torchjd.autojac._transform._diagonalize import Diagonalize
from torchjd.autojac._transform._init import Init
from torchjd.autojac._transform._jac import Jac
from torchjd.autojac._transform._ordered_set import OrderedSet
from torchjd.autojac._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_jac_shapes,
check_matching_length,
check_optional_positive_chunk_size,
create_jac_dict,
)


Expand Down Expand Up @@ -159,38 +155,12 @@ def jac(
inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs)
inputs_ = OrderedSet(inputs_with_repetition)

jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "outputs", "jac_outputs")
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
result = transform(jac_outputs_dict)
return tuple(result[input] for input in inputs_with_repetition)


def _create_jac_outputs_dict(
outputs: OrderedSet[Tensor],
opt_jac_outputs: Sequence[Tensor] | Tensor | None,
) -> dict[Tensor, Tensor]:
"""
Creates a dictionary mapping outputs to their corresponding Jacobians.

:param outputs: The tensors to differentiate.
:param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to
identity.
"""
if opt_jac_outputs is None:
# Transform that creates gradient outputs containing only ones.
init = Init(outputs)
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(outputs)
return (diag << init)({})
jac_outputs = cast(
Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs
)
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
check_consistent_first_dimension(jac_outputs, "jac_outputs")
return dict(zip(outputs, jac_outputs, strict=True))


def _create_transform(
outputs: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
Expand Down
35 changes: 35 additions & 0 deletions src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,41 @@ def check_matching_grad_shapes(
)


def create_jac_dict(
tensors: OrderedSet[Tensor],
opt_jacobians: Sequence[Tensor] | Tensor | None,
tensor_param_name: str,
jacobian_param_name: str,
) -> dict[Tensor, Tensor]:
"""
Creates a dictionary mapping tensors to their corresponding Jacobians.

If ``opt_jacobians`` is ``None``, creates identity Jacobians using Init and Diagonalize
transforms. Otherwise, validates the provided Jacobians and returns them as a dict.

:param tensors: The tensors to differentiate.
:param opt_jacobians: The initial Jacobians to backpropagate. If ``None``, defaults to
identity.
:param tensor_param_name: The name of the tensor parameter for error messages.
:param jacobian_param_name: The name of the jacobian parameter for error messages.
"""
from torchjd.autojac._transform._diagonalize import Diagonalize
from torchjd.autojac._transform._init import Init

if opt_jacobians is None:
init = Init(tensors)
diag = Diagonalize(tensors)
return (diag << init)({})

jacobians = cast(
Sequence[Tensor], (opt_jacobians,) if is_tensor_like(opt_jacobians) else opt_jacobians
)
check_matching_length(jacobians, tensors, jacobian_param_name, tensor_param_name)
check_matching_jac_shapes(jacobians, tensors, jacobian_param_name, tensor_param_name)
check_consistent_first_dimension(jacobians, jacobian_param_name)
return dict(zip(tensors, jacobians, strict=True))


def check_consistent_first_dimension(
jacobians: Sequence[Tensor],
variable_name: str,
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from utils.tensors import eye_, randn_, tensor_

from torchjd.autojac import backward
from torchjd.autojac._backward import _create_jac_tensors_dict, _create_transform
from torchjd.autojac._backward import _create_transform
from torchjd.autojac._transform import OrderedSet
from torchjd.autojac._utils import create_jac_dict


@mark.parametrize("default_jac_tensors", [True, False])
Expand All @@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_tensors: bool) -> None:
None if default_jac_tensors else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])]
)

jac_tensors = _create_jac_tensors_dict(
jac_tensors = create_jac_dict(
tensors=OrderedSet([y1, y2]),
opt_jac_tensors=optional_jac_tensors,
opt_jacobians=optional_jac_tensors,
tensor_param_name="tensors",
jacobian_param_name="jac_tensors",
)
transform = _create_transform(
tensors=OrderedSet([y1, y2]),
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/autojac/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from utils.tensors import eye_, randn_, tensor_

from torchjd.autojac import jac
from torchjd.autojac._jac import _create_jac_outputs_dict, _create_transform
from torchjd.autojac._jac import _create_transform
from torchjd.autojac._transform import OrderedSet
from torchjd.autojac._utils import create_jac_dict


@mark.parametrize("default_jac_outputs", [True, False])
Expand All @@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_outputs: bool) -> None:
None if default_jac_outputs else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])]
)

jac_outputs = _create_jac_outputs_dict(
outputs=OrderedSet([y1, y2]),
opt_jac_outputs=optional_jac_outputs,
jac_outputs = create_jac_dict(
tensors=OrderedSet([y1, y2]),
opt_jacobians=optional_jac_outputs,
tensor_param_name="outputs",
jacobian_param_name="jac_outputs",
)
transform = _create_transform(
outputs=OrderedSet([y1, y2]),
Expand Down
Loading