Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions simpeg_drivers/components/topography.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
floating_active,
get_containing_cells,
get_neighbouring_cells,
mask_vertices_and_cells,
octree_extents,
)


Expand Down Expand Up @@ -101,14 +103,26 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray:
active_cells = InversionModel.obj_2_mesh(
self.params.active_cells.active_model, mesh.entity
)

else:
if any(k in self.params.inversion_type for k in ["2d", "p3d"]):
vertices = self.locations
cells = getattr(
self.params.active_cells.topography_object, "cells", None
)
else:
extent = octree_extents(mesh.entity)[:4]
vertices, cells = mask_vertices_and_cells(
extent.ravel(order="F"),
self.locations,
getattr(self.params.active_cells.topography_object, "cells", None),
)

active_cells = active_from_xyz(
mesh.entity,
self.locations,
vertices,
grid_reference="bottom" if forced_to_surface else "center",
triangulation=getattr(
self.params.active_cells.topography_object, "cells", None
),
triangulation=cells,
)

active_cells = (mesh.permutation @ active_cells).astype(bool)
Expand Down
57 changes: 56 additions & 1 deletion simpeg_drivers/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -46,6 +47,60 @@
from simpeg_drivers.driver import InversionDriver


def octree_extents(octree: Octree) -> np.ndarray:
"""
Get the true extents of an octree (min/max of the perimeter).

The octree.extents property returns min/max of the centroids

:param octree: Octree mesh object.

:returns: Array of [xmin, xmax, ymin, ymax].
"""

origin = np.array(list(octree.origin.tolist()))
span = np.array(
[
getattr(octree, f"{axis}_cell_size") * getattr(octree, f"{axis}_count")
for axis in "uvw"
]
)

return np.stack([origin, origin + span]).flatten(order="F")


def mask_vertices_and_cells(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's too bad, we are duplicating a lot of mechanics of geoh5py surface.copy_by_extent() method.
We can migrate this to geoh5py on GEOPY-2714

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've left some notes regarding the implementation to make sure it can be used for this case

extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray]:
"""
Mask vertices and remove cells whose vertices are all outside the extent.

:param extent: Array-like object of [xmin, xmax, ymin, ymax].
:param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates.
:param cells: Array of shape (n_cells, 3) containing the indices of the vertices
that make up each cell.
"""

vertex_mask = (
(vertices[:, 0] >= extent[0])
& (vertices[:, 0] <= extent[1])
& (vertices[:, 1] >= extent[2])
& (vertices[:, 1] <= extent[3])
)
if cells is None:
return vertices[vertex_mask], None

cell_mask = np.any(vertex_mask[cells], axis=1)
vertex_mask = np.zeros_like(vertex_mask, dtype=bool)
vertex_mask[cells[cell_mask].flatten()] = True

new_cells = cells.copy()[cell_mask]
cell_map = np.arange(len(vertices))[vertex_mask]
new_cells = np.searchsorted(cell_map, new_cells)

return vertices[vertex_mask], new_cells


def calculate_2D_trend(
points: np.ndarray, values: np.ndarray, order: int = 0, method: str = "all"
):
Expand Down Expand Up @@ -500,7 +555,7 @@ def active_from_xyz(
raise ValueError("'grid_reference' must be one of 'center', 'top', or 'bottom'")

# Return the active cell array
return mask_under_horizon(locations, topo, triangulation=triangulation)
return mask_under_horizon(locations, horizon=topo, triangulation=triangulation)


def truncate_locs_depths(locs: np.ndarray, depth_core: float) -> np.ndarray:
Expand Down
64 changes: 64 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
# Copyright (c) 2026 Mira Geoscience Ltd. '
# '
# This file is part of simpeg-drivers package. '
# '
# simpeg-drivers is distributed under the terms and conditions of the MIT License '
# (see LICENSE file at the root of this source code package). '
# '
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

import numpy as np
from geoh5py import Workspace
from geoh5py.objects import Octree, Points
from grid_apps.octree_creation.driver import OctreeDriver
from grid_apps.octree_creation.options import OctreeOptions, RefinementOptions

from simpeg_drivers.utils.utils import mask_vertices_and_cells, octree_extents


def test_octree_extents(tmp_path):
with Workspace(tmp_path / "test.geoh5") as ws:
X, Y = np.meshgrid(np.linspace(0, 1000, 51), np.linspace(0, 1000, 51))
Z = np.zeros_like(X)
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
pts = Points.create(ws, name="points", vertices=vertices)
options = OctreeOptions(
geoh5=ws,
objects=pts,
refinements=[
RefinementOptions(
refinement_object=pts, levels=[4, 2], horizon=False, distance=100
),
],
)
octree = OctreeDriver.octree_from_params(options)

extents = octree_extents(octree)
assert np.allclose(extents, [-1112.5, 2087.5, -1112.5, 2087.5, -1062.5, 537.5])


def test_mask_vertices_and_cells():
X, Y = np.meshgrid(np.arange(3), np.arange(3))
Z = np.zeros_like(X)
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
cells = np.array(
[
[0, 1, 3],
[3, 1, 4],
[1, 2, 4],
[4, 2, 5],
[3, 4, 6],
[6, 4, 7],
[4, 5, 7],
[7, 5, 8],
]
)
extent = [0.5, 2, 0, 2, 0, 1]
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
assert len(masked_vertices) == len(vertices)
assert len(masked_cells) == len(cells)
extent = [1.5, 2, 0, 2, 0, 1]
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
assert len(masked_vertices) == 6
assert len(masked_cells) == 4