From 99e085b0cf073800f96b4cd0315058fe90026fe4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 3 Apr 2026 08:20:18 -0500 Subject: [PATCH] ENH: allow running on torch.mps device, which does not have float64/complex128 --- array_api_tests/hypothesis_helpers.py | 4 ++-- array_api_tests/test_signatures.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 4669421c..e6aa1da5 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -341,8 +341,8 @@ def finite_matrices(draw, shape=matrix_shapes(), dtype=floating_dtypes, bound=No rtol_shared_matrix_shapes = shared(matrix_shapes()) -# Should we set a max_value here? -_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0) +# Arbitrary max_value for rtols, to avoid overflows when float64 is not available +_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0, max_value=42) rtols = one_of(floats(**_rtol_float_kw), arrays(dtype=real_floating_dtypes, shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]), diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index ab1440cf..defaedbb 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -156,7 +156,10 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str: array_argnames -= set(func_to_specified_arg_exprs[func_name].keys()) if len(array_argnames) > 0: in_dtypes = dh.func_in_dtypes[func_name] - for dtype_name in ["float64", "bool", "int64", "complex128"]: + # use "float64" if available, "float32" otherwise; ditto for complex128/complex64 + float_name = dh.dtype_to_name[dh.widest_real_dtype] + cmplx_name = dh.dtype_to_name[dh.widest_complex_dtype] + for dtype_name in [float_name, "bool", "int64", cmplx_name]: # We try float64 first because uninspectable numerical functions # tend to support float inputs first-and-foremost (i.e. PyTorch) try: