diff --git a/simpeg_drivers/components/topography.py b/simpeg_drivers/components/topography.py index 1e251f62..40f20a39 100644 --- a/simpeg_drivers/components/topography.py +++ b/simpeg_drivers/components/topography.py @@ -38,6 +38,8 @@ floating_active, get_containing_cells, get_neighbouring_cells, + mask_vertices_and_cells, + octree_extents, ) @@ -101,14 +103,26 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray: active_cells = InversionModel.obj_2_mesh( self.params.active_cells.active_model, mesh.entity ) + else: + if any(k in self.params.inversion_type for k in ["2d", "p3d"]): + vertices = self.locations + cells = getattr( + self.params.active_cells.topography_object, "cells", None + ) + else: + extent = octree_extents(mesh.entity)[:4] + vertices, cells = mask_vertices_and_cells( + extent.ravel(order="F"), + self.locations, + getattr(self.params.active_cells.topography_object, "cells", None), + ) + active_cells = active_from_xyz( mesh.entity, - self.locations, + vertices, grid_reference="bottom" if forced_to_surface else "center", - triangulation=getattr( - self.params.active_cells.topography_object, "cells", None - ), + triangulation=cells, ) active_cells = (mesh.permutation @ active_cells).astype(bool) diff --git a/simpeg_drivers/utils/utils.py b/simpeg_drivers/utils/utils.py index 9312b581..37eafb8c 100644 --- a/simpeg_drivers/utils/utils.py +++ b/simpeg_drivers/utils/utils.py @@ -11,6 +11,7 @@ from __future__ import annotations +from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING @@ -46,6 +47,60 @@ from simpeg_drivers.driver import InversionDriver +def octree_extents(octree: Octree) -> np.ndarray: + """ + Get the true extents of an octree (min/max of the perimeter). + + The octree.extents property returns min/max of the centroids + + :param octree: Octree mesh object. + + :returns: Array of [xmin, xmax, ymin, ymax]. + """ + + origin = np.array(list(octree.origin.tolist())) + span = np.array( + [ + getattr(octree, f"{axis}_cell_size") * getattr(octree, f"{axis}_count") + for axis in "uvw" + ] + ) + + return np.stack([origin, origin + span]).flatten(order="F") + + +def mask_vertices_and_cells( + extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None +) -> tuple[np.ndarray, np.ndarray]: + """ + Mask vertices and remove cells whose vertices are all outside the extent. + + :param extent: Array-like object of [xmin, xmax, ymin, ymax]. + :param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates. + :param cells: Array of shape (n_cells, 3) containing the indices of the vertices + that make up each cell. + """ + + vertex_mask = ( + (vertices[:, 0] >= extent[0]) + & (vertices[:, 0] <= extent[1]) + & (vertices[:, 1] >= extent[2]) + & (vertices[:, 1] <= extent[3]) + ) + if cells is None: + return vertices[vertex_mask], None + + cell_mask = np.any(vertex_mask[cells], axis=1) + vertex_mask = np.zeros_like(vertex_mask, dtype=bool) + vertex_mask[cells[cell_mask].flatten()] = True + + new_cells = cells.copy()[cell_mask] + cell_map = np.arange(len(vertices))[vertex_mask] + new_cells = np.searchsorted(cell_map, new_cells) + + return vertices[vertex_mask], new_cells + + def calculate_2D_trend( points: np.ndarray, values: np.ndarray, order: int = 0, method: str = "all" ): @@ -500,7 +555,7 @@ def active_from_xyz( raise ValueError("'grid_reference' must be one of 'center', 'top', or 'bottom'") # Return the active cell array - return mask_under_horizon(locations, topo, triangulation=triangulation) + return mask_under_horizon(locations, horizon=topo, triangulation=triangulation) def truncate_locs_depths(locs: np.ndarray, depth_core: float) -> np.ndarray: diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 00000000..fdfb8014 --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,64 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +import numpy as np +from geoh5py import Workspace +from geoh5py.objects import Octree, Points +from grid_apps.octree_creation.driver import OctreeDriver +from grid_apps.octree_creation.options import OctreeOptions, RefinementOptions + +from simpeg_drivers.utils.utils import mask_vertices_and_cells, octree_extents + + +def test_octree_extents(tmp_path): + with Workspace(tmp_path / "test.geoh5") as ws: + X, Y = np.meshgrid(np.linspace(0, 1000, 51), np.linspace(0, 1000, 51)) + Z = np.zeros_like(X) + vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()]) + pts = Points.create(ws, name="points", vertices=vertices) + options = OctreeOptions( + geoh5=ws, + objects=pts, + refinements=[ + RefinementOptions( + refinement_object=pts, levels=[4, 2], horizon=False, distance=100 + ), + ], + ) + octree = OctreeDriver.octree_from_params(options) + + extents = octree_extents(octree) + assert np.allclose(extents, [-1112.5, 2087.5, -1112.5, 2087.5, -1062.5, 537.5]) + + +def test_mask_vertices_and_cells(): + X, Y = np.meshgrid(np.arange(3), np.arange(3)) + Z = np.zeros_like(X) + vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()]) + cells = np.array( + [ + [0, 1, 3], + [3, 1, 4], + [1, 2, 4], + [4, 2, 5], + [3, 4, 6], + [6, 4, 7], + [4, 5, 7], + [7, 5, 8], + ] + ) + extent = [0.5, 2, 0, 2, 0, 1] + masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells) + assert len(masked_vertices) == len(vertices) + assert len(masked_cells) == len(cells) + extent = [1.5, 2, 0, 2, 0, 1] + masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells) + assert len(masked_vertices) == 6 + assert len(masked_cells) == 4