Skip to content

Could not perform index operation scatter #28

@zhangylch

Description

@zhangylch

Hi

Thank you for implementing the forward Lap. It significantly accelerates the Laplacian calculation. However, I am encountering an error when attempting to use sparsity. It seems that the sparsity is not being applied successfully. Here are the details of the issue:
Image

Next is the minimal code to reproduce the error.

import folx
import jax
import time


def fwd(x):
    x = x.reshape(-1, 3)
    distances = jnp.sqrt(jnp.sum(jnp.square(x), axis=1))
    sph = jnp.zeros((2, distances.shape[0]))
    sph = sph.at[0].set(distances)
    sph = sph.at[1].set(distances * 5.0)
    return jnp.sum(x)



key = jax.random.PRNGKey(12)
x = jax.random.normal(key, (100,300))


lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(6)(fwd)))
jax.block_until_ready(lapl(x))
start_time = time.time()
jax.block_until_ready(lapl(x))
end_time = time.time()
print(end_time - start_time)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions