Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3d11630
fix: raise errors for invalid shears and PixelScale WCS inits
beckermr May 14, 2026
bd0e282
please the dog
beckermr May 14, 2026
a3b7ba4
fix: mock up equinox
beckermr May 14, 2026
5a43c92
test: more array equals
beckermr May 14, 2026
ff18900
doc: update docs for shears
beckermr May 14, 2026
85569db
fix: clarify docs
beckermr May 14, 2026
d8e24ae
fix: raise erorr on failed integrations
beckermr May 14, 2026
81976e6
style: please the dog
beckermr May 14, 2026
d21e383
fix: try code with equinox filter_jit
beckermr May 14, 2026
8805c8c
fix: use standard JIT
beckermr May 14, 2026
4d661a2
doc: update doc strings
beckermr May 14, 2026
1561dc0
fix: only use generic Exception
beckermr May 14, 2026
8c69aac
doc: update docs
beckermr May 14, 2026
6d0a66b
refactor:
beckermr May 14, 2026
2609840
doc: ensure doc string is accurate
beckermr May 14, 2026
ed7c508
fix: enable tests for image gain, area, exptime, and max_extra_noise
beckermr May 15, 2026
831990a
doc: update doc strings
beckermr May 15, 2026
42fb804
doc: update doc string
beckermr May 15, 2026
2ef7195
doc: add doc string for position exceptions
beckermr May 15, 2026
3ee5f33
fix+doc: do more error checking and more docs
beckermr May 15, 2026
48c2569
doc: fix doc string formatting
beckermr May 15, 2026
6e9c4de
Apply suggestion from @beckermr
beckermr May 15, 2026
37310d9
fix: add the rest of the types
beckermr May 15, 2026
9aa63de
Merge branch 'equinox-err-2' of https://github.com/GalSim-developers/…
beckermr May 15, 2026
88c1c7b
fix: use proper array ref
beckermr May 15, 2026
70e4c3f
Apply suggestion from @beckermr
beckermr May 15, 2026
d80f878
Merge branch 'main' into equinox-err-2
beckermr May 15, 2026
164c2c6
fix: docs done wrong
beckermr May 15, 2026
0ec07d8
Merge branch 'main' into equinox-err-2
beckermr May 15, 2026
aafb848
test: update to latest submodule
beckermr May 15, 2026
4ed8c8e
fix: raise for interpolated image init problems
beckermr May 15, 2026
7963c1f
doc: add docs for exceptions
beckermr May 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 17 additions & 31 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
CONST_TYPES,
cast_to_float,
cast_to_int,
cast_to_python_float,
check_is_int_then_cast,
ensure_hashable,
has_tracers,
implements,
)
from jax_galsim.position import Position, PositionD, PositionI

CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64)
CONST_TYPES_WITH_JAX = CONST_TYPES + (
jax.Array,
jnp.array,
jnp.int32,
jnp.int64,
jnp.float32,
jnp.float64,
)

# TODO: write extra docs for JAX changes
BOUNDS_LAX_DESCR = """\
The JAX implementation

- will not always test whether the bounds are valid
- will not always test whether BoundsI is initialized with integers

Further, the JAX implementation adds a new method, ``isStatic`` to the
``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance
Expand Down Expand Up @@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs):
f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}."
)

self.deltax = cast_to_python_float(self.deltax)
self.deltay = cast_to_python_float(self.deltay)
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
self.deltax = int(cast_to_int(self.deltax))
self.deltay = int(cast_to_int(self.deltay))

if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
if has_tracers(self._xmin) or has_tracers(self._ymin):
self._isstatic = False

# validate inputs are ints
self._xmin = check_is_int_then_cast(
self._xmin, "BoundsI must be initialized with integer values"
)
self._ymin = check_is_int_then_cast(
self._ymin, "BoundsI must be initialized with integer values"
)

if self.deltax < 1 and self.deltay < 1:
self._isdefined = False

# for simple inputs, we can check if the bounds are valid ints
if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin):
raise TypeError("BoundsI must be initialized with integer values")

if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin):
raise TypeError("BoundsI must be initialized with integer values")

if not has_tracers(self._xmin) and not has_tracers(self._ymin):
self._isstatic = True
self._xmin = int(np.trunc(self._xmin))
self._ymin = int(np.trunc(self._ymin))
else:
self._isstatic = False
self._xmin = cast_to_float(jnp.trunc(self._xmin))
self._ymin = cast_to_float(jnp.trunc(self._ymin))

if force_static and not self._isstatic:
raise RuntimeError(
"BoundsI initialized with non-static "
Expand Down
48 changes: 48 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,59 @@
from functools import partial
from typing import NamedTuple

import equinox
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_flatten

CONST_TYPES = (
float,
int,
np.ndarray,
np.int8,
np.int16,
np.int32,
np.int64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
)
CONST_TYPES_WITH_JAX = CONST_TYPES + (
jax.Array,
jnp.ndarray,
jnp.int8,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float32,
jnp.float64,
jnp.complex64,
jnp.complex128,
)


def check_is_int_then_cast(val, msg):
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
# for simple inputs, we can check direct in python
if isinstance(val, CONST_TYPES) and not has_tracers(val):
val = cast_to_python_float(val)
if val != int(val):
raise TypeError(msg)
val = int(val)
else:
# otherwise we use more opaque checking upon jit via equinox
val = equinox.error_if(
val,
val != jnp.trunc(val),
msg,
)
val = val.astype(int)

return val


def cast_numpy_array_to_native_byte_order(arr):
"""Cast an array to native byte order."""
Expand Down
36 changes: 29 additions & 7 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import namedtuple
from functools import partial

import equinox
import galsim as _galsim
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -574,12 +575,13 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None):
lax_description="""\
The JAX-GalSim version of ``drawImage``

- does not do extensive (any?) checking of the input settings.
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0``
- requires that the ``maxN`` option be a constant since PhotonArrays are allocated
with ``maxN`` photons when this option is used and arrays in JAX must have static sizes.
- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs
""",
)
def drawImage(
Expand All @@ -601,7 +603,7 @@ def drawImage(
offset=None,
n_photons=None,
rng=None,
max_extra_noise=0.0,
max_extra_noise=None,
poisson_flux=None,
sensor=None,
photon_ops=(),
Expand All @@ -626,6 +628,13 @@ def drawImage(
if image is not None and not isinstance(image, Image):
raise TypeError("image is not an Image instance", image)

# Make sure (gain, area, exptime) have valid values:
gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.")
area = equinox.error_if(jnp.array(area), area <= 0.0, "Invalid area <= 0.")
exptime = equinox.error_if(
jnp.array(exptime), exptime <= 0.0, "Invalid exptime <= 0."
)

if method == "phot" and save_photons and maxN is not None:
raise GalSimIncompatibleValuesError(
"Setting maxN is incompatible with save_photons=True"
Expand Down Expand Up @@ -659,6 +668,13 @@ def drawImage(
sensor=sensor,
n_photons=n_photons,
)
if max_extra_noise is not None:
raise GalSimIncompatibleValuesError(
"max_extra_noise is only relevant for method='phot'",
method=method,
sensor=sensor,
max_extra_noise=max_extra_noise,
)
if poisson_flux is not None:
raise GalSimIncompatibleValuesError(
"poisson_flux is only relevant for method='phot'",
Expand Down Expand Up @@ -1078,6 +1094,8 @@ def _drawKImage(

@implements(_galsim.GSObject._calculate_nphotons)
def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng):
if max_extra_noise is None:
max_extra_noise = 0.0
n_photons, g, _rng = calculate_n_photons(
self.flux,
self._flux_per_photon,
Expand All @@ -1096,17 +1114,17 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng):
lax_description="""\
The JAX-GalSim version of ``makePhot``

- does little to no error checking on the inputs
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0``
to indicate no limit on the extra noise
- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs
""",
)
def makePhot(
self,
n_photons=None,
rng=None,
max_extra_noise=0.0,
max_extra_noise=None,
poisson_flux=None,
photon_ops=(),
local_wcs=None,
Expand Down Expand Up @@ -1168,6 +1186,8 @@ def makePhot(
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0``
- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs
- requires that the ``maxN`` option must be a constant
""",
)
Expand All @@ -1178,7 +1198,7 @@ def drawPhot(
add_to_image=False,
n_photons=None,
rng=None,
max_extra_noise=0.0,
max_extra_noise=None,
poisson_flux=None,
sensor=None,
photon_ops=(),
Expand Down Expand Up @@ -1208,6 +1228,8 @@ def drawPhot(
elif not isinstance(sensor, Sensor):
raise TypeError("The sensor provided is not a Sensor instance")

gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.")

if n_photons is not None:
# n_photons is the length of an array so it is a python int and
# and thus a constant wrt to JIT
Expand Down
13 changes: 7 additions & 6 deletions jax_galsim/integ.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import equinox
import galsim as _galsim
import jax.lax
import jax.numpy as jnp
Expand All @@ -17,8 +18,8 @@

- This implementation is different than the one in GalSim and lacks some features that
greatly enhance galsim's accuracy.
- The JAX-GalSim implementation returns NaN on error/non-convergence instead of
rasing an exception.
- The JAX-GalSim implementation raises a generic ``Exception`` on error/non-convergence
instead of rasing a ``galsim.GalSimError`` exception.
"""
),
)
Expand Down Expand Up @@ -72,8 +73,8 @@ def _base_integration():
_base_integration,
)

return jax.lax.cond(
status == 0,
lambda: val,
lambda: jnp.nan,
return equinox.error_if(
val,
status != 0,
"`jax_galsim.int1d` failed to converge!",
)
16 changes: 16 additions & 0 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
from functools import partial

import equinox
import galsim as _galsim
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -59,6 +60,8 @@ def __dir__(cls):
- the pad_image options
- depixelize
- most of the bounds checks, type checks, and dtype casts done by galsim
- raises a generic ``Exception`` instead of a more specific one for some
initialization errors
"""


Expand Down Expand Up @@ -118,6 +121,11 @@ def __init__(
elif not isinstance(image, Image):
raise TypeError("Supplied image must be an Image or file name")

if not (image.dtype == jnp.float32 or image.dtype == jnp.float64):
raise GalSimValueError(
"Interpolated images must use a float-type image.", image.dtype
)

self._jax_children = (
image,
dict(
Expand Down Expand Up @@ -506,6 +514,14 @@ def __init__(
image=self._jax_children[0],
)

if calculate_stepk or calculate_maxk or flux is not None:
image = equinox.error_if(
image,
image.array.sum() == 0.0,
"This input image has zero total flux. It does not define a "
"valid surface brightness profile.",
)

@doc_inherit
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
Expand Down
21 changes: 16 additions & 5 deletions jax_galsim/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jax_galsim.core.utils import (
cast_to_float,
cast_to_int,
check_is_int_then_cast,
ensure_hashable,
implements,
)
Expand Down Expand Up @@ -208,15 +208,26 @@ def _check_scalar(self, other, op):
raise TypeError("Can only %s a PositionD by float values" % op)


@implements(_galsim.PositionI)
@implements(
_galsim.PositionI,
lax_description=(
"The ``jax_galsim.PositionI`` class will raise generic "
"``Exception``s instead of a more specific exception for invalid "
"inputs."
),
)
@register_pytree_node_class
class PositionI(Position):
def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)

# inputs must be ints
self.x = cast_to_int(self.x)
self.y = cast_to_int(self.y)
# validate input is int
self.x = check_is_int_then_cast(
self.x, "PositionI must be initialized with integer values"
)
self.y = check_is_int_then_cast(
self.y, "PositionI must be initialized with integer values"
)

def _check_scalar(self, other, op):
try:
Expand Down
Loading