Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 77 additions & 131 deletions conda_package/mpas_tools/mesh/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
"""

import argparse
from functools import partial
from concurrent.futures import ThreadPoolExecutor

import numpy
import progressbar
import shapely
import shapely.geometry
import xarray as xr
from geometric_features import read_feature_collection
from igraph import Graph
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.spatial import KDTree
from shapely.geometry import GeometryCollection, MultiPolygon, Polygon, box
from shapely.strtree import STRtree
Expand Down Expand Up @@ -110,10 +112,8 @@ def compute_mpas_region_masks(
if logger is not None:
logger.info(f' Computing {maskType} masks:')

# create shapely geometry for lon and lat
points = [
shapely.geometry.Point(x, y) for x, y in zip(lon, lat, strict=True)
]
# create shapely geometry for lon and lat (vectorized)
points = list(shapely.points(numpy.stack([lon, lat], axis=-1)))
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
Expand Down Expand Up @@ -709,10 +709,8 @@ def compute_lon_lat_region_masks(
Lon = Lon.ravel()
Lat = Lat.ravel()

# create shapely geometry for lon and lat
points = [
shapely.geometry.Point(x, y) for x, y in zip(Lon, Lat, strict=True)
]
# create shapely geometry for lon and lat (vectorized)
points = list(shapely.points(numpy.stack([Lon, Lat], axis=-1)))
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
Expand Down Expand Up @@ -934,11 +932,10 @@ def compute_projection_grid_region_masks(

ny, nx = lon.shape

# create shapely geometry for lon and lat
points = [
shapely.geometry.Point(x, y)
for x, y in zip(lon.ravel(), lat.ravel(), strict=True)
]
# create shapely geometry for lon and lat (vectorized)
points = list(
shapely.points(numpy.stack([lon.ravel(), lat.ravel()], axis=-1))
)
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
Expand Down Expand Up @@ -1104,52 +1101,6 @@ def entry_point_compute_projection_grid_region_masks():
)


def _compute_mask_from_shapes(
shapes1, shapes2, func, pool, chunkSize, showProgress
):
"""
If multiprocessing, break shapes2 into chunks and use multiprocessing to
apply the given function one chunk at a time
"""
nShapes2 = len(shapes2)
if pool is None:
mask = func(shapes1, shapes2)
else:
nChunks = int(numpy.ceil(nShapes2 / chunkSize))
chunks = []
indices = [0]
for iChunk in range(nChunks):
start = iChunk * chunkSize
end = min((iChunk + 1) * chunkSize, nShapes2)
chunks.append(shapes2[start:end])
indices.append(end)

partial_func = partial(func, shapes1)
if showProgress:
widgets = [
' ',
progressbar.Percentage(),
' ',
progressbar.Bar(),
' ',
progressbar.ETA(),
]
bar = progressbar.ProgressBar(
widgets=widgets, maxval=nChunks
).start()
else:
bar = None

mask = numpy.zeros((nShapes2,), bool)
for iChunk, maskChunk in enumerate(pool.imap(partial_func, chunks)):
mask[indices[iChunk] : indices[iChunk + 1]] = maskChunk
if showProgress:
bar.update(iChunk + 1)
if showProgress:
bar.finish()
return mask


def _add_properties(ds, properties, dim):
"""
Add properties to the dataset from a dictionary of properties
Expand Down Expand Up @@ -1210,40 +1161,31 @@ def _compute_region_masks(

regionNames, properties = _get_region_names_and_properties(fcMask)

masks = []
# Build the spatial index once for all points
tree = STRtree(points)
nPoints = len(points)

for feature in fcMask.features:
def _query(feature):
name = feature['properties']['name']

if logger is not None:
logger.info(f' {name}')

shape = shapely.geometry.shape(feature['geometry'])
shapes = _katana(shape, threshold=threshold)

mask = _compute_mask_from_shapes(
shapes1=shapes,
shapes2=points,
func=_contains,
pool=pool,
chunkSize=chunkSize,
showProgress=showProgress,
)

masks.append(mask)
katana_shapes = _katana(shape, threshold=threshold)
mask = numpy.zeros(nPoints, dtype=bool)
for s in katana_shapes:
mask[tree.query(s, predicate='covers')] = True
return mask

if pool is not None and len(fcMask.features) > 1:
n_workers = getattr(pool, '_processes', None) or 1
with ThreadPoolExecutor(max_workers=n_workers) as executor:
masks = list(executor.map(_query, fcMask.features))
else:
masks = [_query(f) for f in fcMask.features]

return regionNames, masks, properties


def _contains(shapes, points):
tree = STRtree(points)
mask = numpy.zeros(len(points), dtype=bool)
for shape in shapes:
indicesInShape = tree.query(shape, predicate='covers')
mask[indicesInShape] = True
return mask


def _katana(geometry, threshold, count=0, maxcount=250):
"""
From https://snorfalorpagus.net/blog/2016/03/13/splitting-large-polygons-for-faster-intersections/
Expand Down Expand Up @@ -1332,12 +1274,9 @@ def _compute_transect_masks(

transectNames, properties = _get_region_names_and_properties(fcMask)

masks = []
shapes = []

for feature in fcMask.features:
name = feature['properties']['name']

if logger is not None:
logger.info(f' {name}')

Expand Down Expand Up @@ -1373,27 +1312,25 @@ def _compute_transect_masks(
else:
shape = shapely.geometry.MultiLineString(new_coords)

mask = _compute_mask_from_shapes(
shapes1=shape,
shapes2=polygons,
func=_intersects,
pool=pool,
chunkSize=chunkSize,
showProgress=showProgress,
)

masks.append(mask)
shapes.append(shape)

return transectNames, masks, properties, shapes
# Build the spatial index once for all polygons
tree = STRtree(polygons)
nPolygons = len(polygons)

def _query(shape):
mask = numpy.zeros(nPolygons, dtype=bool)
mask[tree.query(shape, predicate='intersects')] = True
return mask

def _intersects(shape, polygons):
tree = STRtree(polygons)
mask = numpy.zeros(len(polygons), dtype=bool)
indicesInShape = tree.query(shape, predicate='intersects')
mask[indicesInShape] = True
return mask
if pool is not None and len(shapes) > 1:
n_workers = getattr(pool, '_processes', None) or 1
with ThreadPoolExecutor(max_workers=n_workers) as executor:
masks = list(executor.map(_query, shapes))
else:
masks = [_query(s) for s in shapes]

return transectNames, masks, properties, shapes


def _get_polygons(dsMesh, maskType):
Expand Down Expand Up @@ -1476,10 +1413,8 @@ def _get_polygons(dsMesh, maskType):

nPolygons = len(lonCenter)

polygons = []
for index in range(lon.shape[0]):
coords = zip(lon[index, :], lat[index, :], strict=True)
polygons.append(shapely.geometry.Polygon(coords))
coords_3d = numpy.stack([lon, lat], axis=-1)
polygons = list(shapely.polygons(coords_3d))

return polygons, nPolygons, duplicatePolygons

Expand Down Expand Up @@ -1544,30 +1479,41 @@ def _compute_seed_mask(fcSeed, lon, lat, workers):

def _flood_fill_mask(seedMask, growMask, cellsOnCell):
"""
Flood fill starting with a mask of seed points
"""
Flood fill starting with a mask of seed points.

Uses scipy connected components to identify all cells reachable from the
seed cells through regions where growMask == 1.
"""
nCells = cellsOnCell.shape[0]
maxNeighbors = cellsOnCell.shape[1]

while True:
neighbors = cellsOnCell[seedMask == 1, :]
maskCount = 0
for iNeighbor in range(maxNeighbors):
indices = neighbors[:, iNeighbor]
# we only want to mask valid neighbors, locations that aren't
# already masked, and locations that we're allowed to flood
indices = indices[indices >= 0]
localMask = numpy.logical_and(
seedMask[indices] == 0, growMask[indices] == 1
)
maskCount += numpy.count_nonzero(localMask)
indices = indices[localMask]
seedMask[indices] = 1

if maskCount == 0:
break

return seedMask
# Build edge list for the subgraph of growMask==1 cells
i_flat = numpy.repeat(numpy.arange(nCells), maxNeighbors)
j_flat = cellsOnCell.ravel()
valid = (j_flat >= 0) & (growMask[i_flat] == 1) & (growMask[j_flat] == 1)
adj = csr_matrix(
(numpy.ones(valid.sum(), dtype=bool), (i_flat[valid], j_flat[valid])),
shape=(nCells, nCells),
)
_, labels = connected_components(adj, directed=False)

seed_cells = numpy.flatnonzero(seedMask)

# Collect component IDs that contain a seed cell inside growMask
seeds_in_grow = seed_cells[growMask[seed_cells] == 1]
seed_comps = set(labels[seeds_in_grow].tolist())

# Seeds outside growMask can still prime their growMask==1 neighbors
for seed in seed_cells[growMask[seed_cells] == 0]:
for nb in cellsOnCell[seed]:
if nb >= 0 and growMask[nb] == 1:
seed_comps.add(int(labels[nb]))

result = numpy.zeros(nCells, dtype=numpy.int32)
if seed_comps:
result[numpy.isin(labels, list(seed_comps))] = 1
result[seed_cells] = 1 # original seeds always included
return result


def _compute_edge_sign(dsMesh, edgeMask, shape):
Expand Down
Loading