From 3d11630787dac29abd91c4d45c8dd58c8254831e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:57:52 -0500 Subject: [PATCH 01/32] fix: raise errors for invalid shears and PixelScale WCS inits --- jax_galsim/shear.py | 46 +++++++++++++++++++++++++++++++++++++-------- pyproject.toml | 2 +- tests/GalSim | 2 +- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 074a762e..55b639a5 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError @@ -45,15 +46,24 @@ def __init__(self, *args, **kwargs): # g1,g2 elif "g1" in kwargs or "g2" in kwargs: - g1 = kwargs.pop("g1", 0.0) - g2 = kwargs.pop("g2", 0.0) + g1 = jnp.array(kwargs.pop("g1", 0.0)) + g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 + self._g = equinox.error_if( + self._g, jnp.abs(self._g) > 1., + "Requested shear exceeds 1.", + ) # e1,e2 elif "e1" in kwargs or "e2" in kwargs: - e1 = kwargs.pop("e1", 0.0) - e2 = kwargs.pop("e2", 0.0) + e1 = jnp.array(kwargs.pop("e1", 0.0)) + e2 = jnp.array(kwargs.pop("e2", 0.0)) absesq = e1**2 + e2**2 + absesq = equinox.error_if( + absesq, + absesq > 1., + "Requested distortion exceeds 1.", + ) self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 @@ -75,7 +85,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - g = kwargs.pop("g") + g = jnp.array(kwargs.pop("g")) + g = equinox.error_if( + g, + g > 1 or g < 0, + "Requested |shear| is outside [0,1].", + ) self._g = g * jnp.exp(2j * beta.rad) # e,beta @@ -89,7 +104,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - e = kwargs.pop("e") + e = jnp.array(kwargs.pop("e")) + e = equinox.error_if( + e, + (e > 1) | (e < 0), + "Requested distortion is outside [0,1].", + ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta @@ -103,7 +123,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - eta = kwargs.pop("eta") + eta = jnp.array(kwargs.pop("eta")) + eta = equinox.error_if( + eta, + eta < 0, + "Requested eta is below 0.", + ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta @@ -117,7 +142,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - q = kwargs.pop("q") + q = jnp.array(kwargs.pop("q")) + q = equinox.error_if( + q, + (q <= 0) | (q > 1), + "Cannot use requested axis ratio.", + ) eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) diff --git a/pyproject.toml b/pyproject.toml index 2ad2d9aa..f8ca4441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "The modular galaxy image simulation toolkit, but in JAX" dynamic = ["version"] license = { file = "LICENSE" } readme = "README.md" -dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax", "equinox"] [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] diff --git a/tests/GalSim b/tests/GalSim index a5afbf51..11c473b4 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit a5afbf510dc747f5667f61c742b9dd3630643988 +Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 From bd0e282c71dd6de8bfc3e502bf0477099c83224e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:58:40 -0500 Subject: [PATCH 02/32] please the dog --- jax_galsim/shear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 55b639a5..92d57693 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -50,7 +50,8 @@ def __init__(self, *args, **kwargs): g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 self._g = equinox.error_if( - self._g, jnp.abs(self._g) > 1., + self._g, + jnp.abs(self._g) > 1.0, "Requested shear exceeds 1.", ) @@ -61,7 +62,7 @@ def __init__(self, *args, **kwargs): absesq = e1**2 + e2**2 absesq = equinox.error_if( absesq, - absesq > 1., + absesq > 1.0, "Requested distortion exceeds 1.", ) self._g = (e1 + 1j * e2) * self._e2g(absesq) From a3b7ba4c68c4e806cfc95c5420e3af844de1642a Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:31:39 -0500 Subject: [PATCH 03/32] fix: mock up equinox --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 11c473b4..062c9ed0 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 +Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac From 5a43c922a80dd1d234e44c9e55055e1a7262991b Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:35:42 -0500 Subject: [PATCH 04/32] test: more array equals --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 062c9ed0..e5ee4016 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac +Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94 From ff189007dfc0e6bd0a4872eeb89f882d134a7714 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:49:07 -0500 Subject: [PATCH 05/32] doc: update docs for shears --- jax_galsim/shear.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 92d57693..60e5b306 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -11,9 +11,11 @@ @register_pytree_node_class @implements( _galsim.Shear, - lax_description="""\ -The jax_galsim implementation of ``Shear`` does not perform range checking of the \ -shear (e.g., ``|g| <= 1``) upon construction.""", + lax_description=( + "While the JAX-GalSim implementation of ``Shear`` will do range checking of " + "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " + "instead of ``galsim.GalSimRangeError`` exceptions." + ), ) class Shear(object): def __init__(self, *args, **kwargs): From 85569db75d31cb70ad11b1d4290b4ddabcdc91ff Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:51:39 -0500 Subject: [PATCH 06/32] fix: clarify docs --- jax_galsim/shear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 60e5b306..59ef2cca 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -12,9 +12,9 @@ @implements( _galsim.Shear, lax_description=( - "While the JAX-GalSim implementation of ``Shear`` will do range checking of " - "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " - "instead of ``galsim.GalSimRangeError`` exceptions." + "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " + "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " + "exceptions instead of ``galsim.GalSimRangeError`` exceptions." ), ) class Shear(object): From d8e24ae1f0486eaf76d243b8c9d43511aa7dbf74 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 08:49:57 -0500 Subject: [PATCH 07/32] fix: raise erorr on failed integrations --- jax_galsim/integ.py | 19 +++++++++++-------- tests/GalSim | 2 +- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 19ad5c4b..aaef6396 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax.lax import jax.numpy as jnp @@ -7,7 +8,10 @@ from jax_galsim.core.utils import implements +# @partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) + +@equinox.filter_jit @implements( _galsim.integ.int1d, lax_description=( @@ -17,12 +21,11 @@ - This implementation is different than the one in GalSim and lacks some features that greatly enhance galsim's accuracy. -- The JAX-GalSim implementation returns NaN on error/non-convergence instead of - rasing an exception. +- The JAX-GalSim implementation raises a ``equinox.EquinoxRuntimeError`` on error/non-convergence + instead of rasing a ``galsim.GalSimError`` exception. """ ), ) -@partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) def int1d( func, min, @@ -37,7 +40,7 @@ def int1d( # can be used with jax if _wrap_as_callback: - @jax.jit + @equinox.filter_jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.pure_callback(func, rdt, x, vmap_method="sequential") @@ -72,8 +75,8 @@ def _base_integration(): _base_integration, ) - return jax.lax.cond( - status == 0, - lambda: val, - lambda: jnp.nan, + return equinox.error_if( + val, + status != 0, + "`jax_galsim.int1d` failed to converge!", ) diff --git a/tests/GalSim b/tests/GalSim index e5ee4016..22013ee3 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94 +Subproject commit 22013ee3c4fe1659814e5bfc147779fac22dd8de From 81976e63362a93816b2d1f5b7a16a54904127b60 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 08:50:45 -0500 Subject: [PATCH 08/32] style: please the dog --- jax_galsim/integ.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index aaef6396..0404271d 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,5 +1,3 @@ -from functools import partial - import equinox import galsim as _galsim import jax.lax From d21e3830bf00b6dfccbb7ba09e6e5611b6393b81 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:03:56 -0500 Subject: [PATCH 09/32] fix: try code with equinox filter_jit --- jax_galsim/integ.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 0404271d..208e36d5 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,3 +1,5 @@ +from functools import partial + import equinox import galsim as _galsim import jax.lax @@ -6,8 +8,6 @@ from jax_galsim.core.utils import implements -# @partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) - @equinox.filter_jit @implements( From 8805c8c9d79f901d4c1ccbf7a0aed6f2f4143ba2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:38:19 -0500 Subject: [PATCH 10/32] fix: use standard JIT --- jax_galsim/integ.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 208e36d5..277e57f6 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -9,7 +9,6 @@ from jax_galsim.core.utils import implements -@equinox.filter_jit @implements( _galsim.integ.int1d, lax_description=( @@ -24,6 +23,7 @@ """ ), ) +@partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) def int1d( func, min, @@ -38,7 +38,7 @@ def int1d( # can be used with jax if _wrap_as_callback: - @equinox.filter_jit + @jax.jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.pure_callback(func, rdt, x, vmap_method="sequential") From 4d661a28efdff90fb5774baa9242c230e6a91ceb Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:38:46 -0500 Subject: [PATCH 11/32] doc: update doc strings --- jax_galsim/integ.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 277e57f6..de1c26ce 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -18,7 +18,7 @@ - This implementation is different than the one in GalSim and lacks some features that greatly enhance galsim's accuracy. -- The JAX-GalSim implementation raises a ``equinox.EquinoxRuntimeError`` on error/non-convergence +- The JAX-GalSim implementation raises a generic ``Exception`` on error/non-convergence instead of rasing a ``galsim.GalSimError`` exception. """ ), From 1561dc06955d5601358aca2a68164e4fd1cb633d Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 10:01:45 -0500 Subject: [PATCH 12/32] fix: only use generic Exception --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 22013ee3..95ff6fd9 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 22013ee3c4fe1659814e5bfc147779fac22dd8de +Subproject commit 95ff6fd945cec5056a33276af3333fb70f5cb879 From 8c69aac025e2dc0842b17091a38c788ce6584e8f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 10:02:31 -0500 Subject: [PATCH 13/32] doc: update docs --- jax_galsim/shear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 59ef2cca..adfb25f4 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -13,8 +13,8 @@ _galsim.Shear, lax_description=( "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " - "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " - "exceptions instead of ``galsim.GalSimRangeError`` exceptions." + "invalid shear values (e.g., |g| > 1), it raises a generic ``Exception`` " + "instead of ``galsim.GalSimRangeError`` exceptions." ), ) class Shear(object): From 6d0a66b921d924c3d6c5902867b5ff471ce25f5f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 16:23:00 -0500 Subject: [PATCH 14/32] refactor: centralize logic for int checking --- jax_galsim/bounds.py | 48 ++++++++++++---------------------- jax_galsim/core/utils.py | 31 ++++++++++++++++++++++ jax_galsim/position.py | 12 ++++++--- jax_galsim/shear.py | 2 +- tests/jax/test_position_jax.py | 21 +++++++++++++++ 5 files changed, 78 insertions(+), 36 deletions(-) create mode 100644 tests/jax/test_position_jax.py diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 10442ad4..21b6ba8a 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,34 +1,24 @@ import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + CONST_TYPES, cast_to_float, cast_to_int, + cast_to_python_float, + check_is_int_then_cast, ensure_hashable, has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) -CONST_TYPES_WITH_JAX = CONST_TYPES + ( - jax.Array, - jnp.array, - jnp.int32, - jnp.int64, - jnp.float32, - jnp.float64, -) - -# TODO: write extra docs for JAX changes BOUNDS_LAX_DESCR = """\ The JAX implementation - will not always test whether the bounds are valid -- will not always test whether BoundsI is initialized with integers Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance @@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs): f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) + self.deltax = cast_to_python_float(self.deltax) + self.deltay = cast_to_python_float(self.deltay) + if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): + raise TypeError("BoundsI must be initialized with integer values") self.deltax = int(cast_to_int(self.deltax)) self.deltay = int(cast_to_int(self.deltay)) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") + if has_tracers(self._xmin) or has_tracers(self._ymin): + self._isstatic = False + + # validate inputs are ints + self._xmin = check_is_int_then_cast( + self._xmin, "BoundsI must be initialized with integer values" + ) + self._ymin = check_is_int_then_cast( + self._ymin, "BoundsI must be initialized with integer values" + ) if self.deltax < 1 and self.deltay < 1: self._isdefined = False - # for simple inputs, we can check if the bounds are valid ints - if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): - raise TypeError("BoundsI must be initialized with integer values") - - if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): - raise TypeError("BoundsI must be initialized with integer values") - - if not has_tracers(self._xmin) and not has_tracers(self._ymin): - self._isstatic = True - self._xmin = int(np.trunc(self._xmin)) - self._ymin = int(np.trunc(self._ymin)) - else: - self._isstatic = False - self._xmin = cast_to_float(jnp.trunc(self._xmin)) - self._ymin = cast_to_float(jnp.trunc(self._ymin)) - if force_static and not self._isstatic: raise RuntimeError( "BoundsI initialized with non-static " diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 3fcbf46d..e4d51b18 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -3,11 +3,42 @@ from functools import partial from typing import NamedTuple +import equinox import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_flatten +CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) +CONST_TYPES_WITH_JAX = CONST_TYPES + ( + jax.Array, + jnp.array, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, +) + + +def check_is_int_then_cast(val, msg): + """Cast to integer and raise if value is not int.""" + # for simple inputs, we can check if the bounds are valid ints + if isinstance(val, CONST_TYPES) and not has_tracers(val): + val = cast_to_python_float(val) + if val != int(val): + raise TypeError(msg) + val = int(val) + else: + # otherwise we use more opaque checking upon jit + val = equinox.error_if( + val, + val != jnp.trunc(val), + msg, + ) + val = val.astype(int) + + return val + def cast_numpy_array_to_native_byte_order(arr): """Cast an array to native byte order.""" diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 822797b8..b3af5844 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -5,7 +5,7 @@ from jax_galsim.core.utils import ( cast_to_float, - cast_to_int, + check_is_int_then_cast, ensure_hashable, implements, ) @@ -214,9 +214,13 @@ class PositionI(Position): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - # inputs must be ints - self.x = cast_to_int(self.x) - self.y = cast_to_int(self.y) + # validate input is int + self.x = check_is_int_then_cast( + self.x, "PositionI must be initialized with integer values" + ) + self.y = check_is_int_then_cast( + self.y, "PositionI must be initialized with integer values" + ) def _check_scalar(self, other, op): try: diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index adfb25f4..dd88f424 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -14,7 +14,7 @@ lax_description=( "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " "invalid shear values (e.g., |g| > 1), it raises a generic ``Exception`` " - "instead of ``galsim.GalSimRangeError`` exceptions." + "instead of a ``galsim.GalSimRangeError`` exception." ), ) class Shear(object): diff --git a/tests/jax/test_position_jax.py b/tests/jax/test_position_jax.py new file mode 100644 index 00000000..e73937e7 --- /dev/null +++ b/tests/jax/test_position_jax.py @@ -0,0 +1,21 @@ +import jax +import pytest + +import jax_galsim + + +def test_position_jax_int_raises_in_jit(): + + @jax.jit + def _make_pos(x, y): + return jax_galsim.PositionI(x, y) + + with pytest.raises(Exception): + _make_pos(1.2, 23) + + with pytest.raises(Exception): + _make_pos(12, 2.3) + + pos = _make_pos(1, 2) + assert pos.x == 1 + assert pos.y == 2 From 260984000f5cbb96af27c6cff1f19fdbf923139f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 16:26:09 -0500 Subject: [PATCH 15/32] doc: ensure doc string is accurate --- jax_galsim/core/utils.py | 6 +++--- tests/GalSim | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index e4d51b18..18916073 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -21,15 +21,15 @@ def check_is_int_then_cast(val, msg): - """Cast to integer and raise if value is not int.""" - # for simple inputs, we can check if the bounds are valid ints + """Check if `val` is an integer, raise if not, otherwise cast to int.""" + # for simple inputs, we can check direct in python if isinstance(val, CONST_TYPES) and not has_tracers(val): val = cast_to_python_float(val) if val != int(val): raise TypeError(msg) val = int(val) else: - # otherwise we use more opaque checking upon jit + # otherwise we use more opaque checking upon jit via equinox val = equinox.error_if( val, val != jnp.trunc(val), diff --git a/tests/GalSim b/tests/GalSim index 95ff6fd9..0fe6d90b 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 95ff6fd945cec5056a33276af3333fb70f5cb879 +Subproject commit 0fe6d90bd7df4f660c923dc03da8aa44b347afda From ed7c5083e0fe994fbc13093c0fb5792a28ccc37f Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:44:57 -0500 Subject: [PATCH 16/32] fix: enable tests for image gain, area, exptime, and max_extra_noise --- jax_galsim/gsobject.py | 23 ++++++++++++++++++++--- tests/GalSim | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 3d72dbab..5907e5b8 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -601,7 +602,7 @@ def drawImage( offset=None, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), @@ -626,6 +627,13 @@ def drawImage( if image is not None and not isinstance(image, Image): raise TypeError("image is not an Image instance", image) + # Make sure (gain, area, exptime) have valid values: + gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") + area = equinox.error_if(jnp.array(area), area <= 0.0, "Invalid area <= 0.") + exptime = equinox.error_if( + jnp.array(exptime), exptime <= 0.0, "Invalid exptime <= 0." + ) + if method == "phot" and save_photons and maxN is not None: raise GalSimIncompatibleValuesError( "Setting maxN is incompatible with save_photons=True" @@ -659,6 +667,13 @@ def drawImage( sensor=sensor, n_photons=n_photons, ) + if max_extra_noise is not None: + raise GalSimIncompatibleValuesError( + "max_extra_noise is only relevant for method='phot'", + method=method, + sensor=sensor, + max_extra_noise=max_extra_noise, + ) if poisson_flux is not None: raise GalSimIncompatibleValuesError( "poisson_flux is only relevant for method='phot'", @@ -1078,6 +1093,8 @@ def _drawKImage( @implements(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): + if max_extra_noise is None: + max_extra_noise = 0.0 n_photons, g, _rng = calculate_n_photons( self.flux, self._flux_per_photon, @@ -1106,7 +1123,7 @@ def makePhot( self, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, photon_ops=(), local_wcs=None, @@ -1178,7 +1195,7 @@ def drawPhot( add_to_image=False, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), diff --git a/tests/GalSim b/tests/GalSim index 0fe6d90b..200c2cd2 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0fe6d90bd7df4f660c923dc03da8aa44b347afda +Subproject commit 200c2cd2bad9f8f93936290cdca9d87ee10ebaa1 From 831990a64e266b8bc1ed8256996561f67431cc8d Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:47:02 -0500 Subject: [PATCH 17/32] doc: update doc strings --- jax_galsim/gsobject.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 5907e5b8..0ea14865 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -575,10 +575,11 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): lax_description="""\ The JAX-GalSim version of ``drawImage`` -- does not do extensive (any?) checking of the input settings. - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. """, From 42fb804703ddaed724850f410f15d06d4a2930d3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:48:00 -0500 Subject: [PATCH 18/32] doc: update doc string --- jax_galsim/gsobject.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 0ea14865..1a94cf1f 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -582,6 +582,7 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs """, ) def drawImage( From 2ef7195df3502c9286890f7c872db669bb4ed289 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:51:32 -0500 Subject: [PATCH 19/32] doc: add doc string for position exceptions --- jax_galsim/position.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index b3af5844..cf36dba8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -208,7 +208,14 @@ def _check_scalar(self, other, op): raise TypeError("Can only %s a PositionD by float values" % op) -@implements(_galsim.PositionI) +@implements( + _galsim.PositionI, + lax_description=( + "The ``jax_galsim.PositionI`` class will raise generic " + "``Exception``s instead of a more specific exception for invalid " + "inputs." + ), +) @register_pytree_node_class class PositionI(Position): def __init__(self, *args, **kwargs): From 3ee5f33e0776230c89339f974f28bf97f92ab5a8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 06:01:15 -0500 Subject: [PATCH 20/32] fix+doc: do more error checking and more docs --- jax_galsim/gsobject.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 1a94cf1f..ce69ec4b 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1115,10 +1115,10 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): lax_description="""\ The JAX-GalSim version of ``makePhot`` -- does little to no error checking on the inputs - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs """, ) def makePhot( @@ -1187,6 +1187,9 @@ def makePhot( - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs - requires that the ``maxN`` option must be a constant """, ) @@ -1227,6 +1230,8 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") + gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") + if n_photons is not None: # n_photons is the length of an array so it is a python int and # and thus a constant wrt to JIT From 48c2569eeaed9f177f1eea5416c57e961d28cd7c Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 06:20:03 -0500 Subject: [PATCH 21/32] doc: fix doc string formatting --- jax_galsim/gsobject.py | 6 +++--- tests/GalSim | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index ce69ec4b..c6896fec 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -582,7 +582,7 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def drawImage( @@ -1118,7 +1118,7 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` to indicate no limit on the extra noise -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def makePhot( @@ -1189,7 +1189,7 @@ def makePhot( from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` to indicate no limit on the extra noise -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs - requires that the ``maxN`` option must be a constant """, ) diff --git a/tests/GalSim b/tests/GalSim index 200c2cd2..0dabbf46 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 200c2cd2bad9f8f93936290cdca9d87ee10ebaa1 +Subproject commit 0dabbf463b4af7f689074c8373a936d511e4b836 From 6e9c4debe9d5ec037b97edbe3be126b9f48e3b44 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 07:42:15 -0500 Subject: [PATCH 22/32] Apply suggestion from @beckermr --- jax_galsim/gsobject.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index c6896fec..b1543f54 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -579,7 +579,6 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate that the number of photons should be determined from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` - to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. - raises a generic ``Exception`` instead of a more specific exception for some invalid inputs From 37310d93d7753f19437ef4be3d7f796d37507068 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:16:32 -0500 Subject: [PATCH 23/32] fix: add the rest of the types --- jax_galsim/core/utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 18916073..91007e39 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -9,14 +9,31 @@ import numpy as np from jax.tree_util import tree_flatten -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) +CONST_TYPES = ( + float, + int, + np.ndarray, + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, +) CONST_TYPES_WITH_JAX = CONST_TYPES + ( jax.Array, jnp.array, + jnp.int8, + jnp.int16, jnp.int32, jnp.int64, jnp.float32, jnp.float64, + jnp.complex64, + jnp.complex128, ) From 88c1c7b546eeb1ff2d6a4da450d08cd9c466df31 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:17:50 -0500 Subject: [PATCH 24/32] fix: use proper array ref --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 91007e39..71ff2735 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -25,7 +25,7 @@ ) CONST_TYPES_WITH_JAX = CONST_TYPES + ( jax.Array, - jnp.array, + jnp.ndarray, jnp.int8, jnp.int16, jnp.int32, From 70e4c3f2406876ca5a08201f3612f46082af7ad6 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 08:18:35 -0500 Subject: [PATCH 25/32] Apply suggestion from @beckermr --- jax_galsim/gsobject.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index b1543f54..db59db43 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1187,7 +1187,6 @@ def makePhot( to indicate that the number of photons should be determined from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` - to indicate no limit on the extra noise - raises a generic ``Exception`` instead of a more specific exception for some invalid inputs - requires that the ``maxN`` option must be a constant """, From 164c2c626c015ddd5be7e38008dc074b2ed0abf9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:21:50 -0500 Subject: [PATCH 26/32] fix: docs done wrong --- docs/conf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index edb02218..f96b604b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,9 +29,11 @@ copyright = "2026, GalSim Developers" try: - from jax_galsim._version import version as release + from jax_galsim._version import version except ImportError: - release = "0.0.1.dev0" + version = "0.0.1.dev0" + +release = version # --------------------------------------------------------------------------- # General configuration From aafb848227446f3d6bea8fed601af7c9198620e0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:37:31 -0500 Subject: [PATCH 27/32] test: update to latest submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 0dabbf46..5cd4c1ec 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0dabbf463b4af7f689074c8373a936d511e4b836 +Subproject commit 5cd4c1ecc8b856790558e39677900cc43e0ce67f From 4ed8c8edee08ba07f9f352d85bd5272690b72448 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 09:10:23 -0500 Subject: [PATCH 28/32] fix: raise for interpolated image init problems --- jax_galsim/interpolatedimage.py | 14 ++++++++++++++ tests/GalSim | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0424e9d8..bfb5c5d7 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -2,6 +2,7 @@ import math from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -118,6 +119,11 @@ def __init__( elif not isinstance(image, Image): raise TypeError("Supplied image must be an Image or file name") + if not (image.dtype == jnp.float32 or image.dtype == jnp.float64): + raise GalSimValueError( + "Interpolated images must use a float-type image.", image.dtype + ) + self._jax_children = ( image, dict( @@ -506,6 +512,14 @@ def __init__( image=self._jax_children[0], ) + if calculate_stepk or calculate_maxk or flux is not None: + image = equinox.error_if( + image, + image.array.sum() == 0.0, + "This input image has zero total flux. It does not define a " + "valid surface brightness profile.", + ) + @doc_inherit def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: diff --git a/tests/GalSim b/tests/GalSim index 5cd4c1ec..09ded8ab 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5cd4c1ecc8b856790558e39677900cc43e0ce67f +Subproject commit 09ded8abfa570f836084ef9cf8d53c210203f825 From 7963c1f66cce1ddb3d210fda921e781032339ab7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 09:15:49 -0500 Subject: [PATCH 29/32] doc: add docs for exceptions --- jax_galsim/interpolatedimage.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index bfb5c5d7..f2d0e11c 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -60,6 +60,8 @@ def __dir__(cls): - the pad_image options - depixelize - most of the bounds checks, type checks, and dtype casts done by galsim +- raises a generic ``Exception`` instead of a more specific one for some + initialization errors """ From 2e2aaca68758c24da9ac30106d4698e99f69d5f8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:00:35 -0500 Subject: [PATCH 30/32] fix: use array in transform, not image --- jax_galsim/interpolatedimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f2d0e11c..e11ba44a 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -515,8 +515,8 @@ def __init__( ) if calculate_stepk or calculate_maxk or flux is not None: - image = equinox.error_if( - image, + image.array = equinox.error_if( + image.array, image.array.sum() == 0.0, "This input image has zero total flux. It does not define a " "valid surface brightness profile.", From 58196ca0f18275c3c0c0cc51d45691beff02a72b Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:32:56 -0500 Subject: [PATCH 31/32] fix: ensure repr of image prints even with tracers --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 3e8c0e69..e964380b 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -326,7 +326,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined(): + if self.bounds.isDefined() and not has_tracers(self.array): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: From 58efd99e639cf2b1a5393bc55a5f5f7249d1c068 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:41:39 -0500 Subject: [PATCH 32/32] fix: raise for invalid beta --- jax_galsim/moffat.py | 15 ++++++++++++++- tests/jax/test_moffat_jax.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/jax/test_moffat_jax.py diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 2a9b312b..5d994fe1 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -31,7 +32,7 @@ def _Knu(nu, x): lax_description="""\ The JAX-GalSim version of the Moffat profile -- does not support truncation or beta < 1.1 +- does not support truncation or beta <= 1.1 - does not support gsparams.maxk_thresholds > 0.1 - does not support autodiff with respect to the `beta` parameter for Fourier-space evaluations @@ -67,6 +68,18 @@ def __init__( f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" ) + if isinstance(beta, (float, int)): + if beta <= self._beta_thr: + raise ValueError( + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." + ) + else: + beta = equinox.error_if( + jnp.array(beta), + beta <= self._beta_thr, + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}.", + ) + # Parse the radius options if half_light_radius is not None: if scale_radius is not None or fwhm is not None: diff --git a/tests/jax/test_moffat_jax.py b/tests/jax/test_moffat_jax.py new file mode 100644 index 00000000..4810a105 --- /dev/null +++ b/tests/jax/test_moffat_jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp +import pytest + +import jax_galsim + + +def test_moffat_jax_beta_raises(): + + @jax.jit + def make_moffat(beta): + return jax_galsim.Moffat(beta, fwhm=1.0) + + with pytest.raises(Exception): + make_moffat(jnp.array(1.1)) + + with pytest.raises(Exception): + make_moffat(0.9)