From ec8023712f5915aa90183bb7900b15201c5f0814 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 10:55:36 -0600 Subject: [PATCH 1/9] Modernize type annotations, require __future__ annotations --- arraycontext/__init__.py | 1 + arraycontext/container/dataclass.py | 1 + arraycontext/context.py | 20 ++++++++------ arraycontext/fake_numpy.py | 3 +++ arraycontext/impl/__init__.py | 3 +++ arraycontext/impl/jax/__init__.py | 2 ++ arraycontext/impl/jax/fake_numpy.py | 3 +++ arraycontext/impl/numpy/fake_numpy.py | 3 +++ arraycontext/impl/pyopencl/fake_numpy.py | 3 +++ .../impl/pyopencl/taggable_cl_array.py | 7 ++--- arraycontext/impl/pytato/compile.py | 17 +++++++----- arraycontext/impl/pytato/fake_numpy.py | 3 +++ arraycontext/impl/pytato/utils.py | 5 +++- arraycontext/loopy.py | 2 ++ arraycontext/metadata.py | 1 + arraycontext/pytest.py | 2 ++ arraycontext/transform_metadata.py | 2 ++ arraycontext/version.py | 2 ++ doc/make_numpy_coverage_table.py | 1 + pyproject.toml | 6 +++-- test/test_arraycontext.py | 12 +++++---- test/test_pytato_arraycontext.py | 2 ++ test/test_utils.py | 27 ++++++++++++------- 23 files changed, 92 insertions(+), 36 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index c40117e8..1c2ae451 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -2,6 +2,7 @@ An array context is an abstraction that helps you dispatch between multiple implementations of :mod:`numpy`-like :math:`n`-dimensional arrays. """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index ae9ab486..9495b344 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -4,6 +4,7 @@ .. currentmodule:: arraycontext .. autofunction:: dataclass_array_container """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/context.py b/arraycontext/context.py index 398f8aa3..f6dc70bf 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -134,6 +134,9 @@ :canonical: arraycontext.ArrayOrContainerOrScalarT """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -160,7 +163,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union from warnings import warn import numpy as np @@ -204,14 +207,14 @@ def size(self) -> int: ... @property - def dtype(self) -> "np.dtype[Any]": + def dtype(self) -> np.dtype[Any]: ... # Covering all the possible index variations is hard and (kind of) futile. # If you'd like to see how, try changing the Any to # AxisIndex = slice | int | "Array" # Index = AxisIndex |tuple[AxisIndex] - def __getitem__(self, index: Any) -> "Array": + def __getitem__(self, index: Any) -> Array: ... @@ -220,9 +223,10 @@ def __getitem__(self, index: Any) -> "Array": ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrContainer = Union[Array, "ArrayContainer"] +ArrayOrScalar: TypeAlias = "Array | ScalarLike" +ArrayOrContainer: TypeAlias = "Array | ArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) -ArrayOrContainerOrScalar = Union[Array, "ArrayContainer", ScalarLike] +ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" ArrayOrContainerOrScalarT = TypeVar( "ArrayOrContainerOrScalarT", bound=ArrayOrContainerOrScalar) @@ -295,7 +299,7 @@ def __hash__(self) -> int: def zeros(self, shape: int | tuple[int, ...], - dtype: "np.dtype[Any]") -> Array: + dtype: np.dtype[Any]) -> Array: warn(f"{type(self).__name__}.zeros is deprecated and will stop " "working in 2025. Use actx.np.zeros instead.", DeprecationWarning, stacklevel=2) @@ -329,7 +333,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, - t_unit: "loopy.TranslationUnit", + t_unit: loopy.TranslationUnit, **kwargs: Any) -> dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. @@ -414,7 +418,7 @@ def tag_axis(self, @memoize_method def _get_einsum_prg(self, spec: str, arg_names: tuple[str, ...], - tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": + tagged: ToTagSetConvertible) -> loopy.TranslationUnit: import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 58215617..6c5fb158 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/__init__.py b/arraycontext/impl/__init__.py index ac0e47a3..53030a2b 100644 --- a/arraycontext/impl/__init__.py +++ b/arraycontext/impl/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 0b6cd727..a70cbaa2 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -2,6 +2,8 @@ .. currentmodule:: arraycontext .. autoclass:: EagerJAXArrayContext """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 094e8cf2..1a4e790f 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index f345edc9..582ccda9 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index ae340ca9..4b96e475 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -2,6 +2,9 @@ .. currentmodule:: arraycontext .. autoclass:: PyOpenCLArrayContext """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 7de76113..39f92586 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -4,6 +4,7 @@ .. autofunction:: to_tagged_cl_array """ +from __future__ import annotations from dataclasses import dataclass from typing import Any @@ -25,7 +26,7 @@ class Axis(Taggable): tags: frozenset[Tag] - def _with_new_tags(self, tags: frozenset[Tag]) -> "Axis": + def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: from dataclasses import replace return replace(self, tags=tags) @@ -109,12 +110,12 @@ def copy(self, queue=cla._copy_queue): return type(self)(None, tags=self.tags, axes=self.axes, **_unwrap_cl_array(ary)) - def _with_new_tags(self, tags: frozenset[Tag]) -> "TaggableCLArray": + def _with_new_tags(self, tags: frozenset[Tag]) -> TaggableCLArray: return type(self)(None, tags=tags, axes=self.axes, **_unwrap_cl_array(self)) def with_tagged_axis(self, iaxis: int, - tags: ToTagSetConvertible) -> "TaggableCLArray": + tags: ToTagSetConvertible) -> TaggableCLArray: """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 952761bf..e77c1091 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -5,6 +5,9 @@ .. autoclass:: CompiledFunction .. autoclass:: FromArrayContextCompile """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -261,7 +264,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor], - "CompiledFunction"] = field(default_factory=lambda: {}) + CompiledFunction] = field(default_factory=lambda: {}) # {{{ abstract interface @@ -270,11 +273,11 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: raise NotImplementedError @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: raise NotImplementedError # }}} @@ -383,11 +386,11 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: return CompiledPyOpenCLFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: return CompiledPyOpenCLFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -482,11 +485,11 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays): class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: return CompiledJAXFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: return CompiledJAXFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 0692eb7e..d7072855 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2d624d9a..c031e29b 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __doc__ = """ .. autofunction:: transfer_from_numpy .. autofunction:: transfer_to_numpy @@ -127,7 +130,7 @@ def __init__(self, limit_arg_size_nbytes: int) -> None: self.limit_arg_size_nbytes = limit_arg_size_nbytes @memoize_method - def get_loopy_target(self) -> "lp.PyOpenCLTarget": + def get_loopy_target(self) -> lp.PyOpenCLTarget: from loopy import PyOpenCLTarget return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index da717846..d6f90783 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -2,6 +2,8 @@ .. currentmodule:: arraycontext .. autofunction:: make_loopy_program """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/metadata.py b/arraycontext/metadata.py index 756999f7..5f0633f1 100644 --- a/arraycontext/metadata.py +++ b/arraycontext/metadata.py @@ -1,6 +1,7 @@ """ .. autoclass:: NameHint """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index f1f62a71..760fc103 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -6,6 +6,8 @@ .. autofunction:: pytest_generate_tests_for_array_contexts """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/transform_metadata.py b/arraycontext/transform_metadata.py index 2e0942e9..ccfcfba3 100644 --- a/arraycontext/transform_metadata.py +++ b/arraycontext/transform_metadata.py @@ -4,6 +4,8 @@ .. autoclass:: CommonSubexpressionTag .. autoclass:: ElementwiseMapKernelTag """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/version.py b/arraycontext/version.py index d33045f0..90305a22 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from importlib import metadata diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py index 19d09d4a..1a5782e4 100644 --- a/doc/make_numpy_coverage_table.py +++ b/doc/make_numpy_coverage_table.py @@ -13,6 +13,7 @@ python make_numpy_support_table.py numpy_coverage.rst """ +from __future__ import annotations import pathlib diff --git a/pyproject.toml b/pyproject.toml index 0daaa214..a9c1df48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,6 @@ extend-ignore = [ "E221", # multiple spaces before operator "E226", # missing whitespace around arithmetic operator "E402", # module-level import not at top of file - "UP006", # updated annotations due to __future__ import - "UP007", # updated annotations due to __future__ import ] [tool.ruff.lint.flake8-quotes] @@ -101,6 +99,10 @@ known-local-folder = [ "arraycontext", ] lines-after-imports = 2 +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"doc/conf.py" = ["I002"] [tool.mypy] python_version = "3.10" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 47d8e941..050bfc8d 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = "Copyright (C) 2020-21 University of Illinois Board of Trustees" __license__ = """ @@ -23,7 +26,6 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Union import numpy as np import pytest @@ -216,9 +218,9 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @dataclass(frozen=True) class MyContainer: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray __array_ufunc__ = None @@ -241,9 +243,9 @@ def array_context(self): @dataclass(frozen=True) class MyContainerDOFBcast: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray @property def array_context(self): diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index a14df50f..a4050380 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -1,4 +1,6 @@ """ PytatoArrayContext specific tests""" +from __future__ import annotations + __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" diff --git a/test/test_utils.py b/test/test_utils.py index db9ed825..a6aa2714 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,7 @@ """Testing for internal utilities.""" +from __future__ import annotations + +from typing import cast __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" @@ -49,7 +52,6 @@ def test_pt_actx_key_stringification_uniqueness(): def test_dataclass_array_container() -> None: from dataclasses import dataclass, field - from typing import Optional, Tuple # noqa: UP035 from arraycontext import Array, dataclass_array_container @@ -58,7 +60,7 @@ def test_dataclass_array_container() -> None: @dataclass class ArrayContainerWithStringTypes: x: np.ndarray - y: "np.ndarray" + y: np.ndarray with pytest.raises(TypeError, match="String annotation on field 'y'"): # NOTE: cannot have string annotations in container @@ -71,7 +73,7 @@ class ArrayContainerWithStringTypes: @dataclass class ArrayContainerWithOptional: x: np.ndarray - y: Optional[np.ndarray] + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) @@ -84,7 +86,7 @@ class ArrayContainerWithOptional: @dataclass class ArrayContainerWithTuple: x: Array - y: Tuple[Array, Array] + y: tuple[Array, Array] with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): dataclass_array_container(ArrayContainerWithTuple) @@ -131,7 +133,6 @@ class ArrayContainerWithArray: def test_dataclass_container_unions() -> None: from dataclasses import dataclass - from typing import Union from arraycontext import Array, dataclass_array_container @@ -140,7 +141,7 @@ def test_dataclass_container_unions() -> None: @dataclass class ArrayContainerWithUnion: x: np.ndarray - y: Union[np.ndarray, Array] + y: np.ndarray | Array dataclass_array_container(ArrayContainerWithUnion) @@ -158,7 +159,7 @@ class ArrayContainerWithUnionAlt: @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray - y: Union[np.ndarray, float] + y: np.ndarray | float with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): # NOTE: float is not an ArrayContainer, so y should fail @@ -217,9 +218,15 @@ class SomeOtherContainer: extent: float rng = np.random.default_rng(seed=42) - a = ArrayWrapper(ary=rng.random(10)) - d = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) - c = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + a = ArrayWrapper(ary=cast(Array, rng.random(10))) + d = SomeContainer( + points=cast(Array, rng.random((2, 10))), + radius=rng.random(), + centers=a) + c = SomeContainer( + points=cast(Array, rng.random((2, 10))), + radius=rng.random(), + centers=a) ary = SomeOtherContainer( disk=d, circle=c, has_disk=True, From aca58996949b4ccec06ded5108f955a023c95fd0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 10:56:29 -0600 Subject: [PATCH 2/9] is_array_container: Any -> object --- arraycontext/container/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 6c4fb671..75eee2a6 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -81,7 +81,7 @@ from collections.abc import Hashable, Sequence from functools import singledispatch -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. @@ -219,7 +219,7 @@ def is_array_container_type(cls: type) -> bool: is not serialize_container.__wrapped__)) # type:ignore[attr-defined] -def is_array_container(ary: Any) -> bool: +def is_array_container(ary: object) -> bool: """ :returns: *True* if the instance *ary* has a registered implementation of :func:`serialize_container`. From 65dd08e8e9571050e988e41bde21566c123d1a04 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 10:58:35 -0600 Subject: [PATCH 3/9] Array: require some basic arithmetic --- arraycontext/context.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/arraycontext/context.py b/arraycontext/context.py index f6dc70bf..60277970 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -167,6 +167,7 @@ from warnings import warn import numpy as np +from typing_extensions import Self from pytools import memoize_method from pytools.tag import ToTagSetConvertible @@ -196,6 +197,8 @@ class Array(Protocol): .. attribute:: size .. attribute:: dtype .. attribute:: __getitem__ + + In addition, arrays are expected to support basic arithmetic. """ @property @@ -217,8 +220,21 @@ def dtype(self) -> np.dtype[Any]: def __getitem__(self, index: Any) -> Array: ... + # some basic arithmetic that's supposed to work + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... + def __add__(self, other: Self | ScalarLike) -> Self: ... + def __radd__(self, other: Self | ScalarLike) -> Self: ... + def __sub__(self, other: Self | ScalarLike) -> Self: ... + def __rsub__(self, other: Self | ScalarLike) -> Self: ... + def __mul__(self, other: Self | ScalarLike) -> Self: ... + def __rmul__(self, other: Self | ScalarLike) -> Self: ... + def __truediv__(self, other: Self | ScalarLike) -> Self: ... + def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... + # deprecated, use ScalarLike instead +ScalarLike: TypeAlias = int | float | complex | np.generic Scalar = ScalarLike From f949702a70767fee883df40c9d712c35cf75e6f2 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 12:59:36 -0600 Subject: [PATCH 4/9] dataclass_array_container: support string annotations --- arraycontext/container/dataclass.py | 55 +++++-- pyproject.toml | 4 + test/test_arraycontext.py | 182 +---------------------- test/test_utils.py | 22 +-- test/testlib.py | 216 ++++++++++++++++++++++++++++ 5 files changed, 270 insertions(+), 209 deletions(-) create mode 100644 test/testlib.py diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 9495b344..5ff9dfd8 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,6 +31,7 @@ THE SOFTWARE. """ +from collections.abc import Mapping, Sequence from dataclasses import Field, fields, is_dataclass from typing import Union, get_args, get_origin @@ -58,13 +59,21 @@ def dataclass_array_container(cls: type) -> type: * a :class:`typing.Union` of array containers is considered an array container. * other type annotations, e.g. :class:`typing.Optional`, are not considered array containers, even if they wrap one. + + .. note:: + + When type annotations are strings (e.g. because of + ``from __future__ import annotations``), + this function relies on :func:`inspect.get_annotations` + (with ``eval_str=True``) to obtain type annotations. This + means that *cls* must live in a module that is importable. """ from types import GenericAlias, UnionType assert is_dataclass(cls) - def is_array_field(f: Field) -> bool: + def is_array_field(f: Field, field_type: type) -> bool: # NOTE: unions of array containers are treated separately to handle # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as # they can work seamlessly with arithmetic and traversal. @@ -77,17 +86,17 @@ def is_array_field(f: Field) -> bool: # # This is not set in stone, but mostly driven by current usage! - origin = get_origin(f.type) + origin = get_origin(field_type) # NOTE: `UnionType` is returned when using `Type1 | Type2` if origin in (Union, UnionType): - if all(is_array_type(arg) for arg in get_args(f.type)): + if all(is_array_type(arg) for arg in get_args(field_type)): return True else: raise TypeError( f"Field '{f.name}' union contains non-array container " "arguments. All arguments must be array containers.") - if isinstance(f.type, str): + if isinstance(field_type, str): raise TypeError( f"String annotation on field '{f.name}' not supported. " "(this may be due to 'from __future__ import annotations')") @@ -105,33 +114,49 @@ def is_array_field(f: Field) -> bool: _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm): + if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " - f"'{f.type!r}'") + f"'{field_type!r}'") - if not isinstance(f.type, type): + if not isinstance(field_type, type): raise TypeError( f"Field '{f.name}' not an instance of 'type': " - f"'{f.type!r}'") + f"'{field_type!r}'") + + return is_array_type(field_type) + + from inspect import get_annotations - return is_array_type(f.type) + array_fields: list[Field] = [] + non_array_fields: list[Field] = [] + cls_ann: Mapping[str, type] | None = None + for field in fields(cls): + field_type_or_str = field.type + if isinstance(field_type_or_str, str): + if cls_ann is None: + cls_ann = get_annotations(cls, eval_str=True) + field_type = cls_ann[field.name] + else: + field_type = field_type_or_str - from pytools import partition - array_fields, non_array_fields = partition(is_array_field, fields(cls)) + if is_array_field(field, field_type): + array_fields.append(field) + else: + non_array_fields.append(field) if not array_fields: raise ValueError(f"'{cls}' must have fields with array container type " "in order to use the 'dataclass_array_container' decorator") - return inject_dataclass_serialization(cls, array_fields, non_array_fields) + return _inject_dataclass_serialization(cls, array_fields, non_array_fields) -def inject_dataclass_serialization( +def _inject_dataclass_serialization( cls: type, - array_fields: tuple[Field, ...], - non_array_fields: tuple[Field, ...]) -> type: + array_fields: Sequence[Field], + non_array_fields: Sequence[Field]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. diff --git a/pyproject.toml b/pyproject.toml index a9c1df48..d715981a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,12 +97,16 @@ known-first-party = [ ] known-local-folder = [ "arraycontext", + "testlib", ] lines-after-imports = 2 required-imports = ["from __future__ import annotations"] [tool.ruff.lint.per-file-ignores] "doc/conf.py" = ["I002"] +# To avoid a requirement of array container definitions being someplace importable +# from @dataclass_array_container. +"test/test_utils.py" = ["I002"] [tool.mypy] python_version = "3.10" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 050bfc8d..ab263304 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,18 +34,14 @@ from pytools.tag import Tag from arraycontext import ( - ArrayContainer, - ArrayContext, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, - deserialize_container, pytest_generate_tests_for_array_contexts, serialize_container, tag_axes, - with_array_context, with_container_arithmetic, ) from arraycontext.pytest import ( @@ -55,6 +51,7 @@ _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, ) +from testlib import DOFArray, MyContainer, MyContainerDOFBcast, Velocity2D logger = logging.getLogger(__name__) @@ -116,147 +113,10 @@ def _acf(): # }}} -# {{{ stand-in DOFArray implementation - -@with_container_arithmetic( - bcasts_across_obj_array=True, - bitwise=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -class DOFArray: - def __init__(self, actx, data): - if not (actx is None or isinstance(actx, ArrayContext)): - raise TypeError("actx must be of type ArrayContext") - - if not isinstance(data, tuple): - raise TypeError("'data' argument must be a tuple") - - self.array_context = actx - self.data = data - - # prevent numpy broadcasting - __array_ufunc__ = None - - def __bool__(self): - if len(self) == 1 and self.data[0].size == 1: - return bool(self.data[0]) - - raise ValueError( - "The truth value of an array with more than one element is " - "ambiguous. Use actx.np.any(x) or actx.np.all(x)") - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - def __repr__(self): - return f"DOFArray({self.data!r})" - - @classmethod - def _serialize_init_arrays_code(cls, instance_name): - return {"_": - (f"{instance_name}_i", f"{instance_name}")} - - @classmethod - def _deserialize_init_arrays_code(cls, template_instance_name, args): - (_, arg), = args.items() - # Why tuple([...])? https://stackoverflow.com/a/48592299 - return (f"{template_instance_name}.array_context, tuple([{arg}])") - - @property - def size(self): - return sum(ary.size for ary in self.data) - - @property - def real(self): - return DOFArray(self.array_context, tuple(subary.real for subary in self)) - - @property - def imag(self): - return DOFArray(self.array_context, tuple(subary.imag for subary in self)) - - -@serialize_container.register(DOFArray) -def _serialize_dof_container(ary: DOFArray): - return list(enumerate(ary.data)) - - -@deserialize_container.register(DOFArray) -# https://github.com/python/mypy/issues/13040 -def _deserialize_dof_container( # type: ignore[misc] - template, iterable): - def _raise_index_inconsistency(i, stream_i): - raise ValueError( - "out-of-sequence indices supplied in DOFArray deserialization " - f"(expected {i}, received {stream_i})") - - return type(template)( - template.array_context, - data=tuple( - v if i == stream_i else _raise_index_inconsistency(i, stream_i) - for i, (stream_i, v) in enumerate(iterable))) - - -@with_array_context.register(DOFArray) -# https://github.com/python/mypy/issues/13040 -def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: ignore[misc] - return type(ary)(actx, ary.data) - -# }}} - - -# {{{ nested containers - -@with_container_arithmetic(bcasts_across_obj_array=False, - eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class MyContainer: - name: str - mass: DOFArray | np.ndarray - momentum: np.ndarray - enthalpy: DOFArray | np.ndarray - - __array_ufunc__ = None - - @property - def array_context(self): - if isinstance(self.mass, np.ndarray): - return next(iter(self.mass)).array_context - else: - return self.mass.array_context - - -@with_container_arithmetic( - bcasts_across_obj_array=False, - bcast_container_types=(DOFArray, np.ndarray), - matmul=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class MyContainerDOFBcast: - name: str - mass: DOFArray | np.ndarray - momentum: np.ndarray - enthalpy: DOFArray | np.ndarray - - @property - def array_context(self): - if isinstance(self.mass, np.ndarray): - return next(iter(self.mass)).array_context - else: - return self.mass.array_context - - def _get_test_containers(actx, ambient_dim=2, shapes=50_000): from numbers import Number + + from testlib import DOFArray, MyContainer, MyContainerDOFBcast if isinstance(shapes, Number | tuple): shapes = [shapes] @@ -286,8 +146,6 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs, bcast_dataclass_of_dofs) -# }}} - # {{{ assert_close_to_numpy* @@ -1224,21 +1082,6 @@ def test_norm_ord_none(actx_factory, ndim): # {{{ test_actx_compile helpers -@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) -@dataclass_array_container -@dataclass(frozen=True) -class Velocity2D: - u: ArrayContainer - v: ArrayContainer - array_context: ArrayContext - - -@with_array_context.register(Velocity2D) -# https://github.com/python/mypy/issues/13040 -def _with_actx_velocity_2d(ary, actx): # type: ignore[misc] - return type(ary)(ary.u, ary.v, actx) - - def scale_and_orthogonalize(alpha, vel): from arraycontext import rec_map_array_container actx = vel.array_context @@ -1353,25 +1196,8 @@ def test_container_equality(actx_factory): # {{{ test_no_leaf_array_type_broadcasting -@with_container_arithmetic( - bcasts_across_obj_array=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class Foo: - u: DOFArray - - # prevent numpy arithmetic from taking precedence - __array_ufunc__ = None - - @property - def array_context(self): - return self.u.array_context - - def test_no_leaf_array_type_broadcasting(actx_factory): + from testlib import Foo # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() diff --git a/test/test_utils.py b/test/test_utils.py index a6aa2714..807d652d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,7 +1,8 @@ """Testing for internal utilities.""" -from __future__ import annotations -from typing import cast +# Do not add +# from __future__ import annotations +# to allow the non-string annotations below to work. __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" @@ -26,6 +27,7 @@ THE SOFTWARE. """ import logging +from typing import Optional, cast import numpy as np import pytest @@ -55,25 +57,13 @@ def test_dataclass_array_container() -> None: from arraycontext import Array, dataclass_array_container - # {{{ string fields - - @dataclass - class ArrayContainerWithStringTypes: - x: np.ndarray - y: np.ndarray - - with pytest.raises(TypeError, match="String annotation on field 'y'"): - # NOTE: cannot have string annotations in container - dataclass_array_container(ArrayContainerWithStringTypes) - - # }}} - # {{{ optional fields @dataclass class ArrayContainerWithOptional: x: np.ndarray - y: np.ndarray | None + # Deliberately left as Optional to test compatibility. + y: Optional[np.ndarray] # noqa: UP007 with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) diff --git a/test/testlib.py b/test/testlib.py new file mode 100644 index 00000000..3f085207 --- /dev/null +++ b/test/testlib.py @@ -0,0 +1,216 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2020-21 University of Illinois Board of Trustees" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from dataclasses import dataclass + +import numpy as np + +from arraycontext import ( + ArrayContainer, + ArrayContext, + dataclass_array_container, + deserialize_container, + serialize_container, + with_array_context, + with_container_arithmetic, +) + + +# Containers live here, because in order for get_annotations to work, they must +# live somewhere importable. +# See https://docs.python.org/3.12/library/inspect.html#inspect.get_annotations + + +# {{{ stand-in DOFArray implementation + +@with_container_arithmetic( + bcasts_across_obj_array=True, + bitwise=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +class DOFArray: + def __init__(self, actx, data): + if not (actx is None or isinstance(actx, ArrayContext)): + raise TypeError("actx must be of type ArrayContext") + + if not isinstance(data, tuple): + raise TypeError("'data' argument must be a tuple") + + self.array_context = actx + self.data = data + + # prevent numpy broadcasting + __array_ufunc__ = None + + def __bool__(self): + if len(self) == 1 and self.data[0].size == 1: + return bool(self.data[0]) + + raise ValueError( + "The truth value of an array with more than one element is " + "ambiguous. Use actx.np.any(x) or actx.np.all(x)") + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def __repr__(self): + return f"DOFArray({self.data!r})" + + @classmethod + def _serialize_init_arrays_code(cls, instance_name): + return {"_": + (f"{instance_name}_i", f"{instance_name}")} + + @classmethod + def _deserialize_init_arrays_code(cls, template_instance_name, args): + (_, arg), = args.items() + # Why tuple([...])? https://stackoverflow.com/a/48592299 + return (f"{template_instance_name}.array_context, tuple([{arg}])") + + @property + def size(self): + return sum(ary.size for ary in self.data) + + @property + def real(self): + return DOFArray(self.array_context, tuple(subary.real for subary in self)) + + @property + def imag(self): + return DOFArray(self.array_context, tuple(subary.imag for subary in self)) + + +@serialize_container.register(DOFArray) +def _serialize_dof_container(ary: DOFArray): + return list(enumerate(ary.data)) + + +@deserialize_container.register(DOFArray) +# https://github.com/python/mypy/issues/13040 +def _deserialize_dof_container( # type: ignore[misc] + template, iterable): + def _raise_index_inconsistency(i, stream_i): + raise ValueError( + "out-of-sequence indices supplied in DOFArray deserialization " + f"(expected {i}, received {stream_i})") + + return type(template)( + template.array_context, + data=tuple( + v if i == stream_i else _raise_index_inconsistency(i, stream_i) + for i, (stream_i, v) in enumerate(iterable))) + + +@with_array_context.register(DOFArray) +# https://github.com/python/mypy/issues/13040 +def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: ignore[misc] + return type(ary)(actx, ary.data) + +# }}} + + +# {{{ nested containers + +@with_container_arithmetic(bcasts_across_obj_array=False, + eq_comparison=False, rel_comparison=False, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainer: + name: str + mass: DOFArray | np.ndarray + momentum: np.ndarray + enthalpy: DOFArray | np.ndarray + + __array_ufunc__ = None + + @property + def array_context(self): + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context + + +@with_container_arithmetic( + bcasts_across_obj_array=False, + bcast_container_types=(DOFArray, np.ndarray), + matmul=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainerDOFBcast: + name: str + mass: DOFArray | np.ndarray + momentum: np.ndarray + enthalpy: DOFArray | np.ndarray + + @property + def array_context(self): + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context + +# }}} + + +@with_container_arithmetic( + bcasts_across_obj_array=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class Foo: + u: DOFArray + + # prevent numpy arithmetic from taking precedence + __array_ufunc__ = None + + @property + def array_context(self): + return self.u.array_context + + +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class Velocity2D: + u: ArrayContainer + v: ArrayContainer + array_context: ArrayContext + + +@with_array_context.register(Velocity2D) +# https://github.com/python/mypy/issues/13040 +def _with_actx_velocity_2d(ary, actx): # type: ignore[misc] + return type(ary)(ary.u, ary.v, actx) From 7dad0ceb0f35e9683b446250251f4c7777a0dbe2 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 13:04:34 -0600 Subject: [PATCH 5/9] Introduce ArithArrayContainer --- arraycontext/__init__.py | 10 ++++++++ arraycontext/container/__init__.py | 27 +++++++++++++++++++- arraycontext/context.py | 40 ++++++++++++++---------------- doc/conf.py | 5 ++++ pyproject.toml | 2 ++ 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1c2ae451..674a229d 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -30,6 +30,7 @@ """ from .container import ( + ArithArrayContainer, ArrayContainer, ArrayContainerT, NotAnArrayContainerError, @@ -73,6 +74,10 @@ from .context import ( Array, ArrayContext, + ArrayOrArithContainer, + ArrayOrArithContainerOrScalar, + ArrayOrArithContainerOrScalarT, + ArrayOrArithContainerT, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, @@ -96,10 +101,15 @@ __all__ = ( + "ArithArrayContainer", "Array", "ArrayContainer", "ArrayContainerT", "ArrayContext", + "ArrayOrArithContainer", + "ArrayOrArithContainerOrScalar", + "ArrayOrArithContainerOrScalarT", + "ArrayOrArithContainerT", "ArrayOrContainer", "ArrayOrContainerOrScalar", "ArrayOrContainerOrScalarT", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 75eee2a6..afe4a406 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -4,6 +4,7 @@ .. currentmodule:: arraycontext .. autoclass:: ArrayContainer +.. autoclass:: ArithArrayContainer .. class:: ArrayContainerT A type variable with a lower bound of :class:`ArrayContainer`. @@ -87,8 +88,9 @@ # what 'np' is. import numpy import numpy as np +from typing_extensions import Self -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, ArrayOrScalar if TYPE_CHECKING: @@ -145,6 +147,29 @@ class ArrayContainer(Protocol): # that are container-typed. +class ArithArrayContainer(ArrayContainer, Protocol): + """ + A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic. + """ + + # This is loose and permissive, assuming that any array can be added + # to any container. The alternative would be to plaster type-ignores + # on all those uses. Achieving typing precision on what broadcasting is + # allowable seems like a huge endeavor and is likely not feasible without + # a mypy plugin. Maybe some day? -AK, November 2024 + + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... + def __add__(self, other: ArrayOrScalar | Self) -> Self: ... + def __radd__(self, other: ArrayOrScalar | Self) -> Self: ... + def __sub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... + + ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) diff --git a/arraycontext/context.py b/arraycontext/context.py index 60277970..0d0595c3 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -87,37 +87,30 @@ .. autoclass:: Array -.. class:: ArrayT +.. autodata:: ArrayT A type variable with a lower bound of :class:`Array`. -.. class:: ScalarLike +.. autodata:: ScalarLike A type annotation for scalar types commonly usable with arrays. See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`. -.. class:: ArrayOrContainer +.. autodata:: ArrayOrContainer -.. class:: ArrayOrContainerT +.. autodata:: ArrayOrContainerT A type variable with a lower bound of :class:`ArrayOrContainer`. -.. class:: ArrayOrContainerOrScalar +.. autodata:: ArrayOrContainerOrScalar -.. class:: ArrayOrContainerOrScalarT +.. autodata:: ArrayOrContainerOrScalarT A type variable with a lower bound of :class:`ArrayOrContainerOrScalar`. -Internal typing helpers (do not import) ---------------------------------------- - .. currentmodule:: arraycontext.context -This is only here because the documentation tool wants it. - -.. class:: SelfType - Canonical locations for type annotations ---------------------------------------- @@ -176,16 +169,11 @@ if TYPE_CHECKING: import loopy - from arraycontext.container import ArrayContainer + from arraycontext.container import ArithArrayContainer, ArrayContainer # {{{ typing -ScalarLike = int | float | complex | np.generic - -SelfType = TypeVar("SelfType") - - class Array(Protocol): """A :class:`~typing.Protocol` for the array type supported by :class:`ArrayContext`. @@ -236,16 +224,26 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... # deprecated, use ScalarLike instead ScalarLike: TypeAlias = int | float | complex | np.generic Scalar = ScalarLike - +ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) ArrayT = TypeVar("ArrayT", bound=Array) ArrayOrScalar: TypeAlias = "Array | ScalarLike" ArrayOrContainer: TypeAlias = "Array | ArrayContainer" +ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) +ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer) ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" +ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike" ArrayOrContainerOrScalarT = TypeVar( "ArrayOrContainerOrScalarT", bound=ArrayOrContainerOrScalar) +ArrayOrArithContainerOrScalarT = TypeVar( + "ArrayOrArithContainerOrScalarT", + bound=ArrayOrContainerOrScalar) + + +ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") + NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike] @@ -494,7 +492,7 @@ def einsum(self, return self.tag(tagged, out_ary) @abstractmethod - def clone(self: SelfType) -> SelfType: + def clone(self) -> Self: """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible. diff --git a/doc/conf.py b/doc/conf.py index 0ba49301..0042ae57 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,3 +38,8 @@ sys._BUILDING_SPHINX_DOCS = True + + +nitpick_ignore_regex = [ + ["py:class", r"arraycontext\.context\.ContainerOrScalarT"], + ] diff --git a/pyproject.toml b/pyproject.toml index d715981a..2e515865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + # for Self + "typing_extensions>=4", ] [project.optional-dependencies] From 3729b10c222877ff61b16c8312937c442ce2260a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 13:04:58 -0600 Subject: [PATCH 6/9] {to,from}_numpy: Use overloads for more precise type info --- arraycontext/context.py | 18 +++++++++++++++++- arraycontext/impl/numpy/__init__.py | 27 ++++++++++++++++++++++----- arraycontext/impl/pytato/utils.py | 5 +---- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 0d0595c3..5c7651c2 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -156,7 +156,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload from warnings import warn import numpy as np @@ -320,6 +320,14 @@ def zeros(self, return self.np.zeros(shape, dtype) + @overload + def from_numpy(self, array: np.ndarray) -> Array: + ... + + @overload + def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + @abstractmethod def from_numpy(self, array: NumpyOrContainerOrScalar @@ -333,6 +341,14 @@ def from_numpy(self, intact. """ + @overload + def to_numpy(self, array: Array) -> np.ndarray: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + @abstractmethod def to_numpy(self, array: ArrayOrContainerOrScalar diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index c2f884a6..f9d6c541 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,7 +1,4 @@ -from __future__ import annotations - - -__doc__ = """ +""" .. currentmodule:: arraycontext A :mod:`numpy`-based array context. @@ -9,6 +6,9 @@ .. autoclass:: NumpyArrayContext """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ @@ -33,7 +33,7 @@ THE SOFTWARE. """ -from typing import Any +from typing import Any, overload import numpy as np @@ -46,6 +46,7 @@ ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + ContainerOrScalarT, NumpyOrContainerOrScalar, UntransformedCodeWarning, ) @@ -84,11 +85,27 @@ def _get_fake_numpy_namespace(self): def clone(self): return type(self)() + @overload + def from_numpy(self, array: np.ndarray) -> Array: + ... + + @overload + def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + def from_numpy(self, array: NumpyOrContainerOrScalar ) -> ArrayOrContainerOrScalar: return array + @overload + def to_numpy(self, array: Array) -> np.ndarray: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + def to_numpy(self, array: ArrayOrContainerOrScalar ) -> NumpyOrContainerOrScalar: diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index c031e29b..6441527a 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -163,10 +163,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: # https://github.com/pylint-dev/pylint/issues/3893 # pylint: disable=unexpected-keyword-arg - # type-ignore: discussed at - # https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967 - # possibly related: https://github.com/python/mypy/issues/17375 - return DataWrapper( # type: ignore[call-arg] + return DataWrapper( data=new_dw.data, shape=expr.shape, axes=expr.axes, From 24f516d504a0806dcc7a1859893fa47197cb8fc3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 15:40:13 -0600 Subject: [PATCH 7/9] Introduce "*Tc" constrained versions of the ArrayOr* type variables --- arraycontext/context.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 5c7651c2..98bab8d7 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -226,12 +226,25 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... Scalar = ScalarLike ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) +# NOTE: I'm kind of not sure about the *Tc versions of these type variables. +# mypy seems better at understanding arithmetic performed on the *Tc versions +# than the *T, versions, whereas pyright doesn't seem to care. +# +# This issue seems to be part of it: +# https://github.com/python/mypy/issues/18203 +# but there is likely other stuff lurking. +# +# For now, they're purposefully not in the main arraycontext.* name space. ArrayT = TypeVar("ArrayT", bound=Array) ArrayOrScalar: TypeAlias = "Array | ScalarLike" ArrayOrContainer: TypeAlias = "Array | ArrayContainer" ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) +ArrayOrContainerTc = TypeVar("ArrayOrContainerTc", + Array, "ArrayContainer", "ArithArrayContainer") ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer) +ArrayOrArithContainerTc = TypeVar("ArrayOrArithContainerTc", + Array, "ArithArrayContainer") ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike" ArrayOrContainerOrScalarT = TypeVar( @@ -239,7 +252,13 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... bound=ArrayOrContainerOrScalar) ArrayOrArithContainerOrScalarT = TypeVar( "ArrayOrArithContainerOrScalarT", - bound=ArrayOrContainerOrScalar) + bound=ArrayOrArithContainerOrScalar) +ArrayOrContainerOrScalarTc = TypeVar( + "ArrayOrContainerOrScalarTc", + ScalarLike, Array, "ArrayContainer", "ArithArrayContainer") +ArrayOrArithContainerOrScalarTc = TypeVar( + "ArrayOrArithContainerOrScalarTc", + ScalarLike, Array, "ArithArrayContainer") ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") From 3ce0c5731a622539c90436dddb68de669c31907a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Nov 2024 15:22:20 -0600 Subject: [PATCH 8/9] Fix a name in TransferFromNumpyMapper --- arraycontext/impl/pytato/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 6441527a..2457e297 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -158,13 +158,13 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: # Ideally, this code should just do # return self.actx.from_numpy(expr.data).tagged(expr.tags), # but there seems to be no way to transfer the non_equality_tags in that case. - new_dw = self.actx.from_numpy(expr.data) - assert isinstance(new_dw, DataWrapper) + actx_ary = self.actx.from_numpy(expr.data) + assert isinstance(actx_ary, DataWrapper) # https://github.com/pylint-dev/pylint/issues/3893 # pylint: disable=unexpected-keyword-arg return DataWrapper( - data=new_dw.data, + data=actx_ary.data, shape=expr.shape, axes=expr.axes, tags=expr.tags, From 5e8afadbda52937f56df68f70040397d809907de Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Nov 2024 15:26:05 -0600 Subject: [PATCH 9/9] Github: Limit PR CI concurrency --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd011830..28923da7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,6 +7,10 @@ on: schedule: - cron: '17 3 * * 0' +concurrency: + group: ${{ github.head_ref || github.ref_name }} + cancel-in-progress: true + jobs: typos: name: Typos