diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 074a762e..59ef2cca 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError @@ -10,9 +11,11 @@ @register_pytree_node_class @implements( _galsim.Shear, - lax_description="""\ -The jax_galsim implementation of ``Shear`` does not perform range checking of the \ -shear (e.g., ``|g| <= 1``) upon construction.""", + lax_description=( + "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " + "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " + "exceptions instead of ``galsim.GalSimRangeError`` exceptions." + ), ) class Shear(object): def __init__(self, *args, **kwargs): @@ -45,15 +48,25 @@ def __init__(self, *args, **kwargs): # g1,g2 elif "g1" in kwargs or "g2" in kwargs: - g1 = kwargs.pop("g1", 0.0) - g2 = kwargs.pop("g2", 0.0) + g1 = jnp.array(kwargs.pop("g1", 0.0)) + g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 + self._g = equinox.error_if( + self._g, + jnp.abs(self._g) > 1.0, + "Requested shear exceeds 1.", + ) # e1,e2 elif "e1" in kwargs or "e2" in kwargs: - e1 = kwargs.pop("e1", 0.0) - e2 = kwargs.pop("e2", 0.0) + e1 = jnp.array(kwargs.pop("e1", 0.0)) + e2 = jnp.array(kwargs.pop("e2", 0.0)) absesq = e1**2 + e2**2 + absesq = equinox.error_if( + absesq, + absesq > 1.0, + "Requested distortion exceeds 1.", + ) self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 @@ -75,7 +88,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - g = kwargs.pop("g") + g = jnp.array(kwargs.pop("g")) + g = equinox.error_if( + g, + g > 1 or g < 0, + "Requested |shear| is outside [0,1].", + ) self._g = g * jnp.exp(2j * beta.rad) # e,beta @@ -89,7 +107,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - e = kwargs.pop("e") + e = jnp.array(kwargs.pop("e")) + e = equinox.error_if( + e, + (e > 1) | (e < 0), + "Requested distortion is outside [0,1].", + ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta @@ -103,7 +126,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - eta = kwargs.pop("eta") + eta = jnp.array(kwargs.pop("eta")) + eta = equinox.error_if( + eta, + eta < 0, + "Requested eta is below 0.", + ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta @@ -117,7 +145,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - q = kwargs.pop("q") + q = jnp.array(kwargs.pop("q")) + q = equinox.error_if( + q, + (q <= 0) | (q > 1), + "Cannot use requested axis ratio.", + ) eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) diff --git a/pyproject.toml b/pyproject.toml index 2ad2d9aa..f8ca4441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "The modular galaxy image simulation toolkit, but in JAX" dynamic = ["version"] license = { file = "LICENSE" } readme = "README.md" -dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax", "equinox"] [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] diff --git a/tests/GalSim b/tests/GalSim index a5afbf51..e5ee4016 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit a5afbf510dc747f5667f61c742b9dd3630643988 +Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94