Skip to content

forward_laplacian inside of shard_map vma errors when using function with integer power 1 #38

@inailuig

Description

@inailuig

Recently came across this somewhat obscure error because of a missing pvary.
The missing pvary is the one I mentioned at the end of #36 (comment), but there I had a different reason in mind of why it would be needed.

Reproducer:

import jax
import jax.numpy as jnp
from folx import forward_laplacian
from functools import partial

def f(w, x):
    return jax.lax.integer_pow(x @ w, 1)

@jax.smap(out_axes=0,in_axes=(None, 0), axis_name='i')
@partial(jax.vmap, in_axes=(None, 0))
def test(w, x):
    return forward_laplacian(partial(f, w))(x)

x = jnp.ones((1,16))
w = jnp.ones((16,16))

with jax.set_mesh(jax.sharding.Mesh(jax.devices(), 'i')):
    test(w,x)
ERROR:[folx](/private/tmp/bug.py:7:11 (f)) - Error in operation integer_pow.
Traceback (most recent call last):
  File "/private/tmp/bug.py", line 18, in <module>
    test(w,x)
  File "/private/tmp/bug.py", line 12, in test
    return forward_laplacian(partial(f, w))(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 309, in wrapped
    out = eval_jaxpr_with_forward_laplacian(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 226, in eval_jaxpr_with_forward_laplacian
    raise e
  File "/Users/clemens/folx/folx/interpreter.py", line 222, in eval_jaxpr_with_forward_laplacian
    outvals = eval_laplacian(eqn, invals)
^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 183, in eval_laplacian
    return fn(
           ^^^
  File "/Users/clemens/folx/folx/wrapper.py", line 124, in new_fn
    lapl_y, lapl_fns.jac_hessian_jac_trace(laplace_args, sparsity_threshold)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 394, in hessian_transform
    return vmapped_jac_hessian_jac(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 367, in vmapped_jac_hessian_jac
    result = hess_transform(lapl_args, extra_args, out_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 348, in hess_transform
    result = general_jac_hessian_jac(merged_fn, args, out_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 99, in general_jac_hessian_jac
    flat_out = JHJ_via_hessian(flat_fn, flat_x, grad_2d)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 41, in JHJ_via_hessian
    flat_hessian = hessian(flat_fn)(flat_x)
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 97, in jacfun
    J = jax.vmap(jvp_fun, out_axes=-1)(eye)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 94, in jvp_fun
    return jax.jvp(f, primals, unravel(s))[1]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 77, in jacfun
    result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 45, in vjp
    out, vjp = jax.vjp(fun, *primals)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 70, in flat_f
    return jfu.ravel_pytree(f(*unravel(x)))[0]
                            ^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/utils.py", line 99, in new_fn
    return jfu.ravel_pytree(fn(*x))[0]  # type: ignore
                            ^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 334, in merged_fn
    return fwd(*merge(x, extra_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Primitive mul requires varying manual axes to match, but got [frozenset({'i'}), frozenset()]. Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_vma=False argument to `jax.shard_map`

PR with fix is on its way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions