diff --git a/grudge/reductions.py b/grudge/reductions.py index c719be712..4d38e4178 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -57,12 +57,12 @@ """ -from functools import reduce, partial +from functools import partial from arraycontext import ( make_loopy_program, map_array_container, - serialize_container, + get_container_context_recursively, DeviceScalar ) from arraycontext.container import ArrayOrContainerT @@ -94,7 +94,6 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar": if dd is None: dd = dof_desc.DD_VOLUME - from arraycontext import get_container_context_recursively actx = get_container_context_recursively(vec) dd = dof_desc.as_dofdesc(dd) @@ -128,7 +127,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": # NOTE: Don't move this from mpi4py import MPI - actx = vec.array_context + actx = get_container_context_recursively(vec) return actx.from_numpy( comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM)) @@ -143,15 +142,13 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :class:`~arraycontext.container.ArrayContainer` of them. :returns: a scalar denoting the rank-local nodal sum. """ - if not isinstance(vec, DOFArray): - return sum( - nodal_sum_loc(dcoll, dd, comp) - for _, comp in serialize_container(vec) - ) - - actx = vec.array_context - - return sum([actx.np.sum(grp_ary) for grp_ary in vec]) + actx = get_container_context_recursively(vec) + result = actx.np.sum(vec) + # Fix actx._force_device_scalars == False case + if np.isscalar(result): + return actx.from_numpy(result) + else: + return result def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": @@ -169,7 +166,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": # NOTE: Don't move this from mpi4py import MPI - actx = vec.array_context + actx = get_container_context_recursively(vec) return actx.from_numpy( comm.allreduce(actx.to_numpy(nodal_min_loc(dcoll, dd, vec)), op=MPI.MIN)) @@ -185,17 +182,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :class:`~arraycontext.container.ArrayContainer` of them. :returns: a scalar denoting the rank-local nodal minimum. """ - if not isinstance(vec, DOFArray): - return min( - nodal_min_loc(dcoll, dd, comp) - for _, comp in serialize_container(vec) - ) - - actx = vec.array_context - - return reduce( - lambda acc, grp_ary: actx.np.minimum(acc, actx.np.min(grp_ary)), - vec, actx.from_numpy(np.array(np.inf))) + actx = get_container_context_recursively(vec) + result = actx.np.min(vec) + # Fix actx._force_device_scalars == False case + if np.isscalar(result): + return actx.from_numpy(result) + else: + return result def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": @@ -213,7 +206,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": # NOTE: Don't move this from mpi4py import MPI - actx = vec.array_context + actx = get_container_context_recursively(vec) return actx.from_numpy( comm.allreduce(actx.to_numpy(nodal_max_loc(dcoll, dd, vec)), op=MPI.MAX)) @@ -229,17 +222,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :class:`~arraycontext.container.ArrayContainer`. :returns: a scalar denoting the rank-local nodal maximum. """ - if not isinstance(vec, DOFArray): - return max( - nodal_max_loc(dcoll, dd, comp) - for _, comp in serialize_container(vec) - ) - - actx = vec.array_context - - return reduce( - lambda acc, grp_ary: actx.np.maximum(acc, actx.np.max(grp_ary)), - vec, actx.from_numpy(np.array(-np.inf))) + actx = get_container_context_recursively(vec) + result = actx.np.max(vec) + # Fix actx._force_device_scalars == False case + if np.isscalar(result): + return actx.from_numpy(result) + else: + return result def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": @@ -253,9 +242,10 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": """ from grudge.op import _apply_mass_operator + actx = get_container_context_recursively(vec) dd = dof_desc.as_dofdesc(dd) - ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0 + ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0 return nodal_sum( dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones) ) @@ -295,7 +285,7 @@ def _apply_elementwise_reduction( partial(_apply_elementwise_reduction, op_name, dcoll, dd), vec ) - actx = vec.array_context + actx = get_container_context_recursively(vec) if actx.supports_nonscalar_broadcasting: return DOFArray( @@ -456,11 +446,12 @@ def elementwise_integral( else: raise TypeError("invalid number of arguments") + actx = get_container_context_recursively(vec) dd = dof_desc.as_dofdesc(dd) from grudge.op import _apply_mass_operator - ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0 + ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0 return elementwise_sum( dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones) ) diff --git a/requirements.txt b/requirements.txt index abf6d39fa..5aa45bfea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ git+https://github.com/inducer/dagrt.git#egg=dagrt git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy git+https://github.com/inducer/modepy.git#egg=modepy -git+https://github.com/inducer/arraycontext.git#egg=arraycontext +git+https://github.com/majosm/arraycontext.git@empty-subcontainers#egg=arraycontext git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis