diff --git a/conda_package/mpas_tools/mesh/mask.py b/conda_package/mpas_tools/mesh/mask.py index 793a36cb2..c147bdacd 100644 --- a/conda_package/mpas_tools/mesh/mask.py +++ b/conda_package/mpas_tools/mesh/mask.py @@ -21,14 +21,16 @@ """ import argparse -from functools import partial +from concurrent.futures import ThreadPoolExecutor import numpy -import progressbar +import shapely import shapely.geometry import xarray as xr from geometric_features import read_feature_collection from igraph import Graph +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components from scipy.spatial import KDTree from shapely.geometry import GeometryCollection, MultiPolygon, Polygon, box from shapely.strtree import STRtree @@ -110,10 +112,8 @@ def compute_mpas_region_masks( if logger is not None: logger.info(f' Computing {maskType} masks:') - # create shapely geometry for lon and lat - points = [ - shapely.geometry.Point(x, y) for x, y in zip(lon, lat, strict=True) - ] + # create shapely geometry for lon and lat (vectorized) + points = list(shapely.points(numpy.stack([lon, lat], axis=-1))) regionNames, masks, properties = _compute_region_masks( fcMask, points, @@ -709,10 +709,8 @@ def compute_lon_lat_region_masks( Lon = Lon.ravel() Lat = Lat.ravel() - # create shapely geometry for lon and lat - points = [ - shapely.geometry.Point(x, y) for x, y in zip(Lon, Lat, strict=True) - ] + # create shapely geometry for lon and lat (vectorized) + points = list(shapely.points(numpy.stack([Lon, Lat], axis=-1))) regionNames, masks, properties = _compute_region_masks( fcMask, points, @@ -934,11 +932,10 @@ def compute_projection_grid_region_masks( ny, nx = lon.shape - # create shapely geometry for lon and lat - points = [ - shapely.geometry.Point(x, y) - for x, y in zip(lon.ravel(), lat.ravel(), strict=True) - ] + # create shapely geometry for lon and lat (vectorized) + points = list( + shapely.points(numpy.stack([lon.ravel(), lat.ravel()], axis=-1)) + ) regionNames, masks, properties = _compute_region_masks( fcMask, points, @@ -1104,52 +1101,6 @@ def entry_point_compute_projection_grid_region_masks(): ) -def _compute_mask_from_shapes( - shapes1, shapes2, func, pool, chunkSize, showProgress -): - """ - If multiprocessing, break shapes2 into chunks and use multiprocessing to - apply the given function one chunk at a time - """ - nShapes2 = len(shapes2) - if pool is None: - mask = func(shapes1, shapes2) - else: - nChunks = int(numpy.ceil(nShapes2 / chunkSize)) - chunks = [] - indices = [0] - for iChunk in range(nChunks): - start = iChunk * chunkSize - end = min((iChunk + 1) * chunkSize, nShapes2) - chunks.append(shapes2[start:end]) - indices.append(end) - - partial_func = partial(func, shapes1) - if showProgress: - widgets = [ - ' ', - progressbar.Percentage(), - ' ', - progressbar.Bar(), - ' ', - progressbar.ETA(), - ] - bar = progressbar.ProgressBar( - widgets=widgets, maxval=nChunks - ).start() - else: - bar = None - - mask = numpy.zeros((nShapes2,), bool) - for iChunk, maskChunk in enumerate(pool.imap(partial_func, chunks)): - mask[indices[iChunk] : indices[iChunk + 1]] = maskChunk - if showProgress: - bar.update(iChunk + 1) - if showProgress: - bar.finish() - return mask - - def _add_properties(ds, properties, dim): """ Add properties to the dataset from a dictionary of properties @@ -1210,40 +1161,31 @@ def _compute_region_masks( regionNames, properties = _get_region_names_and_properties(fcMask) - masks = [] + # Build the spatial index once for all points + tree = STRtree(points) + nPoints = len(points) - for feature in fcMask.features: + def _query(feature): name = feature['properties']['name'] - if logger is not None: logger.info(f' {name}') - shape = shapely.geometry.shape(feature['geometry']) - shapes = _katana(shape, threshold=threshold) - - mask = _compute_mask_from_shapes( - shapes1=shapes, - shapes2=points, - func=_contains, - pool=pool, - chunkSize=chunkSize, - showProgress=showProgress, - ) - - masks.append(mask) + katana_shapes = _katana(shape, threshold=threshold) + mask = numpy.zeros(nPoints, dtype=bool) + for s in katana_shapes: + mask[tree.query(s, predicate='covers')] = True + return mask + + if pool is not None and len(fcMask.features) > 1: + n_workers = getattr(pool, '_processes', None) or 1 + with ThreadPoolExecutor(max_workers=n_workers) as executor: + masks = list(executor.map(_query, fcMask.features)) + else: + masks = [_query(f) for f in fcMask.features] return regionNames, masks, properties -def _contains(shapes, points): - tree = STRtree(points) - mask = numpy.zeros(len(points), dtype=bool) - for shape in shapes: - indicesInShape = tree.query(shape, predicate='covers') - mask[indicesInShape] = True - return mask - - def _katana(geometry, threshold, count=0, maxcount=250): """ From https://snorfalorpagus.net/blog/2016/03/13/splitting-large-polygons-for-faster-intersections/ @@ -1332,12 +1274,9 @@ def _compute_transect_masks( transectNames, properties = _get_region_names_and_properties(fcMask) - masks = [] shapes = [] - for feature in fcMask.features: name = feature['properties']['name'] - if logger is not None: logger.info(f' {name}') @@ -1373,27 +1312,25 @@ def _compute_transect_masks( else: shape = shapely.geometry.MultiLineString(new_coords) - mask = _compute_mask_from_shapes( - shapes1=shape, - shapes2=polygons, - func=_intersects, - pool=pool, - chunkSize=chunkSize, - showProgress=showProgress, - ) - - masks.append(mask) shapes.append(shape) - return transectNames, masks, properties, shapes + # Build the spatial index once for all polygons + tree = STRtree(polygons) + nPolygons = len(polygons) + def _query(shape): + mask = numpy.zeros(nPolygons, dtype=bool) + mask[tree.query(shape, predicate='intersects')] = True + return mask -def _intersects(shape, polygons): - tree = STRtree(polygons) - mask = numpy.zeros(len(polygons), dtype=bool) - indicesInShape = tree.query(shape, predicate='intersects') - mask[indicesInShape] = True - return mask + if pool is not None and len(shapes) > 1: + n_workers = getattr(pool, '_processes', None) or 1 + with ThreadPoolExecutor(max_workers=n_workers) as executor: + masks = list(executor.map(_query, shapes)) + else: + masks = [_query(s) for s in shapes] + + return transectNames, masks, properties, shapes def _get_polygons(dsMesh, maskType): @@ -1476,10 +1413,8 @@ def _get_polygons(dsMesh, maskType): nPolygons = len(lonCenter) - polygons = [] - for index in range(lon.shape[0]): - coords = zip(lon[index, :], lat[index, :], strict=True) - polygons.append(shapely.geometry.Polygon(coords)) + coords_3d = numpy.stack([lon, lat], axis=-1) + polygons = list(shapely.polygons(coords_3d)) return polygons, nPolygons, duplicatePolygons @@ -1544,30 +1479,41 @@ def _compute_seed_mask(fcSeed, lon, lat, workers): def _flood_fill_mask(seedMask, growMask, cellsOnCell): """ - Flood fill starting with a mask of seed points - """ + Flood fill starting with a mask of seed points. + Uses scipy connected components to identify all cells reachable from the + seed cells through regions where growMask == 1. + """ + nCells = cellsOnCell.shape[0] maxNeighbors = cellsOnCell.shape[1] - while True: - neighbors = cellsOnCell[seedMask == 1, :] - maskCount = 0 - for iNeighbor in range(maxNeighbors): - indices = neighbors[:, iNeighbor] - # we only want to mask valid neighbors, locations that aren't - # already masked, and locations that we're allowed to flood - indices = indices[indices >= 0] - localMask = numpy.logical_and( - seedMask[indices] == 0, growMask[indices] == 1 - ) - maskCount += numpy.count_nonzero(localMask) - indices = indices[localMask] - seedMask[indices] = 1 - - if maskCount == 0: - break - - return seedMask + # Build edge list for the subgraph of growMask==1 cells + i_flat = numpy.repeat(numpy.arange(nCells), maxNeighbors) + j_flat = cellsOnCell.ravel() + valid = (j_flat >= 0) & (growMask[i_flat] == 1) & (growMask[j_flat] == 1) + adj = csr_matrix( + (numpy.ones(valid.sum(), dtype=bool), (i_flat[valid], j_flat[valid])), + shape=(nCells, nCells), + ) + _, labels = connected_components(adj, directed=False) + + seed_cells = numpy.flatnonzero(seedMask) + + # Collect component IDs that contain a seed cell inside growMask + seeds_in_grow = seed_cells[growMask[seed_cells] == 1] + seed_comps = set(labels[seeds_in_grow].tolist()) + + # Seeds outside growMask can still prime their growMask==1 neighbors + for seed in seed_cells[growMask[seed_cells] == 0]: + for nb in cellsOnCell[seed]: + if nb >= 0 and growMask[nb] == 1: + seed_comps.add(int(labels[nb])) + + result = numpy.zeros(nCells, dtype=numpy.int32) + if seed_comps: + result[numpy.isin(labels, list(seed_comps))] = 1 + result[seed_cells] = 1 # original seeds always included + return result def _compute_edge_sign(dsMesh, edgeMask, shape):