diff --git a/folx/ad.py b/folx/ad.py index 6e06fc0..2536202 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -3,6 +3,8 @@ import jax.numpy as jnp import jax.tree_util as jtu +from .utils import _mark_varying_like + def is_tree_complex(tree): leaves = jtu.tree_leaves(tree) @@ -72,8 +74,7 @@ def flat_f(x): out = flat_f(flat_primals) eye = jnp.eye(out.size, dtype=out.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) + eye = _mark_varying_like(eye, out) result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: @@ -94,8 +95,7 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) + eye = _mark_varying_like(eye, flat_primals) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J diff --git a/folx/tree_utils.py b/folx/tree_utils.py index 56367bf..2c45e10 100644 --- a/folx/tree_utils.py +++ b/folx/tree_utils.py @@ -8,12 +8,16 @@ T = TypeVar('T', bound=PyTree[ArrayLike]) +def _is_leaf_pytree(x) -> bool: + return jtu.treedef_is_leaf(jtu.tree_structure(x)) + + def tree_scale(tree: T, x: ArrayLike) -> T: return jtu.tree_map(lambda a: a * x, tree) def tree_mul(tree: T, x: T | ArrayLike) -> T: - if isinstance(x, ArrayLike): + if _is_leaf_pytree(x): return tree_scale(tree, x) return jtu.tree_map(lambda a, b: a * b, tree, x) @@ -23,7 +27,7 @@ def tree_shift(tree1: T, x: ArrayLike) -> T: def tree_add(tree1: T, tree2: T | ArrayLike) -> T: - if isinstance(tree2, ArrayLike): + if _is_leaf_pytree(tree2): return tree_shift(tree1, tree2) return jtu.tree_map(lambda a, b: a + b, tree1, tree2) diff --git a/folx/utils.py b/folx/utils.py index 7ed08eb..d446f56 100644 --- a/folx/utils.py +++ b/folx/utils.py @@ -44,6 +44,30 @@ def tree_shapes(tree: PyTree[Array]) -> list[tuple[int, ...]]: return [l.shape for l in leaves] +def _varying_axes(x: jax.Array) -> tuple[int | str, ...]: + if not hasattr(jax, 'typeof'): + return () + + typ = jax.typeof(x) + manual_axis_type = getattr(typ, 'manual_axis_type', None) + if manual_axis_type is not None: + return tuple(getattr(manual_axis_type, 'varying', ())) + + return tuple(getattr(typ, 'vma', ())) + + +def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array: + axes = _varying_axes(like) + if not axes: + return x + + if hasattr(jax.lax, 'pcast'): + return jax.lax.pcast(x, axes, to='varying') + if hasattr(jax.lax, 'pvary'): + return jax.lax.pvary(x, axes) + return x + + def trace_of_product(mat1: Array, mat2: Array): """ Computes the trace of the product of the given matrices. diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index bf6eca9..ff76961 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -28,6 +28,7 @@ div_jac_hessian_jac, slogdet_jac_hessian_jac, ) +from .utils import _mark_varying_like from .wrapper import ( warp_without_fwd_laplacian, wrap_elementwise, @@ -226,11 +227,9 @@ def custom_jvp(jacobian, tangent, sign): sign_jvp = jac_dot_tangent log_det_jvp = jac_dot_tangent.real else: - sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) - if hasattr(jax.lax, 'pvary'): - sign_jvp = jax.lax.pvary( - sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma) - ) + sign_jvp = _mark_varying_like( + jnp.zeros((), dtype=jac_dot_tangent.dtype), jac_dot_tangent + ) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp)