From 889444542abadd1d7e97e2cbebb0182f4136ec27 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 21 Sep 2022 09:49:52 -0500 Subject: [PATCH 1/7] set up connections between volumes --- grudge/discretization.py | 281 +++++++++++++++++++++++++++++---------- 1 file changed, 208 insertions(+), 73 deletions(-) diff --git a/grudge/discretization.py b/grudge/discretization.py index 8e57ca503..fd4c39728 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -7,6 +7,7 @@ .. autofunction:: make_discretization_collection .. currentmodule:: grudge.discretization +.. autoclass:: PartID """ __copyright__ = """ @@ -34,10 +35,12 @@ THE SOFTWARE. """ -from typing import Mapping, Optional, Union, TYPE_CHECKING, Any +from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any from pytools import memoize_method, single_valued +from dataclasses import dataclass, replace + from grudge.dof_desc import ( VTAG_ALL, DD_VOLUME_ALL, @@ -71,6 +74,75 @@ import mpi4py.MPI +@dataclass(frozen=True) +class PartID: + """Unique identifier for a piece of a partitioned mesh. + + .. attribute:: volume_tag + + The volume of the part. + + .. attribute:: rank + + The (optional) MPI rank of the part. + + """ + volume_tag: VolumeTag + rank: Optional[int] = None + + +# {{{ part ID normalization + +def _normalize_mesh_part_ids( + mesh: Mesh, + self_volume_tag: VolumeTag, + all_volume_tags: Sequence[VolumeTag], + mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None): + """Convert a mesh's configuration-dependent "part ID" into a fixed type.""" + from numbers import Integral + if mpi_communicator is not None: + # Accept PartID or rank (assume intra-volume for the latter) + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif isinstance(mesh_part_id, Integral): + return PartID(self_volume_tag, int(mesh_part_id)) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Accept PartID or volume tag + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif mesh_part_id in all_volume_tags: + return PartID(mesh_part_id) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + + facial_adjacency_groups = mesh.facial_adjacency_groups + + new_facial_adjacency_groups = [] + + from meshmode.mesh import InterPartAdjacencyGroup + for grp_list in facial_adjacency_groups: + new_grp_list = [] + for fagrp in grp_list: + if isinstance(fagrp, InterPartAdjacencyGroup): + part_id = as_part_id(fagrp.part_id) + new_fagrp = replace( + fagrp, + boundary_tag=BTAG_PARTITION(part_id), + part_id=part_id) + else: + new_fagrp = fagrp + new_grp_list.append(new_fagrp) + new_facial_adjacency_groups.append(new_grp_list) + + return mesh.copy(facial_adjacency_groups=new_facial_adjacency_groups) + +# }}} + + # {{{ discr_tag_to_group_factory normalization def _normalize_discr_tag_to_group_factory( @@ -156,6 +228,9 @@ def __init__(self, array_context: ArrayContext, discr_tag_to_group_factory: Optional[ Mapping[DiscretizationTag, ElementGroupFactory]] = None, mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None, + inter_part_connections: Optional[ + Mapping[Tuple[PartID, PartID], + DiscretizationConnection]] = None, ) -> None: """ :arg discr_tag_to_group_factory: A mapping from discretization tags @@ -206,6 +281,9 @@ def __init__(self, array_context: ArrayContext, mesh = volume_discrs + mesh = _normalize_mesh_part_ids( + mesh, VTAG_ALL, [VTAG_ALL], mpi_communicator=mpi_communicator) + discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory( dim=mesh.dim, discr_tag_to_group_factory=discr_tag_to_group_factory, @@ -219,17 +297,32 @@ def __init__(self, array_context: ArrayContext, del mesh + if inter_part_connections is not None: + raise TypeError("may not pass inter_part_connections when " + "DiscretizationCollection constructor is called in " + "legacy mode") + + self._inter_part_connections = \ + _set_up_inter_part_connections( + array_context=self._setup_actx, + mpi_communicator=mpi_communicator, + volume_discrs=volume_discrs, + base_group_factory=( + discr_tag_to_group_factory[DISCR_TAG_BASE])) + # }}} else: assert discr_tag_to_group_factory is not None self._discr_tag_to_group_factory = discr_tag_to_group_factory - self._volume_discrs = volume_discrs + if inter_part_connections is None: + raise TypeError("inter_part_connections must be passed when " + "DiscretizationCollection constructor is called in " + "'modern' mode") + + self._inter_part_connections = inter_part_connections - self._dist_boundary_connections = { - vtag: self._set_up_distributed_communication( - vtag, mpi_communicator, array_context) - for vtag in self._volume_discrs.keys()} + self._volume_discrs = volume_discrs # }}} @@ -252,71 +345,6 @@ def is_management_rank(self): return self.mpi_communicator.Get_rank() \ == self.get_management_rank_index() - # {{{ distributed - - def _set_up_distributed_communication( - self, vtag, mpi_communicator, array_context): - from_dd = DOFDesc(VolumeDomainTag(vtag), DISCR_TAG_BASE) - - boundary_connections = {} - - from meshmode.distributed import get_connected_partitions - connected_parts = get_connected_partitions(self._volume_discrs[vtag].mesh) - - if connected_parts: - if mpi_communicator is None: - raise RuntimeError("must supply an MPI communicator when using a " - "distributed mesh") - - grp_factory = \ - self.group_factory_for_discretization_tag(DISCR_TAG_BASE) - - local_boundary_connections = {} - for i_remote_part in connected_parts: - local_boundary_connections[i_remote_part] = self.connection_from_dds( - from_dd, from_dd.trace(BTAG_PARTITION(i_remote_part))) - - from meshmode.distributed import MPIBoundaryCommSetupHelper - with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, - local_boundary_connections, grp_factory) as bdry_setup_helper: - while True: - conns = bdry_setup_helper.complete_some() - if not conns: - break - for i_remote_part, conn in conns.items(): - boundary_connections[i_remote_part] = conn - - return boundary_connections - - def distributed_boundary_swap_connection(self, dd): - """Provides a mapping from the base volume discretization - to the exterior boundary restriction on a parallel boundary - partition described by *dd*. This connection is used to - communicate across element boundaries in different parallel - partitions during distributed runs. - - :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value - convertible to one. The domain tag must be a subclass - of :class:`grudge.dof_desc.BoundaryDomainTag` with an - associated :class:`meshmode.mesh.BTAG_PARTITION` - corresponding to a particular communication rank. - """ - if dd.discretization_tag is not DISCR_TAG_BASE: - # FIXME - raise NotImplementedError( - "Distributed communication with discretization tag " - f"{dd.discretization_tag} is not implemented." - ) - - assert isinstance(dd.domain_tag, BoundaryDomainTag) - assert isinstance(dd.domain_tag.tag, BTAG_PARTITION) - - vtag = dd.domain_tag.volume_tag - - return self._dist_boundary_connections[vtag][dd.domain_tag.tag.part_nr] - - # }}} - # {{{ discr_from_dd @memoize_method @@ -772,6 +800,105 @@ def normal(self, dd): # }}} +# {{{ distributed/multi-volume setup + +def _set_up_inter_part_connections( + array_context: ArrayContext, + mpi_communicator: Optional["mpi4py.MPI.Intracomm"], + volume_discrs: Mapping[VolumeTag, Discretization], + base_group_factory: ElementGroupFactory, + ) -> Mapping[ + Tuple[PartID, PartID], + DiscretizationConnection]: + + from meshmode.distributed import (get_connected_parts, + make_remote_group_infos, InterRankBoundaryInfo, + MPIBoundaryCommSetupHelper) + + rank = mpi_communicator.Get_rank() if mpi_communicator is not None else None + + # Save boundary restrictions as they're created to avoid potentially creating + # them twice in the loop below + cached_part_bdry_restrictions: Mapping[ + Tuple[PartID, PartID], + DiscretizationConnection] = {} + + def get_part_bdry_restriction(self_part_id, other_part_id): + cached_result = cached_part_bdry_restrictions.get( + (self_part_id, other_part_id), None) + if cached_result is not None: + return cached_result + return cached_part_bdry_restrictions.setdefault( + (self_part_id, other_part_id), + make_face_restriction( + array_context, volume_discrs[self_part_id.volume_tag], + base_group_factory, + boundary_tag=BTAG_PARTITION(other_part_id))) + + inter_part_conns: Mapping[ + Tuple[PartID, PartID], + DiscretizationConnection] = {} + + irbis = [] + + for vtag, volume_discr in volume_discrs.items(): + part_id = PartID(vtag, rank) + connected_part_ids = get_connected_parts(volume_discr.mesh) + for connected_part_id in connected_part_ids: + bdry_restr = get_part_bdry_restriction( + self_part_id=part_id, other_part_id=connected_part_id) + + if connected_part_id.rank == rank: + # {{{ rank-local interface between multiple volumes + + connected_bdry_restr = get_part_bdry_restriction( + self_part_id=connected_part_id, other_part_id=part_id) + + from meshmode.discretization.connection import \ + make_partition_connection + inter_part_conns[connected_part_id, part_id] = \ + make_partition_connection( + array_context, + local_bdry_conn=bdry_restr, + remote_bdry_discr=connected_bdry_restr.to_discr, + remote_group_infos=make_remote_group_infos( + array_context, part_id, connected_bdry_restr)) + + # }}} + else: + # {{{ cross-rank interface + + if mpi_communicator is None: + raise RuntimeError("must supply an MPI communicator " + "when using a distributed mesh") + + irbis.append( + InterRankBoundaryInfo( + local_part_id=part_id, + remote_part_id=connected_part_id, + remote_rank=connected_part_id.rank, + local_boundary_connection=bdry_restr)) + + # }}} + + if irbis: + assert mpi_communicator is not None + + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + irbis, base_group_factory) as bdry_setup_helper: + while True: + conns = bdry_setup_helper.complete_some() + if not conns: + # We're done. + break + + inter_part_conns.update(conns) + + return inter_part_conns + +# }}} + + # {{{ modal group factory def _generate_modal_group_factory(nodal_group_factory): @@ -860,6 +987,8 @@ def make_discretization_collection( del order + mpi_communicator = getattr(array_context, "mpi_communicator", None) + if any( isinstance(mesh_or_discr, Discretization) for mesh_or_discr in volumes.values()): @@ -868,14 +997,20 @@ def make_discretization_collection( volume_discrs = { vtag: Discretization( array_context, - mesh, + _normalize_mesh_part_ids( + mesh, vtag, volumes.keys(), mpi_communicator=mpi_communicator), discr_tag_to_group_factory[DISCR_TAG_BASE]) for vtag, mesh in volumes.items()} return DiscretizationCollection( array_context=array_context, volume_discrs=volume_discrs, - discr_tag_to_group_factory=discr_tag_to_group_factory) + discr_tag_to_group_factory=discr_tag_to_group_factory, + inter_part_connections=_set_up_inter_part_connections( + array_context=array_context, + mpi_communicator=mpi_communicator, + volume_discrs=volume_discrs, + base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE])) # }}} From deb2e46e006767d980318cda15a98e13fbd2b784 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 21 Sep 2022 09:51:54 -0500 Subject: [PATCH 2/7] add inter-volume communication --- grudge/eager.py | 3 +- grudge/op.py | 10 +- grudge/trace_pair.py | 627 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 522 insertions(+), 118 deletions(-) diff --git a/grudge/eager.py b/grudge/eager.py index 626e15592..08cf08f2a 100644 --- a/grudge/eager.py +++ b/grudge/eager.py @@ -87,7 +87,8 @@ def nodal_max(self, dd, vec): return op.nodal_max(self, dd, vec) -connected_ranks = op.connected_ranks +# FIXME: Deprecate connected_ranks instead of removing +connected_parts = op.connected_parts interior_trace_pair = op.interior_trace_pair cross_rank_trace_pairs = op.cross_rank_trace_pairs diff --git a/grudge/op.py b/grudge/op.py index f5781f4be..a6cef8ffa 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -118,8 +118,11 @@ interior_trace_pair, interior_trace_pairs, local_interior_trace_pair, - connected_ranks, + connected_parts, + inter_volume_trace_pairs, + local_inter_volume_trace_pairs, cross_rank_trace_pairs, + cross_rank_inter_volume_trace_pairs, bdry_trace_pair, bv_trace_pair ) @@ -147,8 +150,11 @@ "interior_trace_pair", "interior_trace_pairs", "local_interior_trace_pair", - "connected_ranks", + "connected_parts", + "inter_volume_trace_pairs", + "local_inter_volume_trace_pairs", "cross_rank_trace_pairs", + "cross_rank_inter_volume_trace_pairs", "bdry_trace_pair", "bv_trace_pair", diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 1f49ae0d6..0b0400f12 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -18,12 +18,15 @@ .. autofunction:: bdry_trace_pair .. autofunction:: bv_trace_pair -Interior and cross-rank trace functions ---------------------------------------- +Interior, cross-rank, and inter-volume traces +--------------------------------------------- .. autofunction:: interior_trace_pairs .. autofunction:: local_interior_trace_pair +.. autofunction:: inter_volume_trace_pairs +.. autofunction:: local_inter_volume_trace_pairs .. autofunction:: cross_rank_trace_pairs +.. autofunction:: cross_rank_inter_volume_trace_pairs """ __copyright__ = """ @@ -52,17 +55,18 @@ from warnings import warn -from typing import List, Hashable, Optional, Type, Any +from typing import List, Hashable, Optional, Tuple, Type, Any, Sequence, Mapping from pytools.persistent_dict import KeyBuilder from arraycontext import ( ArrayContainer, + ArrayContext, with_container_arithmetic, dataclass_array_container, - get_container_context_recursively, - flatten, to_numpy, - unflatten, from_numpy, + get_container_context_recursively_opt, + to_numpy, + from_numpy, ArrayOrContainer ) @@ -72,7 +76,7 @@ from pytools import memoize_on_first_arg -from grudge.discretization import DiscretizationCollection +from grudge.discretization import DiscretizationCollection, PartID from grudge.projection import project from meshmode.mesh import BTAG_PARTITION @@ -82,7 +86,7 @@ import grudge.dof_desc as dof_desc from grudge.dof_desc import ( DOFDesc, DD_VOLUME_ALL, FACE_RESTR_INTERIOR, DISCR_TAG_BASE, - VolumeDomainTag, + VolumeTag, VolumeDomainTag, BoundaryDomainTag, ConvertibleToDOFDesc, ) @@ -360,6 +364,124 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, # }}} +# {{{ inter-volume trace pairs + +def local_inter_volume_trace_pairs( + dcoll: DiscretizationCollection, + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]] + ) -> Mapping[Tuple[DOFDesc, DOFDesc], TracePair]: + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") + + rank = ( + dcoll.mpi_communicator.Get_rank() + if dcoll.mpi_communicator is not None + else None) + + result: Mapping[Tuple[DOFDesc, DOFDesc], TracePair] = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + from meshmode.mesh import mesh_has_boundary + if not mesh_has_boundary( + dcoll.discr_from_dd(vol_dd_pair[0]).mesh, + BTAG_PARTITION(PartID(vol_dd_pair[1].domain_tag.tag, rank))): + continue + + directional_vol_dd_pairs = [ + (vol_dd_pair[1], vol_dd_pair[0]), + (vol_dd_pair[0], vol_dd_pair[1])] + + trace_dd_pair = tuple( + self_vol_dd.trace( + BTAG_PARTITION( + PartID(other_vol_dd.domain_tag.tag, rank))) + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs) + + # Pre-compute the projections out here to avoid doing it twice inside + # the loop below + trace_data = { + trace_dd: project(dcoll, vol_dd, trace_dd, vol_data) + for vol_dd, trace_dd, vol_data in zip( + vol_dd_pair, trace_dd_pair, vol_data_pair)} + + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs: + self_part_id = PartID(self_vol_dd.domain_tag.tag, rank) + other_part_id = PartID(other_vol_dd.domain_tag.tag, rank) + + self_trace_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + other_trace_dd = other_vol_dd.trace(BTAG_PARTITION(self_part_id)) + + self_trace_data = trace_data[self_trace_dd] + unswapped_other_trace_data = trace_data[other_trace_dd] + + other_to_self = dcoll._inter_part_connections[ + other_part_id, self_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return other_to_self(ary) # noqa: B023 + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + other_trace_data = rec_map_array_container( + get_opposite_trace, + unswapped_other_trace_data, + leaf_class=DOFArray) + + result[other_vol_dd, self_vol_dd] = TracePair( + self_trace_dd, + interior=self_trace_data, + exterior=other_trace_data) + + return result + + +def inter_volume_trace_pairs(dcoll: DiscretizationCollection, + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]], + comm_tag: Hashable = None) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: + """ + Note that :func:`local_inter_volume_trace_pairs` provides the rank-local + contributions if those are needed in isolation. Similarly, + :func:`cross_rank_inter_volume_trace_pairs` provides only the trace pairs + defined on cross-rank boundaries. + """ + # TODO documentation + + result: Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]] = {} + + local_tpairs = local_inter_volume_trace_pairs(dcoll, pairwise_volume_data) + cross_rank_tpairs = cross_rank_inter_volume_trace_pairs( + dcoll, pairwise_volume_data, comm_tag=comm_tag) + + for directional_vol_dd_pair, tpair in local_tpairs.items(): + result[directional_vol_dd_pair] = [tpair] + + for directional_vol_dd_pair, tpairs in cross_rank_tpairs.items(): + result.setdefault(directional_vol_dd_pair, []).extend(tpairs) + + return result + +# }}} + + # {{{ distributed: helper functions class _TagKeyBuilder(KeyBuilder): @@ -367,16 +489,21 @@ def update_for_type(self, key_hash, key: Type[Any]): self.rec(key_hash, (key.__module__, key.__name__, key.__name__,)) +# FIXME: Deprecate connected_ranks instead of removing @memoize_on_first_arg -def connected_ranks( +def connected_parts( dcoll: DiscretizationCollection, - volume_dd: Optional[DOFDesc] = None): - if volume_dd is None: - volume_dd = DD_VOLUME_ALL + self_volume_tag: VolumeTag, + other_volume_tag: VolumeTag + ) -> Sequence[PartID]: + result: List[PartID] = [ + connected_part_id + for connected_part_id, part_id in dcoll._inter_part_connections.keys() + if ( + part_id.volume_tag == self_volume_tag + and connected_part_id.volume_tag == other_volume_tag)] - from meshmode.distributed import get_connected_partitions - return get_connected_partitions( - dcoll._volume_discrs[volume_dd.domain_tag.tag].mesh) + return result def _sym_tag_to_num_tag(comm_tag: Optional[Hashable]) -> Optional[int]: @@ -414,24 +541,33 @@ class _RankBoundaryCommunicationEager: base_comm_tag = 1273 def __init__(self, - dcoll: DiscretizationCollection, - array_container: ArrayOrContainer, - remote_rank, comm_tag: Optional[int] = None, - volume_dd=DD_VOLUME_ALL): - actx = get_container_context_recursively(array_container) - bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank)) - - local_bdry_data = project(dcoll, volume_dd, bdry_dd, array_container) + actx: ArrayContext, + dcoll: DiscretizationCollection, + *, + local_part_id: PartID, + remote_part_id: PartID, + local_bdry_data: ArrayOrContainer, + remote_bdry_data_template: ArrayOrContainer, + comm_tag: Optional[Hashable] = None): + comm = dcoll.mpi_communicator assert comm is not None + remote_rank = remote_part_id.rank + assert remote_rank is not None + self.dcoll = dcoll self.array_context = actx - self.remote_bdry_dd = bdry_dd - self.bdry_discr = dcoll.discr_from_dd(bdry_dd) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_part_id.volume_tag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) self.local_bdry_data = local_bdry_data - self.local_bdry_data_np = \ - to_numpy(flatten(self.local_bdry_data, actx), actx) + self.remote_bdry_data_template = remote_bdry_data_template self.comm_tag = self.base_comm_tag comm_tag = _sym_tag_to_num_tag(comm_tag) @@ -439,55 +575,80 @@ def __init__(self, self.comm_tag += comm_tag del comm_tag - # Here, we initialize both send and recieve operations through - # mpi4py `Request` (MPI_Request) instances for comm.Isend (MPI_Isend) - # and comm.Irecv (MPI_Irecv) respectively. These initiate non-blocking - # point-to-point communication requests and require explicit management - # via the use of wait (MPI_Wait, MPI_Waitall, MPI_Waitany, MPI_Waitsome), - # test (MPI_Test, MPI_Testall, MPI_Testany, MPI_Testsome), and cancel - # (MPI_Cancel). The rank-local data `self.local_bdry_data_np` will have its - # associated memory buffer sent across connected ranks and must not be - # modified at the Python level during this process. Completion of the - # requests is handled in :meth:`finish`. - # - # For more details on the mpi4py semantics, see: - # https://mpi4py.readthedocs.io/en/stable/overview.html#nonblocking-communications - # # NOTE: mpi4py currently (2021-11-03) holds a reference to the send # memory buffer for (i.e. `self.local_bdry_data_np`) until the send # requests is complete, however it is not clear that this is documented # behavior. We hold on to the buffer (via the instance attribute) # as well, just in case. - self.send_req = comm.Isend(self.local_bdry_data_np, - remote_rank, - tag=self.comm_tag) - self.remote_data_host_numpy = np.empty_like(self.local_bdry_data_np) - self.recv_req = comm.Irecv(self.remote_data_host_numpy, - remote_rank, - tag=self.comm_tag) + self.send_reqs = [] + self.send_data = [] + + def send_single_array(key, local_subary): + if not isinstance(local_subary, Number): + local_subary_np = to_numpy(local_subary, actx) + self.send_reqs.append( + comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag)) + self.send_data.append(local_subary_np) + return local_subary + + self.recv_reqs = [] + self.recv_data = {} + + def recv_single_array(key, remote_subary_template): + if not isinstance(remote_subary_template, Number): + remote_subary_np = np.empty( + remote_subary_template.shape, + remote_subary_template.dtype) + self.recv_reqs.append( + comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag)) + self.recv_data[key] = remote_subary_np + return remote_subary_template + + from arraycontext.container.traversal import rec_keyed_map_array_container + rec_keyed_map_array_container(send_single_array, local_bdry_data) + rec_keyed_map_array_container(recv_single_array, remote_bdry_data_template) def finish(self): - # Wait for the nonblocking receive request to complete before + from mpi4py import MPI + + # Wait for the nonblocking receive requests to complete before # accessing the data - self.recv_req.Wait() - - # Nonblocking receive is complete, we can now access the data and apply - # the boundary-swap connection - actx = self.array_context - remote_bdry_data_flat = from_numpy(self.remote_data_host_numpy, actx) - remote_bdry_data = unflatten(self.local_bdry_data, - remote_bdry_data_flat, actx) - bdry_conn = self.dcoll.distributed_boundary_swap_connection( - self.remote_bdry_dd) - swapped_remote_bdry_data = bdry_conn(remote_bdry_data) - - # Complete the nonblocking send request associated with communicating - # `self.local_bdry_data_np` - self.send_req.Wait() - - return TracePair(self.remote_bdry_dd, - interior=self.local_bdry_data, - exterior=swapped_remote_bdry_data) + MPI.Request.waitall(self.recv_reqs) + + def finish_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + return from_numpy(self.recv_data[key], self.array_context) + + from arraycontext.container.traversal import rec_keyed_map_array_container + unswapped_remote_bdry_data = rec_keyed_map_array_container( + finish_single_array, self.remote_bdry_data_template) + + remote_to_local = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + unswapped_remote_bdry_data, + leaf_class=DOFArray) + + # Complete the nonblocking send requests + MPI.Request.waitall(self.send_reqs) + + return TracePair( + self.local_bdry_dd, + interior=self.local_bdry_data, + exterior=remote_bdry_data) # }}} @@ -496,51 +657,112 @@ def finish(self): class _RankBoundaryCommunicationLazy: def __init__(self, - dcoll: DiscretizationCollection, - array_container: ArrayOrContainer, - remote_rank: int, comm_tag: Hashable, - volume_dd=DD_VOLUME_ALL): + actx: ArrayContext, + dcoll: DiscretizationCollection, + *, + local_part_id: PartID, + remote_part_id: PartID, + local_bdry_data: ArrayOrContainer, + remote_bdry_data_template: ArrayOrContainer, + comm_tag: Optional[Hashable] = None) -> None: + if comm_tag is None: - raise ValueError("lazy communication requires 'tag' to be supplied") + raise ValueError("lazy communication requires 'comm_tag' to be supplied") - bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank)) + remote_rank = remote_part_id.rank + assert remote_rank is not None self.dcoll = dcoll - self.array_context = get_container_context_recursively(array_container) - self.remote_bdry_dd = bdry_dd - self.bdry_discr = dcoll.discr_from_dd(self.remote_bdry_dd) - - self.local_bdry_data = project( - dcoll, volume_dd, bdry_dd, array_container) - - from pytato import make_distributed_recv, staple_distributed_send - - def communicate_single_array(key, local_bdry_ary): - ary_tag = (comm_tag, key) - return staple_distributed_send( - local_bdry_ary, dest_rank=remote_rank, comm_tag=ary_tag, - stapled_to=make_distributed_recv( + self.array_context = actx + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_part_id.volume_tag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id + + from pytato import ( + make_distributed_recv, + make_distributed_send, + DistributedSendRefHolder) + + # TODO: This currently assumes that local_bdry_data and + # remote_bdry_data_template have the same structure. This is not true + # in general. Find a way to staple the sends appropriately when the number + # of recvs is not equal to the number of sends + # FIXME: Overly restrictive (just needs to be the same structure) + assert type(local_bdry_data) == type(remote_bdry_data_template) + + sends = {} + + def send_single_array(key, local_subary): + if isinstance(local_subary, Number): + return + else: + ary_tag = (comm_tag, key) + sends[key] = make_distributed_send( + local_subary, dest_rank=remote_rank, comm_tag=ary_tag) + + def recv_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + ary_tag = (comm_tag, key) + return DistributedSendRefHolder( + sends[key], + make_distributed_recv( src_rank=remote_rank, comm_tag=ary_tag, - shape=local_bdry_ary.shape, dtype=local_bdry_ary.dtype, - axes=local_bdry_ary.axes)) + shape=remote_subary_template.shape, + dtype=remote_subary_template.dtype, + axes=remote_subary_template.axes)) from arraycontext.container.traversal import rec_keyed_map_array_container - self.remote_data = rec_keyed_map_array_container( - communicate_single_array, self.local_bdry_data) - def finish(self): - bdry_conn = self.dcoll.distributed_boundary_swap_connection( - self.remote_bdry_dd) + rec_keyed_map_array_container(send_single_array, local_bdry_data) + self.local_bdry_data = local_bdry_data - return TracePair(self.remote_bdry_dd, - interior=self.local_bdry_data, - exterior=bdry_conn(self.remote_data)) + self.unswapped_remote_bdry_data = rec_keyed_map_array_container( + recv_single_array, remote_bdry_data_template) + + def finish(self): + remote_to_local = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + self.unswapped_remote_bdry_data, + leaf_class=DOFArray) + + return TracePair( + self.local_bdry_dd, + interior=self.local_bdry_data, + exterior=remote_bdry_data) # }}} # {{{ cross_rank_trace_pairs +def _replace_dof_arrays(array_container, dof_array): + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + return rec_map_array_container( + lambda x: dof_array if isinstance(x, DOFArray) else x, + array_container, + leaf_class=DOFArray) + + def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary: ArrayOrContainer, tag: Hashable = None, @@ -549,9 +771,9 @@ def cross_rank_trace_pairs( r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. For each partition boundary, the field data values in *ary* are - communicated to/from the neighboring partition. Presumably, this - communication is MPI (but strictly speaking, may not be, and this - routine is agnostic to the underlying communication). + communicated to/from the neighboring part. Presumably, this communication + is MPI (but strictly speaking, may not be, and this routine is agnostic to + the underlying communication). For each face on each partition boundary, a :class:`TracePair` is created with the locally, and @@ -596,14 +818,36 @@ def cross_rank_trace_pairs( # }}} - if isinstance(ary, Number): - # NOTE: Assumed that the same number is passed on every rank - return [TracePair( - volume_dd.trace(BTAG_PARTITION(remote_rank)), - interior=ary, exterior=ary) - for remote_rank in connected_ranks(dcoll, volume_dd=volume_dd)] + if dcoll.mpi_communicator is None: + return [] + + rank = dcoll.mpi_communicator.Get_rank() + + local_part_id = PartID(volume_dd.domain_tag.tag, rank) + + connected_part_ids = connected_parts( + dcoll, self_volume_tag=volume_dd.domain_tag.tag, + other_volume_tag=volume_dd.domain_tag.tag) + + remote_part_ids = [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] + + # This asserts that there is only one data exchange per rank, so that + # there is no risk of mismatched data reaching the wrong recipient. + # (Since we have only a single tag.) + assert len(remote_part_ids) == len({part_id.rank for part_id in remote_part_ids}) - actx = get_container_context_recursively(ary) + actx = get_container_context_recursively_opt(ary) + + if actx is None: + # NOTE: Assumes that the same number is passed on every rank + return [ + TracePair( + volume_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=ary, exterior=ary) + for remote_part_id in remote_part_ids] from grudge.array_context import MPIPytatoArrayContextBase @@ -612,14 +856,167 @@ def cross_rank_trace_pairs( else: rbc_class = _RankBoundaryCommunicationEager - # Initialize and post all sends/receives - rank_bdry_communcators = [ - rbc_class(dcoll, ary, remote_rank, comm_tag=comm_tag, volume_dd=volume_dd) - for remote_rank in connected_ranks(dcoll, volume_dd=volume_dd) - ] + rank_bdry_communicators = [] + + for remote_part_id in remote_part_ids: + bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_part_id)) + + local_bdry_data = project(dcoll, volume_dd, bdry_dd, ary) + + from arraycontext import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + remote_bdry_zeros = tag_axes( + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + dcoll._inter_part_connections[ + remote_part_id, local_part_id].from_discr.zeros(actx)) + + remote_bdry_data_template = _replace_dof_arrays( + local_bdry_data, remote_bdry_zeros) + + rank_bdry_communicators.append( + rbc_class(actx, dcoll, + local_part_id=local_part_id, + remote_part_id=remote_part_id, + local_bdry_data=local_bdry_data, + remote_bdry_data_template=remote_bdry_data_template, + comm_tag=comm_tag)) + + return [rbc.finish() for rbc in rank_bdry_communicators] + +# }}} + + +# {{{ cross_rank_inter_volume_trace_pairs + +def cross_rank_inter_volume_trace_pairs( + dcoll: DiscretizationCollection, + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]], + *, comm_tag: Hashable = None, + ) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: + # FIXME: Should this interface take in boundary data instead? + # TODO: Docs + r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. + + :arg comm_tag: a hashable object used to match sent and received data + across ranks. Communication will only match if both endpoints specify + objects that compare equal. A generalization of MPI communication + tags to arbitary, potentially composite objects. + + :returns: a :class:`list` of :class:`TracePair` objects. + """ + # {{{ process arguments + + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") + + # }}} + + if dcoll.mpi_communicator is None: + return {} + + rank = dcoll.mpi_communicator.Get_rank() + + for vol_data_pair in pairwise_volume_data.values(): + for vol_data in vol_data_pair: + actx = get_container_context_recursively_opt(vol_data) + if actx is not None: + break + if actx is not None: + break + + def get_remote_connected_parts(local_vol_dd, remote_vol_dd): + connected_part_ids = connected_parts( + dcoll, self_volume_tag=local_vol_dd.domain_tag.tag, + other_volume_tag=remote_vol_dd.domain_tag.tag) + return [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] + + if actx is None: + # NOTE: Assumes that the same number is passed on every rank for a + # given volume + return { + (remote_vol_dd, local_vol_dd): [ + TracePair( + local_vol_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=local_vol_ary, exterior=remote_vol_ary) + for remote_part_id in get_remote_connected_parts( + local_vol_dd, remote_vol_dd)] + for (remote_vol_dd, local_vol_dd), (remote_vol_ary, local_vol_ary) + in pairwise_volume_data.items()} + + from grudge.array_context import MPIPytatoArrayContextBase + + if isinstance(actx, MPIPytatoArrayContextBase): + rbc_class = _RankBoundaryCommunicationLazy + else: + rbc_class = _RankBoundaryCommunicationEager - # Complete send/receives and return communicated data - return [rc.finish() for rc in rank_bdry_communcators] + rank_bdry_communicators = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + directional_volume_data = { + (vol_dd_pair[0], vol_dd_pair[1]): (vol_data_pair[0], vol_data_pair[1]), + (vol_dd_pair[1], vol_dd_pair[0]): (vol_data_pair[1], vol_data_pair[0])} + + for dd_pair, data_pair in directional_volume_data.items(): + other_vol_dd, self_vol_dd = dd_pair + other_vol_data, self_vol_data = data_pair + + self_part_id = PartID(self_vol_dd.domain_tag.tag, rank) + other_part_ids = get_remote_connected_parts(self_vol_dd, other_vol_dd) + + rbcs = [] + + for other_part_id in other_part_ids: + self_bdry_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + self_bdry_data = project( + dcoll, self_vol_dd, self_bdry_dd, self_vol_data) + + from arraycontext import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + other_bdry_zeros = tag_axes( + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + dcoll._inter_part_connections[ + other_part_id, self_part_id].from_discr.zeros(actx)) + + other_bdry_data_template = _replace_dof_arrays( + other_vol_data, other_bdry_zeros) + + rbcs.append( + rbc_class(actx, dcoll, + local_part_id=self_part_id, + remote_part_id=other_part_id, + local_bdry_data=self_bdry_data, + remote_bdry_data_template=other_bdry_data_template, + comm_tag=comm_tag)) + + rank_bdry_communicators[other_vol_dd, self_vol_dd] = rbcs + + return { + directional_vol_dd_pair: [rbc.finish() for rbc in rbcs] + for directional_vol_dd_pair, rbcs in rank_bdry_communicators.items()} # }}} From d16767acb711c967b54d73b25e4c9e9b31469ae4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 3 Nov 2022 09:20:20 -0700 Subject: [PATCH 3/7] add fixme --- grudge/trace_pair.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 0b0400f12..84dedf386 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -525,6 +525,7 @@ def _sym_tag_to_num_tag(comm_tag: Optional[Hashable]) -> Optional[int]: num_tag = sum(ord(ch) << i for i, ch in enumerate(digest)) % tag_ub + # FIXME: This prints the wrong numerical tag because of base_comm_tag below warn("Encountered unknown symbolic tag " f"'{comm_tag}', assigning a value of '{num_tag}'. " "This is a temporary workaround, please ensure that " From a3810cec76f01b3ac0377855650bdba6573b73ae Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 3 Nov 2022 10:07:09 -0700 Subject: [PATCH 4/7] check for heterogeneous inter-volume data --- grudge/trace_pair.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 84dedf386..7358e5af8 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -915,7 +915,7 @@ def cross_rank_inter_volume_trace_pairs( """ # {{{ process arguments - for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): for vol_dd in vol_dd_pair: if not isinstance(vol_dd.domain_tag, VolumeDomainTag): raise ValueError( @@ -925,6 +925,9 @@ def cross_rank_inter_volume_trace_pairs( raise ValueError( "expected base-discretized DOFDesc in pairwise_volume_data, " f"got '{vol_dd}'") + # FIXME: This check could probably be made more robust + if type(vol_data_pair[0]) != type(vol_data_pair[1]): # noqa: E721 + raise ValueError("heterogeneous inter-volume data not supported.") # }}} From bfad1f77dc2aaf0b8340190e4d39b21a3ad7b297 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 3 Nov 2022 10:07:26 -0700 Subject: [PATCH 5/7] tag communication by destination volume --- grudge/trace_pair.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 7358e5af8..acc086505 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -570,10 +570,17 @@ def __init__(self, self.local_bdry_data = local_bdry_data self.remote_bdry_data_template = remote_bdry_data_template - self.comm_tag = self.base_comm_tag - comm_tag = _sym_tag_to_num_tag(comm_tag) - if comm_tag is not None: - self.comm_tag += comm_tag + def _generate_num_comm_tag(sym_comm_tag): + result = self.base_comm_tag + num_comm_tag = _sym_tag_to_num_tag(sym_comm_tag) + if num_comm_tag is not None: + result += num_comm_tag + return result + + send_sym_comm_tag = (remote_part_id.volume_tag, comm_tag) + recv_sym_comm_tag = (local_part_id.volume_tag, comm_tag) + self.send_comm_tag = _generate_num_comm_tag(send_sym_comm_tag) + self.recv_comm_tag = _generate_num_comm_tag(recv_sym_comm_tag) del comm_tag # NOTE: mpi4py currently (2021-11-03) holds a reference to the send @@ -588,7 +595,7 @@ def send_single_array(key, local_subary): if not isinstance(local_subary, Number): local_subary_np = to_numpy(local_subary, actx) self.send_reqs.append( - comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag)) + comm.Isend(local_subary_np, remote_rank, tag=self.send_comm_tag)) self.send_data.append(local_subary_np) return local_subary @@ -601,7 +608,8 @@ def recv_single_array(key, remote_subary_template): remote_subary_template.shape, remote_subary_template.dtype) self.recv_reqs.append( - comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag)) + comm.Irecv(remote_subary_np, remote_rank, + tag=self.recv_comm_tag)) self.recv_data[key] = remote_subary_np return remote_subary_template @@ -702,7 +710,7 @@ def send_single_array(key, local_subary): if isinstance(local_subary, Number): return else: - ary_tag = (comm_tag, key) + ary_tag = (remote_part_id.volume_tag, comm_tag, key) sends[key] = make_distributed_send( local_subary, dest_rank=remote_rank, comm_tag=ary_tag) @@ -711,7 +719,7 @@ def recv_single_array(key, remote_subary_template): # NOTE: Assumes that the same number is passed on every rank return remote_subary_template else: - ary_tag = (comm_tag, key) + ary_tag = (local_part_id.volume_tag, comm_tag, key) return DistributedSendRefHolder( sends[key], make_distributed_recv( From c112288a0cc2058060a6cf72bbb5e2efc987356c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 3 Apr 2023 15:28:38 -0500 Subject: [PATCH 6/7] add filter_part_boundaries eases setting up boundaries when calling operators on only one volume (i.e., uncoupled) --- grudge/discretization.py | 44 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/grudge/discretization.py b/grudge/discretization.py index fd4c39728..25d5fa797 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -8,6 +8,7 @@ .. currentmodule:: grudge.discretization .. autoclass:: PartID +.. autofunction:: filter_part_boundaries """ __copyright__ = """ @@ -35,7 +36,8 @@ THE SOFTWARE. """ -from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any +from typing import ( + Sequence, Mapping, Optional, Union, List, Tuple, TYPE_CHECKING, Any) from pytools import memoize_method, single_valued @@ -1015,4 +1017,44 @@ def make_discretization_collection( # }}} +# {{{ filter_part_boundaries + +def filter_part_boundaries( + dcoll: DiscretizationCollection, + *, + volume_dd: DOFDesc = DD_VOLUME_ALL, + neighbor_volume_dd: Optional[DOFDesc] = None, + neighbor_rank: Optional[int] = None) -> List[DOFDesc]: + """ + Retrieve tags of part boundaries that match *neighbor_volume_dd* and/or + *neighbor_rank*. + """ + vol_mesh = dcoll.discr_from_dd(volume_dd).mesh + + from meshmode.mesh import InterPartAdjacencyGroup + filtered_part_bdry_dds = [ + volume_dd.trace(fagrp.boundary_tag) + for fagrp_list in vol_mesh.facial_adjacency_groups + for fagrp in fagrp_list + if isinstance(fagrp, InterPartAdjacencyGroup)] + + if neighbor_volume_dd is not None: + filtered_part_bdry_dds = [ + bdry_dd + for bdry_dd in filtered_part_bdry_dds + if ( + bdry_dd.domain_tag.tag.part_id.volume_tag + == neighbor_volume_dd.domain_tag.tag)] + + if neighbor_rank is not None: + filtered_part_bdry_dds = [ + bdry_dd + for bdry_dd in filtered_part_bdry_dds + if bdry_dd.domain_tag.tag.part_id.rank == neighbor_rank] + + return filtered_part_bdry_dds + +# }}} + + # vim: foldmethod=marker From 949d5a08608231bcb567a311381c3a6b27d33c89 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 26 Jan 2024 13:57:41 -0600 Subject: [PATCH 7/7] use make_distributed_send_ref_holder --- grudge/trace_pair.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 1f1cbc26e..e1a9bc69b 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -709,7 +709,7 @@ def __init__(self, from pytato import ( make_distributed_recv, make_distributed_send, - DistributedSendRefHolder) + make_distributed_send_ref_holder) # TODO: This currently assumes that local_bdry_data and # remote_bdry_data_template have the same structure. This is not true @@ -734,7 +734,7 @@ def recv_single_array(key, remote_subary_template): return remote_subary_template else: ary_tag = (local_part_id.volume_tag, comm_tag, key) - return DistributedSendRefHolder( + return make_distributed_send_ref_holder( sends[key], make_distributed_recv( src_rank=remote_rank, comm_tag=ary_tag,