diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 2e18986d..1b489398 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -1498,12 +1498,16 @@ def _xval_noraise(self, x): # it gets recompiled as needed for combinations of n, conserve_dc, du, and krange @functools.partial(jax.jit, static_argnames=("n", "conserve_dc", "du", "krange")) def _interp_kval(k, n, conserve_dc, du, krange): - _idata = _lanczos_kval_interp_table( - n, - du, - krange, - conserve_dc, - ) + with jax.ensure_compile_time_eval(): + _idata = _lanczos_kval_interp_table( + n, + # jax-galsim uses a slightly less accurate interpolation + # function (akima vs cubic spline) and so needs a smaller spacing + # 2.3x appears to be ok + du / 2.3, + krange, + conserve_dc, + ) return akima_interp(jnp.abs(k), *_idata, fixed_spacing=True) def _kval_noraise(self, k): diff --git a/tests/GalSim b/tests/GalSim index 89723bcd..bcaad957 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 89723bcd581b5fc5d00acaab0b23fd38df171fdc +Subproject commit bcaad9579dfc67d0542a2f6cbc6e900c8c7a549b