Skip to content

Fix slow JIT compilation in e2tomogram_refine using jax.lax.scan#583

Open
amineuron wants to merge 1 commit intocryoem:masterfrom
amineuron:master
Open

Fix slow JIT compilation in e2tomogram_refine using jax.lax.scan#583
amineuron wants to merge 1 commit intocryoem:masterfrom
amineuron:master

Conversation

@amineuron
Copy link

GPU compilation of e2tomogram_refine took 5+ minutes due to Python for-loops inside @jit being unrolled by XLA (9 tiles x 38 tilts = 342 inlined copies). Replaced with jax.lax.scan in e2tomogram_refine_jax.py, reducing compilation from 5+ minutes to a few seconds. Added persistent JIT cache to both versions

GPU compilation of e2tomogram_refine took 5+ minutes due to Python for-loops
inside @jit being unrolled by XLA (9 tiles x 38 tilts = 342 inlined copies).
Replaced with jax.lax.scan in e2tomogram_refine_jax.py, reducing compilation
from 5+ minutes to a few seconds. Added persistent JIT cache to both versions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant