From 220348320f2313e458be25ec02d4c8f352ed1e32 Mon Sep 17 00:00:00 2001 From: David Linteau Date: Fri, 1 May 2026 22:37:39 +0200 Subject: [PATCH 1/3] Fix varying basis construction for JAX 0.10 --- folx/ad.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/folx/ad.py b/folx/ad.py index 6e06fc0..0b39b2c 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -9,6 +9,30 @@ def is_tree_complex(tree): return any(jnp.iscomplexobj(leaf) for leaf in leaves) +def _varying_axes(x: jax.Array) -> tuple: + if not hasattr(jax, 'typeof'): + return () + + typ = jax.typeof(x) + manual_axis_type = getattr(typ, 'manual_axis_type', None) + if manual_axis_type is not None: + return tuple(getattr(manual_axis_type, 'varying', ())) + + return tuple(getattr(typ, 'vma', ())) + + +def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array: + axes = _varying_axes(like) + if not axes: + return x + + if hasattr(jax.lax, 'pcast'): + return jax.lax.pcast(x, axes, to='varying') + if hasattr(jax.lax, 'pvary'): + return jax.lax.pvary(x, axes) + return x + + def vjp_rc(fun, *primals: jax.Array): def real_fun(*primals): return jnp.real(fun(*primals)) @@ -72,8 +96,7 @@ def flat_f(x): out = flat_f(flat_primals) eye = jnp.eye(out.size, dtype=out.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma)) + eye = _mark_varying_like(eye, out) result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0] result = jax.vmap(unravel, out_axes=0)(result) if len(primals) == 1: @@ -94,8 +117,7 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) - if hasattr(jax.lax, 'pvary'): - eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) + eye = _mark_varying_like(eye, flat_primals) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J From e4860792ef51be1161068a79b66ecd0e23302cbc Mon Sep 17 00:00:00 2001 From: David Linteau Date: Wed, 6 May 2026 07:58:54 +0200 Subject: [PATCH 2/3] Move varying marker helper to utils and reuse it in slogdet --- folx/ad.py | 26 ++------------------------ folx/utils.py | 24 ++++++++++++++++++++++++ folx/wrapped_functions.py | 9 ++++----- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/folx/ad.py b/folx/ad.py index 0b39b2c..2536202 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -3,36 +3,14 @@ import jax.numpy as jnp import jax.tree_util as jtu +from .utils import _mark_varying_like + def is_tree_complex(tree): leaves = jtu.tree_leaves(tree) return any(jnp.iscomplexobj(leaf) for leaf in leaves) -def _varying_axes(x: jax.Array) -> tuple: - if not hasattr(jax, 'typeof'): - return () - - typ = jax.typeof(x) - manual_axis_type = getattr(typ, 'manual_axis_type', None) - if manual_axis_type is not None: - return tuple(getattr(manual_axis_type, 'varying', ())) - - return tuple(getattr(typ, 'vma', ())) - - -def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array: - axes = _varying_axes(like) - if not axes: - return x - - if hasattr(jax.lax, 'pcast'): - return jax.lax.pcast(x, axes, to='varying') - if hasattr(jax.lax, 'pvary'): - return jax.lax.pvary(x, axes) - return x - - def vjp_rc(fun, *primals: jax.Array): def real_fun(*primals): return jnp.real(fun(*primals)) diff --git a/folx/utils.py b/folx/utils.py index 7ed08eb..d446f56 100644 --- a/folx/utils.py +++ b/folx/utils.py @@ -44,6 +44,30 @@ def tree_shapes(tree: PyTree[Array]) -> list[tuple[int, ...]]: return [l.shape for l in leaves] +def _varying_axes(x: jax.Array) -> tuple[int | str, ...]: + if not hasattr(jax, 'typeof'): + return () + + typ = jax.typeof(x) + manual_axis_type = getattr(typ, 'manual_axis_type', None) + if manual_axis_type is not None: + return tuple(getattr(manual_axis_type, 'varying', ())) + + return tuple(getattr(typ, 'vma', ())) + + +def _mark_varying_like(x: jax.Array, like: jax.Array) -> jax.Array: + axes = _varying_axes(like) + if not axes: + return x + + if hasattr(jax.lax, 'pcast'): + return jax.lax.pcast(x, axes, to='varying') + if hasattr(jax.lax, 'pvary'): + return jax.lax.pvary(x, axes) + return x + + def trace_of_product(mat1: Array, mat2: Array): """ Computes the trace of the product of the given matrices. diff --git a/folx/wrapped_functions.py b/folx/wrapped_functions.py index bf6eca9..ff76961 100644 --- a/folx/wrapped_functions.py +++ b/folx/wrapped_functions.py @@ -28,6 +28,7 @@ div_jac_hessian_jac, slogdet_jac_hessian_jac, ) +from .utils import _mark_varying_like from .wrapper import ( warp_without_fwd_laplacian, wrap_elementwise, @@ -226,11 +227,9 @@ def custom_jvp(jacobian, tangent, sign): sign_jvp = jac_dot_tangent log_det_jvp = jac_dot_tangent.real else: - sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype) - if hasattr(jax.lax, 'pvary'): - sign_jvp = jax.lax.pvary( - sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma) - ) + sign_jvp = _mark_varying_like( + jnp.zeros((), dtype=jac_dot_tangent.dtype), jac_dot_tangent + ) log_det_jvp = jac_dot_tangent return (sign_jvp, log_det_jvp) From bd945482809c31d0bcf35d809dfaadd782e27f05 Mon Sep 17 00:00:00 2001 From: David Linteau Date: Wed, 6 May 2026 08:19:22 +0200 Subject: [PATCH 3/3] Fix checks of scalar pytree leaves in tree utils --- folx/tree_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/folx/tree_utils.py b/folx/tree_utils.py index 56367bf..2c45e10 100644 --- a/folx/tree_utils.py +++ b/folx/tree_utils.py @@ -8,12 +8,16 @@ T = TypeVar('T', bound=PyTree[ArrayLike]) +def _is_leaf_pytree(x) -> bool: + return jtu.treedef_is_leaf(jtu.tree_structure(x)) + + def tree_scale(tree: T, x: ArrayLike) -> T: return jtu.tree_map(lambda a: a * x, tree) def tree_mul(tree: T, x: T | ArrayLike) -> T: - if isinstance(x, ArrayLike): + if _is_leaf_pytree(x): return tree_scale(tree, x) return jtu.tree_map(lambda a, b: a * b, tree, x) @@ -23,7 +27,7 @@ def tree_shift(tree1: T, x: ArrayLike) -> T: def tree_add(tree1: T, tree2: T | ArrayLike) -> T: - if isinstance(tree2, ArrayLike): + if _is_leaf_pytree(tree2): return tree_shift(tree1, tree2) return jtu.tree_map(lambda a, b: a + b, tree1, tree2)