From ffbe93ec12cbf4a674c1cd1fd4986a93b236ed5c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 24 May 2023 19:20:33 -0700 Subject: [PATCH 1/3] add axes for reshapes, etc. in direction connection --- meshmode/discretization/connection/direct.py | 67 ++++++++++++-------- meshmode/transform_metadata.py | 10 +++ 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index a84b2282b..37acf0090 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -30,7 +30,8 @@ import loopy as lp from meshmode.transform_metadata import ( ConcurrentElementInameTag, ConcurrentDOFInameTag, - DiscretizationElementAxisTag, DiscretizationDOFAxisTag) + DiscretizationElementAxisTag, DiscretizationDOFAxisTag, + DiscretizationDOFPickListAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError, @@ -553,17 +554,22 @@ def _per_target_group_pick_info( _FromGroupPickData( from_group_index=source_group_index, dof_pick_lists=actx.freeze( - actx.tag(NameHint("dof_pick_lists"), - actx.from_numpy(dof_pick_lists))), + actx.tag_axis(0, DiscretizationDOFPickListAxisTag(), + actx.tag(NameHint("dof_pick_lists"), + actx.from_numpy(dof_pick_lists)))), dof_pick_list_indices=actx.freeze( - actx.tag(NameHint("dof_pick_list_indices"), - actx.from_numpy(dof_pick_list_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("dof_pick_list_indices"), + actx.from_numpy(dof_pick_list_indices)))), from_el_present=actx.freeze( - actx.tag(NameHint("from_el_present"), - actx.from_numpy(from_el_present.astype(np.int8)))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_present"), + actx.from_numpy( + from_el_present.astype(np.int8))))), from_element_indices=actx.freeze( - actx.tag(NameHint("from_el_indices"), - actx.from_numpy(from_el_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_indices"), + actx.from_numpy(from_el_indices)))), is_surjective=from_el_present.all() )) @@ -732,25 +738,27 @@ def group_pick_knl(is_surjective: bool): group_pick_info = None if group_pick_info is not None: - group_array_contributions = [] - if actx.permits_advanced_indexing and not _force_use_loopy: for fgpd in group_pick_info: from_element_indices = actx.thaw(fgpd.from_element_indices) if ary[fgpd.from_group_index].size: grp_ary_contrib = ary[fgpd.from_group_index][ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), - actx.thaw(fgpd.dof_pick_lists)[ - actx.thaw(fgpd.dof_pick_list_indices)] - ] + actx, from_element_indices, (-1, 1))), + actx.thaw(fgpd.dof_pick_lists)[ + actx.thaw(fgpd.dof_pick_list_indices)] + ] if not fgpd.is_surjective: from_el_present = actx.thaw(fgpd.from_el_present) grp_ary_contrib = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1))), grp_ary_contrib, 0) @@ -800,8 +808,10 @@ def group_pick_knl(is_surjective: bool): mat = self._resample_matrix(actx, i_tgrp, i_batch) if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1))), actx.einsum("ij,ej->ei", mat, grp_ary[from_element_indices]), 0) @@ -822,11 +832,15 @@ def group_pick_knl(is_surjective: bool): if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), - from_vec[ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), + actx, from_el_present, (-1, 1))), + from_vec[ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_element_indices, (-1, 1))), pick_list], 0) else: @@ -853,10 +867,13 @@ def group_pick_knl(is_surjective: bool): else: # If no batched data at all, return zeros for this # particular group array - group_array = actx.zeros( + group_array = tag_axes(actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + actx.zeros( shape=(self.to_discr.groups[i_tgrp].nelements, self.to_discr.groups[i_tgrp].nunit_dofs), - dtype=ary.entry_dtype) + dtype=ary.entry_dtype)) group_arrays.append(group_array) diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index f622310b1..2dfd6febd 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -8,6 +8,7 @@ .. autoclass:: DiscretizationDOFAxisTag .. autoclass:: DiscretizationAmbientDimAxisTag .. autoclass:: DiscretizationTopologicalDimAxisTag +.. autoclass:: DiscretizationDOFPickListAxisTag """ __copyright__ = """ @@ -121,3 +122,12 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag): Array dimensions tagged with this tag type describe an axis indexing over the discretization's physical coordinate dimensions. """ + + +@tag_dataclass +class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + DOF pick lists. See :mod:`meshmode.discretization.connection.direct` for + details. + """ From a3cff0d5fa3448d01ca6aabdd19c621bcd530715 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 31 May 2023 08:52:28 -0700 Subject: [PATCH 2/3] avoid referencing undocumented module --- meshmode/transform_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index 2dfd6febd..54db12b14 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -128,6 +128,6 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag): class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag): """ Array dimensions tagged with this tag type describe an axis indexing over - DOF pick lists. See :mod:`meshmode.discretization.connection.direct` for - details. + DOF pick lists in + :class:`meshmode.discretization.connection.DirectDiscretizationConnection`. """ From 9e8d6fa3ac7460dc881af49ae66178077a9b7661 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 31 May 2023 11:19:22 -0700 Subject: [PATCH 3/3] tag a few more axes --- meshmode/discretization/connection/direct.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 37acf0090..0d55a7c10 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -167,12 +167,14 @@ def _global_from_element_indices( np_full_from_element_indices[~np_from_el_present] = 0 from_el_present = actx.freeze( - actx.tag(NameHint("from_el_present"), - actx.from_numpy( - np_from_el_present.astype(np.int8)))) + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_present"), + actx.from_numpy( + np_from_el_present.astype(np.int8))))) full_from_element_indices = actx.freeze( - actx.tag(NameHint("from_el_indices"), - actx.from_numpy(np_full_from_element_indices))) + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_indices"), + actx.from_numpy(np_full_from_element_indices)))) self._global_from_element_indices_cache = ( from_el_present, full_from_element_indices)