Skip to content

TypeError when using jnp.linalg.slogdet with forward_laplacian in a distributed/sharded JAX environment #37

@zhangylch

Description

@zhangylch

Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. I encountered a TypeError when computing the Laplacian of a function involving jnp.linalg.slogdet using folx. This error only occurs in a distributed/sharded environment (using jax.sharding.Mesh).

The code works perfectly fine in a single-device setting. It also works fine in a distributed setting if I replace jnp.linalg.slogdet with other operations like jnp.sum.

#36 Minimal Reproduction Script
Here is a minimal script to reproduce the issue. It sets up a JAX Mesh and shards the input data.

import jax
import jax.numpy as jnp
import numpy as np
from folx import forward_laplacian
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
import os

#jax.config.update("jax_enable_x64", True)

def test_parallel():
    print("=== Running Distributed (Mesh/Sharding) Test ===")

    if 'SLURM_PROCID' in os.environ:
        try:
            jax.distributed.initialize()
        except:
            pass


    mesh = Mesh(jax.devices(), axis_names=('data',))

    N = 10
    Batch = 8

    x_host = jnp.stack([jnp.eye(N) * (2.0 + i*0.1) for i in range(Batch)])

    sharding = NamedSharding(mesh, P('data'))

    x_sharded = jax.device_put(x_host, sharding)

    def single_sample_f(x):
        sign, logdet = jnp.linalg.slogdet(x)
        return logdet

    fwd_op = forward_laplacian(single_sample_f)

    fwd_f_vmap = jax.vmap(fwd_op)

    result = fwd_f_vmap(x_sharded)


if __name__ == "__main__":
    test_parallel()

#36 Error Traceback
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal float32[] with tangent float32[], expecting tangent float32[]

#36 Environment
Jax 0.7.0 + lastest folx

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions