diff --git a/README.md b/README.md index 881566f4..013282f1 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e | Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | |:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:| | [Allocation](xrspatial/proximity.py) | Assigns each cell to the identity of the nearest source feature | ✅️ | ✅ | | | +| [Cost Distance](xrspatial/cost_distance.py) | Computes minimum accumulated cost to the nearest source through a friction surface | ✅️ | ✅ | | | | [Direction](xrspatial/proximity.py) | Computes the direction from each cell to the nearest source feature | ✅️ | ✅ | | | | [Proximity](xrspatial/proximity.py) | Computes the distance from each cell to the nearest source feature | ✅️ | ✅ | | | diff --git a/examples/user_guide/9_Cost_Distance.ipynb b/examples/user_guide/9_Cost_Distance.ipynb new file mode 100644 index 00000000..7f8a39c9 --- /dev/null +++ b/examples/user_guide/9_Cost_Distance.ipynb @@ -0,0 +1,487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Xarray-spatial\n", + "### User Guide: Cost Distance (Weighted Proximity)\n", + "-----\n", + "\n", + "The `cost_distance` function computes the **minimum accumulated traversal cost** through a friction surface to reach the nearest target pixel. This is the raster equivalent of GRASS `r.cost` / ArcGIS *Cost Distance*.\n", + "\n", + "Unlike `proximity`, which measures geometric (straight-line) distance ignoring terrain, `cost_distance` accounts for a **friction surface** — a raster where each cell's value represents how costly it is to traverse that cell.\n", + "\n", + "**Contents:**\n", + "- [Setup](#Setup)\n", + "- [1. Uniform friction: cost_distance vs proximity](#1.-Uniform-friction:-cost_distance-vs-proximity)\n", + "- [2. Variable friction surface](#2.-Variable-friction-surface)\n", + "- [3. Barriers (impassable cells)](#3.-Barriers-(impassable-cells))\n", + "- [4. max_cost truncation](#4.-max_cost-truncation)\n", + "- [5. Multiple sources](#5.-Multiple-sources)\n", + "- [6. 4-connectivity vs 8-connectivity](#6.-4-connectivity-vs-8-connectivity)\n", + "- [7. Dask support for large rasters](#7.-Dask-support-for-large-rasters)\n", + "\n", + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.colors import Normalize\n", + "\n", + "from xrspatial import proximity, cost_distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_raster(data, res=1.0):\n", + " \"\"\"Helper: create a DataArray with y/x coordinates.\"\"\"\n", + " h, w = data.shape\n", + " raster = xr.DataArray(\n", + " data.astype(np.float64),\n", + " dims=['y', 'x'],\n", + " attrs={'res': (res, res)},\n", + " )\n", + " raster['y'] = np.arange(h) * res\n", + " raster['x'] = np.arange(w) * res\n", + " return raster\n", + "\n", + "\n", + "def plot_comparison(arrays, titles, cmaps=None, figsize=None):\n", + " \"\"\"Plot multiple arrays side by side.\"\"\"\n", + " n = len(arrays)\n", + " if figsize is None:\n", + " figsize = (5 * n, 4)\n", + " if cmaps is None:\n", + " cmaps = ['viridis'] * n\n", + " fig, axes = plt.subplots(1, n, figsize=figsize)\n", + " if n == 1:\n", + " axes = [axes]\n", + " for ax, arr, title, cmap in zip(axes, arrays, titles, cmaps):\n", + " data = arr.values if hasattr(arr, 'values') else arr\n", + " im = ax.imshow(data, cmap=cmap, origin='upper')\n", + " ax.set_title(title)\n", + " plt.colorbar(im, ax=ax, shrink=0.8)\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Uniform friction: cost_distance vs proximity\n", + "\n", + "When friction is uniform (all 1s), `cost_distance` approximates Euclidean `proximity` — the accumulated cost equals the geometric distance along the grid path.\n", + "\n", + "The two functions use different algorithms: `proximity` computes true Euclidean (straight-line) distance via a 4-pass DP sweep, while `cost_distance` uses Dijkstra on an 8-connected grid. On a grid, the shortest 8-connected path is slightly longer than the true Euclidean distance for non-axis-aligned directions (a well-known grid discretization artifact). The maximum error is bounded by `(sqrt(2) - 1) * cellsize` per step, and typically under 1 unit for small grids." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a 21x21 grid with a single source at the centre\n", + "source = np.zeros((21, 21))\n", + "source[10, 10] = 1.0\n", + "\n", + "raster = make_raster(source)\n", + "friction_uniform = make_raster(np.ones((21, 21)))\n", + "\n", + "# Compute both\n", + "prox_result = proximity(raster)\n", + "cost_result = cost_distance(raster, friction_uniform)\n", + "\n", + "plot_comparison(\n", + " [prox_result, cost_result, prox_result - cost_result],\n", + " ['proximity (Euclidean)', 'cost_distance (friction=1)', 'Difference'],\n", + " cmaps=['magma', 'magma', 'RdBu'],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Grid discretization artifact: Dijkstra on an 8-connected grid overestimates\n", + "# true Euclidean distance for non-axis-aligned directions.\n", + "# For a 21x21 grid this is well under 1 unit.\n", + "diff = np.abs(prox_result.values - cost_result.values)\n", + "print(f\"Max absolute difference: {np.nanmax(diff):.6f}\")\n", + "print(f\"Mean absolute difference: {np.nanmean(diff):.6f}\")\n", + "\n", + "# Cardinal and diagonal distances match exactly:\n", + "print(f\"\\nCardinal neighbour — proximity: {prox_result.values[10, 11]:.4f}, cost_distance: {cost_result.values[10, 11]:.4f}\")\n", + "print(f\"Diagonal neighbour — proximity: {prox_result.values[9, 9]:.4f}, cost_distance: {cost_result.values[9, 9]:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Variable friction surface\n", + "\n", + "Now let's add a **high-friction zone** (like a river, swamp, or dense forest) across the middle of the grid. `proximity` ignores this entirely, but `cost_distance` routes around it because traversal is expensive." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a friction surface with a high-cost band across rows 8-12\n", + "friction_data = np.ones((21, 21))\n", + "friction_data[8:13, :] = 10.0 # 10x more costly to cross this zone\n", + "\n", + "# Source at top-left corner\n", + "source2 = np.zeros((21, 21))\n", + "source2[2, 2] = 1.0\n", + "\n", + "raster2 = make_raster(source2)\n", + "friction_var = make_raster(friction_data)\n", + "\n", + "prox2 = proximity(raster2)\n", + "cost2 = cost_distance(raster2, friction_var)\n", + "\n", + "plot_comparison(\n", + " [friction_var, prox2, cost2],\n", + " ['Friction surface', 'proximity (ignores friction)', 'cost_distance'],\n", + " cmaps=['YlOrRd', 'magma', 'magma'],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how:\n", + "- `proximity` shows smooth concentric circles centred on the source — it doesn't \"see\" the friction band.\n", + "- `cost_distance` shows the cost jump across the high-friction band. Pixels on the far side have much higher accumulated cost because every path must traverse that expensive zone." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Barriers (impassable cells)\n", + "\n", + "Setting friction to **NaN** or **0** makes cells completely impassable. `cost_distance` will route around barriers; if pixels are fully cut off, they get NaN.\n", + "\n", + "Compare with `proximity`, which always uses straight-line distance through everything." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a wall with a narrow gap\n", + "friction_barrier = np.ones((21, 21))\n", + "friction_barrier[5:16, 10] = np.nan # vertical wall\n", + "friction_barrier[10, 10] = 1.0 # gap in the wall at row 10\n", + "\n", + "# Source on the left side\n", + "source3 = np.zeros((21, 21))\n", + "source3[10, 3] = 1.0\n", + "\n", + "raster3 = make_raster(source3)\n", + "friction_wall = make_raster(friction_barrier)\n", + "\n", + "prox3 = proximity(raster3)\n", + "cost3 = cost_distance(raster3, friction_wall)\n", + "\n", + "# Show friction surface with barrier visible\n", + "barrier_vis = friction_barrier.copy()\n", + "barrier_vis[np.isnan(barrier_vis)] = 0 # show barrier as 0 for visualisation\n", + "\n", + "plot_comparison(\n", + " [make_raster(barrier_vis), prox3, cost3],\n", + " ['Friction (dark = barrier)', 'proximity (ignores wall)', 'cost_distance (routes through gap)'],\n", + " cmaps=['gray', 'magma', 'magma'],\n", + " figsize=(15, 4),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `cost_distance` result shows that all paths to the right side must funnel through the single gap in the wall, creating asymmetric cost contours." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. max_cost truncation\n", + "\n", + "The `max_cost` parameter limits the search radius. Pixels whose cheapest path exceeds the budget are set to NaN. This is critical for Dask scalability — it bounds the overlap region needed for each chunk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source4 = np.zeros((21, 21))\n", + "source4[10, 10] = 1.0\n", + "\n", + "raster4 = make_raster(source4)\n", + "friction4 = make_raster(np.ones((21, 21)))\n", + "\n", + "cost_full = cost_distance(raster4, friction4)\n", + "cost_limited = cost_distance(raster4, friction4, max_cost=6.0)\n", + "\n", + "plot_comparison(\n", + " [cost_full, cost_limited],\n", + " ['cost_distance (no limit)', 'cost_distance (max_cost=6)'],\n", + " cmaps=['magma', 'magma'],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pixels beyond cost 6 are shown as white (NaN). This is useful for questions like \"which areas are reachable within a given travel budget?\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Multiple sources\n", + "\n", + "When multiple target pixels exist, `cost_distance` finds the cheapest path to the **nearest** source for each pixel — just like `proximity` finds the nearest by geometric distance.\n", + "\n", + "With non-uniform friction, the \"nearest\" source by cost may differ from the nearest by straight-line distance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Two sources, with a high-friction zone near one of them\n", + "source5 = np.zeros((21, 21))\n", + "source5[5, 5] = 1.0 # source A (top-left region)\n", + "source5[15, 15] = 2.0 # source B (bottom-right region)\n", + "\n", + "friction5_data = np.ones((21, 21))\n", + "friction5_data[3:8, 3:8] = 5.0 # high friction around source A\n", + "friction5_data[5, 5] = 1.0 # but source A itself is cheap to stand on\n", + "\n", + "raster5 = make_raster(source5)\n", + "friction5 = make_raster(friction5_data)\n", + "\n", + "prox5 = proximity(raster5)\n", + "cost5 = cost_distance(raster5, friction5)\n", + "\n", + "plot_comparison(\n", + " [friction5, prox5, cost5],\n", + " ['Friction (high near source A)', 'proximity', 'cost_distance'],\n", + " cmaps=['YlOrRd', 'magma', 'magma'],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how the cost-distance boundary between the two sources shifts: the high friction around source A makes source B \"closer\" (by cost) to more of the grid than geometric distance alone would suggest." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. 4-connectivity vs 8-connectivity\n", + "\n", + "By default, `cost_distance` uses 8-connectivity (cardinal + diagonal neighbours). With `connectivity=4`, only cardinal neighbours (up/down/left/right) are considered. This means diagonal paths must take two steps instead of one, increasing costs along diagonals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source6 = np.zeros((15, 15))\n", + "source6[7, 7] = 1.0\n", + "\n", + "raster6 = make_raster(source6)\n", + "friction6 = make_raster(np.ones((15, 15)))\n", + "\n", + "cost_8conn = cost_distance(raster6, friction6, connectivity=8)\n", + "cost_4conn = cost_distance(raster6, friction6, connectivity=4)\n", + "\n", + "plot_comparison(\n", + " [cost_8conn, cost_4conn, cost_4conn - cost_8conn],\n", + " ['8-connectivity', '4-connectivity', 'Difference (4-conn minus 8-conn)'],\n", + " cmaps=['magma', 'magma', 'Reds'],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The corner pixel shows the biggest difference\n", + "print(f\"Corner (0,0) with 8-connectivity: {cost_8conn.values[0, 0]:.4f}\")\n", + "print(f\"Corner (0,0) with 4-connectivity: {cost_4conn.values[0, 0]:.4f}\")\n", + "print(f\"Ratio: {cost_4conn.values[0, 0] / cost_8conn.values[0, 0]:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Dask support for large rasters\n", + "\n", + "For very large rasters that don't fit in memory, `cost_distance` works with Dask-backed DataArrays. When `max_cost` is finite, it automatically computes the required overlap radius and uses `dask.array.map_overlap` for parallel, chunk-by-chunk processing.\n", + "\n", + "The key formula: `overlap_radius = max_cost / (min_friction * cellsize)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dask.array as da\n", + "\n", + "# Create a larger raster backed by dask\n", + "np.random.seed(42)\n", + "h, w = 100, 100\n", + "\n", + "source_data = np.zeros((h, w))\n", + "source_data[20, 20] = 1.0\n", + "source_data[80, 80] = 2.0\n", + "\n", + "friction_data = np.random.uniform(1.0, 3.0, (h, w))\n", + "\n", + "raster_dask = make_raster(source_data)\n", + "raster_dask.data = da.from_array(source_data, chunks=(50, 50))\n", + "\n", + "friction_dask = make_raster(friction_data)\n", + "friction_dask.data = da.from_array(friction_data, chunks=(50, 50))\n", + "\n", + "print(f\"Raster type: {type(raster_dask.data)}\")\n", + "print(f\"Chunks: {raster_dask.data.chunks}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute with max_cost to enable efficient chunked processing\n", + "result_dask = cost_distance(raster_dask, friction_dask, max_cost=80.0)\n", + "\n", + "print(f\"Result type: {type(result_dask.data)}\")\n", + "print(f\"Result chunks: {result_dask.data.chunks}\")\n", + "\n", + "# Compute to numpy for plotting\n", + "result_computed = result_dask.compute()\n", + "\n", + "plot_comparison(\n", + " [make_raster(friction_data), result_computed],\n", + " ['Random friction surface', 'cost_distance (dask, max_cost=80)'],\n", + " cmaps=['YlOrRd', 'magma'],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify dask matches numpy\n", + "raster_np = make_raster(source_data)\n", + "friction_np = make_raster(friction_data)\n", + "result_np = cost_distance(raster_np, friction_np, max_cost=80.0)\n", + "\n", + "max_diff = np.nanmax(np.abs(result_computed.values - result_np.values))\n", + "print(f\"Max difference between dask and numpy: {max_diff:.10f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Function | What it measures | Friction-aware? | Algorithm |\n", + "|---|---|---|---|\n", + "| `proximity` | Geometric distance to nearest target | No | 4-pass DP (GDAL) |\n", + "| `cost_distance` | Accumulated traversal cost to nearest target | Yes | Multi-source Dijkstra |\n", + "\n", + "**When to use which:**\n", + "- Use `proximity` when you need simple geometric distance (e.g., \"how far is each pixel from the nearest road?\").\n", + "- Use `cost_distance` when traversal cost varies across the landscape (e.g., \"what is the cheapest path cost from each pixel to the nearest hospital, accounting for terrain difficulty?\").\n", + "\n", + "**Key parameters for `cost_distance`:**\n", + "- `friction`: A raster where each cell's value represents the per-unit-distance cost of traversal. NaN or ≤0 = impassable.\n", + "- `max_cost`: Limits the search radius. Critical for Dask scalability.\n", + "- `connectivity`: 4 (cardinal only) or 8 (cardinal + diagonal).\n", + "- `target_values`: Specify which pixel values in the source raster are targets.\n", + "\n", + "### References\n", + "- GRASS GIS `r.cost`: https://grass.osgeo.org/grass-stable/manuals/r.cost.html\n", + "- ArcGIS Cost Distance: https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-analyst/cost-distance.htm" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.cfg b/setup.cfg index 78c0bfcf..8b83ef69 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,7 @@ tests = pyarrow pytest pytest-cov + scipy [flake8] diff --git a/xrspatial/__init__.py b/xrspatial/__init__.py index ab1a3fc4..741a0f20 100644 --- a/xrspatial/__init__.py +++ b/xrspatial/__init__.py @@ -1,5 +1,6 @@ from xrspatial.aspect import aspect # noqa from xrspatial.bump import bump # noqa +from xrspatial.cost_distance import cost_distance # noqa from xrspatial.classify import binary # noqa from xrspatial.classify import box_plot # noqa from xrspatial.classify import head_tail_breaks # noqa diff --git a/xrspatial/cost_distance.py b/xrspatial/cost_distance.py new file mode 100644 index 00000000..697e493e --- /dev/null +++ b/xrspatial/cost_distance.py @@ -0,0 +1,425 @@ +"""Cost-distance (weighted proximity) via multi-source Dijkstra. + +Computes the minimum accumulated traversal cost through a friction surface +to reach the nearest target pixel. This is the raster equivalent of +GRASS ``r.cost`` / ArcGIS *Cost Distance*. + +Algorithm +--------- +Multi-source Dijkstra with a numba-friendly binary min-heap: + +1. All source (target) pixels are seeded at cost 0. +2. Pop the minimum-cost pixel, relax 4- or 8-connected neighbours. +3. Edge cost = geometric_distance * average_friction of the two endpoints. +4. Repeat until the heap is empty or ``max_cost`` is exceeded. + +Dask strategy +------------- +For finite ``max_cost``, the maximum pixel radius any cost-path can reach +is ``max_cost / (f_min * cellsize)`` where *f_min* is the global minimum +friction (a tiny ``.compute()``). This radius becomes the ``depth`` +parameter to ``dask.array.map_overlap``, giving **exact** results within +the cost budget. + +If ``max_cost`` is infinite or the implied radius exceeds half the raster, +fall back to single-chunk mode (same trade-off as ``proximity()``). +""" + +from __future__ import annotations + +from functools import partial +from math import sqrt + +import numpy as np +import xarray as xr + +try: + import dask.array as da +except ImportError: + da = None + +from xrspatial.utils import get_dataarray_resolution, ngjit +from xrspatial.dataset_support import supports_dataset + +# --------------------------------------------------------------------------- +# Numba binary min-heap (three parallel arrays: keys, rows, cols) +# --------------------------------------------------------------------------- + +@ngjit +def _heap_push(keys, rows, cols, size, key, row, col): + """Push (key, row, col) onto the heap. Returns new size.""" + pos = size + keys[pos] = key + rows[pos] = row + cols[pos] = col + size += 1 + # sift up + while pos > 0: + parent = (pos - 1) >> 1 + if keys[parent] > keys[pos]: + # swap + keys[parent], keys[pos] = keys[pos], keys[parent] + rows[parent], rows[pos] = rows[pos], rows[parent] + cols[parent], cols[pos] = cols[pos], cols[parent] + pos = parent + else: + break + return size + + +@ngjit +def _heap_pop(keys, rows, cols, size): + """Pop minimum element. Returns (key, row, col, new_size).""" + key = keys[0] + row = rows[0] + col = cols[0] + size -= 1 + # move last to root + keys[0] = keys[size] + rows[0] = rows[size] + cols[0] = cols[size] + # sift down + pos = 0 + while True: + child = 2 * pos + 1 + if child >= size: + break + # pick smaller child + if child + 1 < size and keys[child + 1] < keys[child]: + child += 1 + if keys[child] < keys[pos]: + keys[pos], keys[child] = keys[child], keys[pos] + rows[pos], rows[child] = rows[child], rows[pos] + cols[pos], cols[child] = cols[child], cols[pos] + pos = child + else: + break + return key, row, col, size + + +# --------------------------------------------------------------------------- +# Multi-source Dijkstra kernel +# --------------------------------------------------------------------------- + +@ngjit +def _cost_distance_kernel( + source_data, + friction_data, + height, + width, + cellsize_x, + cellsize_y, + max_cost, + target_values, + dy, + dx, + dd, +): + """Run multi-source Dijkstra and return float32 cost-distance array. + + Parameters + ---------- + source_data : 2-D array + Source raster (targets are non-zero finite, or in *target_values*). + friction_data : 2-D array + Friction surface. NaN or <= 0 means impassable. + height, width : int + cellsize_x, cellsize_y : float + max_cost : float + target_values : 1-D array + Specific pixel values to treat as targets (empty ⇒ all non-zero + finite pixels). + dy, dx : 1-D int arrays + Neighbour offsets (length = connectivity). + dd : 1-D float array + Geometric distance for each neighbour direction. + """ + n_values = len(target_values) + n_neighbors = len(dy) + + # output: initialise to NaN (unreachable) + dist = np.full((height, width), np.inf, dtype=np.float64) + + # Heap arrays — worst-case each pixel is pushed once per neighbour + # but practically much less. We allocate height*width which is + # sufficient for an exact Dijkstra (each pixel settled at most once). + max_heap = height * width + h_keys = np.empty(max_heap, dtype=np.float64) + h_rows = np.empty(max_heap, dtype=np.int64) + h_cols = np.empty(max_heap, dtype=np.int64) + h_size = 0 + + visited = np.zeros((height, width), dtype=np.int8) + + # Seed all source pixels + for r in range(height): + for c in range(width): + val = source_data[r, c] + is_target = False + if n_values == 0: + if val != 0.0 and np.isfinite(val): + is_target = True + else: + for k in range(n_values): + if val == target_values[k]: + is_target = True + break + if is_target: + # source must also be passable + f = friction_data[r, c] + if np.isfinite(f) and f > 0.0: + dist[r, c] = 0.0 + h_size = _heap_push(h_keys, h_rows, h_cols, h_size, + 0.0, r, c) + + # Dijkstra main loop + while h_size > 0: + cost_u, ur, uc, h_size = _heap_pop(h_keys, h_rows, h_cols, h_size) + + if visited[ur, uc]: + continue + visited[ur, uc] = 1 + + if cost_u > max_cost: + break + + f_u = friction_data[ur, uc] + + for i in range(n_neighbors): + vr = ur + dy[i] + vc = uc + dx[i] + if vr < 0 or vr >= height or vc < 0 or vc >= width: + continue + if visited[vr, vc]: + continue + + f_v = friction_data[vr, vc] + # impassable if NaN or non-positive friction + if not (np.isfinite(f_v) and f_v > 0.0): + continue + + edge_cost = dd[i] * (f_u + f_v) * 0.5 + new_cost = cost_u + edge_cost + + if new_cost < dist[vr, vc]: + dist[vr, vc] = new_cost + h_size = _heap_push(h_keys, h_rows, h_cols, h_size, + new_cost, vr, vc) + + # Convert unreachable / over-budget to NaN, cast to float32 + out = np.empty((height, width), dtype=np.float32) + for r in range(height): + for c in range(width): + d = dist[r, c] + if d == np.inf or d > max_cost: + out[r, c] = np.nan + else: + out[r, c] = np.float32(d) + return out + + +# --------------------------------------------------------------------------- +# NumPy wrapper +# --------------------------------------------------------------------------- + +def _cost_distance_numpy(source_data, friction_data, cellsize_x, cellsize_y, + max_cost, target_values, dy, dx, dd): + height, width = source_data.shape + return _cost_distance_kernel( + source_data, friction_data, height, width, + cellsize_x, cellsize_y, max_cost, + target_values, dy, dx, dd, + ) + + +# --------------------------------------------------------------------------- +# Dask wrapper +# --------------------------------------------------------------------------- + +def _make_chunk_func(cellsize_x, cellsize_y, max_cost, target_values, + dy, dx, dd): + """Return a function suitable for ``da.map_overlap`` over two arrays.""" + + def _chunk(source_block, friction_block): + h, w = source_block.shape + return _cost_distance_kernel( + source_block, friction_block, h, w, + cellsize_x, cellsize_y, max_cost, + target_values, dy, dx, dd, + ) + + return _chunk + + +def _cost_distance_dask(source_da, friction_da, cellsize_x, cellsize_y, + max_cost, target_values, dy, dx, dd): + """Dask path: use map_overlap with depth derived from max_cost.""" + + # We need the global minimum friction to compute max pixel radius. + # This is a tiny scalar .compute(). + # Use da.where to avoid boolean indexing (which creates unknown chunks). + positive_friction = da.where(friction_da > 0, friction_da, np.inf) + f_min = da.nanmin(positive_friction).compute() + if not np.isfinite(f_min) or f_min <= 0: + # All friction is non-positive or NaN — nothing reachable + return da.full(source_da.shape, np.nan, dtype=np.float32, + chunks=source_da.chunks) + + min_cellsize = min(abs(cellsize_x), abs(cellsize_y)) + max_radius = max_cost / (float(f_min) * min_cellsize) + + height, width = source_da.shape + max_dim = max(height, width) + + pad = int(max_radius + 1) if np.isfinite(max_radius) else max_dim + + if not np.isfinite(max_radius) or pad >= height or pad >= width: + # Fall back to single-chunk when depth would exceed array size + source_da = source_da.rechunk({0: height, 1: width}) + friction_da = friction_da.rechunk({0: height, 1: width}) + pad_y = pad_x = 0 + else: + pad_y = pad + pad_x = pad + + chunk_func = _make_chunk_func( + cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, + ) + + out = da.map_overlap( + chunk_func, + source_da, friction_da, + depth=(pad_y, pad_x), + boundary=np.nan, + dtype=np.float32, + meta=np.array((), dtype=np.float32), + ) + return out + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +@supports_dataset +def cost_distance( + raster: xr.DataArray, + friction: xr.DataArray, + x: str = "x", + y: str = "y", + target_values: list = [], + max_cost: float = np.inf, + connectivity: int = 8, +) -> xr.DataArray: + """Compute accumulated cost-distance through a friction surface. + + For every pixel, computes the minimum accumulated traversal cost + to reach the nearest target pixel, where traversal cost along each + edge equals ``geometric_distance * mean_friction_of_endpoints``. + + Parameters + ---------- + raster : xr.DataArray or xr.Dataset + 2-D source raster. Target pixels are identified by non-zero + finite values (or values in *target_values*). + friction : xr.DataArray + 2-D friction (cost) surface. Must have the same shape and + coordinates as *raster*. Values must be positive and finite + for passable cells; NaN or ``<= 0`` marks impassable barriers. + x : str, default='x' + Name of the x coordinate. + y : str, default='y' + Name of the y coordinate. + target_values : list, optional + Specific pixel values in *raster* to treat as sources. + If empty, all non-zero finite pixels are sources. + max_cost : float, default=np.inf + Maximum accumulated cost. Pixels whose least-cost path exceeds + this budget are set to NaN. A finite value enables efficient + Dask parallelisation via ``map_overlap``. + connectivity : int, default=8 + Pixel connectivity: 4 (cardinal only) or 8 (cardinal + diagonal). + + Returns + ------- + xr.DataArray or xr.Dataset + 2-D array of accumulated cost-distance values (float32). + Source pixels have cost 0. Unreachable pixels are NaN. + """ + # --- validation --- + if raster.ndim != 2: + raise ValueError("raster must be 2-D") + if friction.ndim != 2: + raise ValueError("friction must be 2-D") + if raster.shape != friction.shape: + raise ValueError("raster and friction must have the same shape") + if raster.dims != (y, x): + raise ValueError( + f"raster.dims should be ({y!r}, {x!r}), got {raster.dims}" + ) + if connectivity not in (4, 8): + raise ValueError("connectivity must be 4 or 8") + + cellsize_x, cellsize_y = get_dataarray_resolution(raster) + cellsize_x = abs(float(cellsize_x)) + cellsize_y = abs(float(cellsize_y)) + + target_values = np.asarray(target_values, dtype=np.float64) + max_cost_f = float(max_cost) + + # Build neighbour offsets and geometric distances + if connectivity == 8: + dy = np.array([-1, -1, -1, 0, 0, 1, 1, 1], dtype=np.int64) + dx = np.array([-1, 0, 1, -1, 1, -1, 0, 1], dtype=np.int64) + dd = np.array([ + sqrt(cellsize_y**2 + cellsize_x**2), # (-1,-1) + cellsize_y, # (-1, 0) + sqrt(cellsize_y**2 + cellsize_x**2), # (-1,+1) + cellsize_x, # ( 0,-1) + cellsize_x, # ( 0,+1) + sqrt(cellsize_y**2 + cellsize_x**2), # (+1,-1) + cellsize_y, # (+1, 0) + sqrt(cellsize_y**2 + cellsize_x**2), # (+1,+1) + ], dtype=np.float64) + else: + dy = np.array([0, -1, 1, 0], dtype=np.int64) + dx = np.array([-1, 0, 0, 1], dtype=np.int64) + dd = np.array([cellsize_x, cellsize_y, cellsize_y, cellsize_x], + dtype=np.float64) + + # Ensure friction chunks match raster chunks for dask + source_data = raster.data + friction_data = friction.data + + if da is not None and isinstance(source_data, da.Array): + # Rechunk friction to match raster + if isinstance(friction_data, da.Array): + friction_data = friction_data.rechunk(source_data.chunks) + else: + friction_data = da.from_array(friction_data, + chunks=source_data.chunks) + + if isinstance(source_data, np.ndarray): + if isinstance(friction_data, np.ndarray): + result_data = _cost_distance_numpy( + source_data, friction_data, + cellsize_x, cellsize_y, max_cost_f, + target_values, dy, dx, dd, + ) + else: + raise TypeError("friction must be numpy-backed when raster is") + elif da is not None and isinstance(source_data, da.Array): + result_data = _cost_distance_dask( + source_data, friction_data, + cellsize_x, cellsize_y, max_cost_f, + target_values, dy, dx, dd, + ) + else: + raise TypeError(f"Unsupported array type: {type(source_data)}") + + return xr.DataArray( + result_data, + coords=raster.coords, + dims=raster.dims, + attrs=raster.attrs, + ) diff --git a/xrspatial/proximity.py b/xrspatial/proximity.py index 44602308..b2c281e6 100644 --- a/xrspatial/proximity.py +++ b/xrspatial/proximity.py @@ -1,3 +1,4 @@ +from functools import partial from math import sqrt try: @@ -5,6 +6,11 @@ except ImportError: da = None +try: + from scipy.spatial import cKDTree +except ImportError: + cKDTree = None + import numpy as np import xarray as xr from numba import prange @@ -398,6 +404,79 @@ def _process_proximity_line( return +def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d, + tree, block_info, max_distance, p): + """Query k-d tree for nearest target distance for every pixel in block.""" + if block_info is None or block_info == []: + return np.full(block.shape, np.nan, dtype=np.float32) + + y_start = block_info[0]['array-location'][0][0] + x_start = block_info[0]['array-location'][1][0] + h, w = block.shape + + chunk_ys = y_coords_1d[y_start:y_start + h] + chunk_xs = x_coords_1d[x_start:x_start + w] + yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij') + query_pts = np.column_stack([yy.ravel(), xx.ravel()]) + + dists, _ = tree.query(query_pts, p=p, + distance_upper_bound=max_distance) + dists = dists.reshape(h, w).astype(np.float32) + dists[dists == np.inf] = np.nan + return dists + + +def _process_dask_kdtree(raster, x_coords, y_coords, + target_values, max_distance, distance_metric): + """Two-phase k-d tree proximity for unbounded dask arrays.""" + p = 2 if distance_metric == EUCLIDEAN else 1 # Manhattan: p=1 + + # Phase 1: stream through chunks to collect target coordinates + target_list = [] + chunks_y, chunks_x = raster.data.chunks + y_offset = 0 + for iy, cy in enumerate(chunks_y): + x_offset = 0 + for ix, cx in enumerate(chunks_x): + chunk_data = raster.data.blocks[iy, ix].compute() + if len(target_values) == 0: + mask = np.isfinite(chunk_data) & (chunk_data != 0) + else: + mask = np.isin(chunk_data, target_values) & np.isfinite(chunk_data) + rows, cols = np.where(mask) + if len(rows) > 0: + coords = np.column_stack([ + y_coords[y_offset + rows], + x_coords[x_offset + cols], + ]) + target_list.append(coords) + x_offset += cx + y_offset += cy + + if len(target_list) == 0: + return da.full(raster.shape, np.nan, dtype=np.float32, + chunks=raster.data.chunks) + + target_coords = np.concatenate(target_list) + tree = cKDTree(target_coords) + + # Phase 2: query tree per chunk via map_blocks + chunk_fn = partial(_kdtree_chunk_fn, + y_coords_1d=y_coords, + x_coords_1d=x_coords, + tree=tree, + max_distance=max_distance if np.isfinite(max_distance) else np.inf, + p=p) + + result = da.map_blocks( + chunk_fn, + raster.data, + dtype=np.float32, + meta=np.array((), dtype=np.float32), + ) + return result + + def _process( raster, x, @@ -633,16 +712,26 @@ def _process_dask(raster, xs, ys): result = _process_numpy(raster.data, xs, ys) elif da is not None and isinstance(raster.data, da.Array): - # dask case - create coordinate arrays as dask arrays directly - # This avoids materializing the full arrays in memory - # Convert 1D coords to dask arrays first - x_coords_da = da.from_array(x_coords, chunks=x_coords.shape[0]) - y_coords_da = da.from_array(y_coords, chunks=y_coords.shape[0]) - xs = da.tile(x_coords_da, (raster.shape[0], 1)) - ys = da.repeat(y_coords_da, raster.shape[1]).reshape(raster.shape) - xs = xs.rechunk(raster.chunks) - ys = ys.rechunk(raster.chunks) - result = _process_dask(raster, xs, ys) + use_kdtree = ( + cKDTree is not None + and process_mode == PROXIMITY + and distance_metric in (EUCLIDEAN, MANHATTAN) + and max_distance >= max_possible_distance + ) + if use_kdtree: + result = _process_dask_kdtree( + raster, x_coords, y_coords, + target_values, max_distance, distance_metric, + ) + else: + # Existing path: build 2D coordinate arrays as dask arrays + x_coords_da = da.from_array(x_coords, chunks=x_coords.shape[0]) + y_coords_da = da.from_array(y_coords, chunks=y_coords.shape[0]) + xs = da.tile(x_coords_da, (raster.shape[0], 1)) + ys = da.repeat(y_coords_da, raster.shape[1]).reshape(raster.shape) + xs = xs.rechunk(raster.chunks) + ys = ys.rechunk(raster.chunks) + result = _process_dask(raster, xs, ys) return result diff --git a/xrspatial/tests/test_cost_distance.py b/xrspatial/tests/test_cost_distance.py new file mode 100644 index 00000000..55f3c9c8 --- /dev/null +++ b/xrspatial/tests/test_cost_distance.py @@ -0,0 +1,402 @@ +"""Tests for xrspatial.cost_distance.""" + +try: + import dask.array as da +except ImportError: + da = None + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.cost_distance import cost_distance + + +def _make_raster(data, backend='numpy', chunks=(3, 3)): + """Build a DataArray with y/x coords, optionally dask-backed.""" + h, w = data.shape + raster = xr.DataArray( + data.astype(np.float64), + dims=['y', 'x'], + attrs={'res': (1.0, 1.0)}, + ) + raster['y'] = np.arange(h, dtype=np.float64) + raster['x'] = np.arange(w, dtype=np.float64) + if 'dask' in backend and da is not None: + raster.data = da.from_array(raster.data, chunks=chunks) + return raster + + +def _compute(arr): + """Extract numpy data from DataArray (works for numpy or dask).""" + if da is not None and isinstance(arr.data, da.Array): + return arr.values + return arr.data + + +# ----------------------------------------------------------------------- +# Uniform friction = 1 should match Euclidean proximity +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_uniform_friction_matches_euclidean(backend): + """With uniform friction=1, cost_distance ≈ Euclidean distance.""" + data = np.zeros((7, 7), dtype=np.float64) + data[3, 3] = 1.0 # single source at centre + + raster = _make_raster(data, backend=backend, chunks=(7, 7)) + friction = _make_raster(np.ones((7, 7)), backend=backend, chunks=(7, 7)) + + result = cost_distance(raster, friction) + out = _compute(result) + + # Source pixel should be 0 + assert out[3, 3] == 0.0 + + # Check a few known Euclidean distances (cellsize=1) + # Cardinal neighbour: distance = 1 + np.testing.assert_allclose(out[3, 4], 1.0, atol=1e-5) + np.testing.assert_allclose(out[2, 3], 1.0, atol=1e-5) + + # Diagonal neighbour: distance = sqrt(2) + np.testing.assert_allclose(out[2, 2], np.sqrt(2), atol=1e-5) + + # 2 cells away cardinally: distance = 2 + np.testing.assert_allclose(out[3, 5], 2.0, atol=1e-5) + + # Corners: distance = sqrt(3^2+3^2) = 3*sqrt(2) ≈ 4.2426 + # But Dijkstra on a grid may find a shorter path via diagonals + # The grid-optimal path from (3,3) to (0,0) is 3 diagonal steps = 3*sqrt(2) + np.testing.assert_allclose(out[0, 0], 3 * np.sqrt(2), atol=1e-5) + + +# ----------------------------------------------------------------------- +# Hand-computed analytic case +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_analytic_small_grid(backend): + """3x3 grid with known costs, single source at (0,0).""" + source = np.zeros((3, 3)) + source[0, 0] = 1.0 # source + + friction_data = np.array([ + [1.0, 2.0, 1.0], + [1.0, 10.0, 1.0], + [1.0, 1.0, 1.0], + ]) + + raster = _make_raster(source, backend=backend, chunks=(3, 3)) + friction = _make_raster(friction_data, backend=backend, chunks=(3, 3)) + + result = cost_distance(raster, friction) + out = _compute(result) + + # Source at (0,0): cost = 0 + assert out[0, 0] == 0.0 + + # (0,1): cardinal, avg_friction = (1+2)/2 = 1.5, dist=1 => cost=1.5 + np.testing.assert_allclose(out[0, 1], 1.5, atol=1e-5) + + # (1,0): cardinal, avg_friction = (1+1)/2 = 1, dist=1 => cost=1.0 + np.testing.assert_allclose(out[1, 0], 1.0, atol=1e-5) + + # (1,1): diagonal from (0,0), avg_friction = (1+10)/2 = 5.5, + # cost = sqrt(2)*5.5 ≈ 7.778 + # BUT via (1,0) then cardinal to (1,1): + # cost = 1.0 + 1*(1+10)/2 = 1 + 5.5 = 6.5 + # Dijkstra picks the cheaper one: 6.5 + np.testing.assert_allclose(out[1, 1], 6.5, atol=1e-5) + + # (2,0): via (1,0), cost = 1.0 + 1*(1+1)/2 = 2.0 + np.testing.assert_allclose(out[2, 0], 2.0, atol=1e-5) + + +# ----------------------------------------------------------------------- +# Barriers: NaN and zero-friction cells are impassable +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_barriers_nan_and_zero(backend): + """NaN and zero-friction cells block paths.""" + source = np.zeros((3, 5)) + source[1, 0] = 1.0 # source on left + + friction_data = np.ones((3, 5)) + friction_data[:, 2] = 0.0 # zero-friction barrier in column 2 + friction_data[1, 2] = np.nan # NaN barrier too + + raster = _make_raster(source, backend=backend, chunks=(3, 5)) + friction = _make_raster(friction_data, backend=backend, chunks=(3, 5)) + + result = cost_distance(raster, friction) + out = _compute(result) + + # Source reachable + assert out[1, 0] == 0.0 + + # Left side reachable + assert np.isfinite(out[0, 0]) + assert np.isfinite(out[1, 1]) + + # Right side should be NaN (unreachable — barrier blocks all paths) + assert np.isnan(out[1, 3]) + assert np.isnan(out[1, 4]) + assert np.isnan(out[0, 3]) + + +# ----------------------------------------------------------------------- +# Multiple sources: verify nearest-by-cost wins +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_multiple_sources(backend): + """Two sources — each pixel gets cost from the cheaper one.""" + source = np.zeros((1, 5)) + source[0, 0] = 1.0 # source A at left + source[0, 4] = 2.0 # source B at right + + friction_data = np.ones((1, 5)) + + raster = _make_raster(source, backend=backend, chunks=(1, 5)) + friction = _make_raster(friction_data, backend=backend, chunks=(1, 5)) + + result = cost_distance(raster, friction) + out = _compute(result) + + # Both sources at 0 + assert out[0, 0] == 0.0 + assert out[0, 4] == 0.0 + + # Middle pixel (0,2) equidistant: cost = 2.0 from either source + np.testing.assert_allclose(out[0, 2], 2.0, atol=1e-5) + + # (0,1): cost 1 from source A + np.testing.assert_allclose(out[0, 1], 1.0, atol=1e-5) + # (0,3): cost 1 from source B + np.testing.assert_allclose(out[0, 3], 1.0, atol=1e-5) + + +# ----------------------------------------------------------------------- +# max_cost truncation +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_max_cost_truncation(backend): + """Pixels beyond max_cost should be NaN.""" + source = np.zeros((1, 10)) + source[0, 0] = 1.0 + + friction_data = np.ones((1, 10)) + + raster = _make_raster(source, backend=backend, chunks=(1, 10)) + friction = _make_raster(friction_data, backend=backend, chunks=(1, 10)) + + result = cost_distance(raster, friction, max_cost=3.5) + out = _compute(result) + + # Pixels within budget + assert out[0, 0] == 0.0 + np.testing.assert_allclose(out[0, 1], 1.0, atol=1e-5) + np.testing.assert_allclose(out[0, 2], 2.0, atol=1e-5) + np.testing.assert_allclose(out[0, 3], 3.0, atol=1e-5) + + # Beyond budget + assert np.isnan(out[0, 4]) + assert np.isnan(out[0, 9]) + + +# ----------------------------------------------------------------------- +# Dask vs NumPy equivalence +# ----------------------------------------------------------------------- + +@pytest.mark.skipif(da is None, reason="dask not installed") +def test_dask_matches_numpy(): + """Dask result must match numpy result exactly.""" + np.random.seed(42) + source = np.zeros((10, 12)) + source[2, 3] = 1.0 + source[7, 9] = 2.0 + + friction_data = np.random.uniform(0.5, 5.0, (10, 12)) + + raster_np = _make_raster(source, backend='numpy') + friction_np = _make_raster(friction_data, backend='numpy') + result_np = cost_distance(raster_np, friction_np, max_cost=20.0) + + raster_da = _make_raster(source, backend='dask+numpy', chunks=(5, 6)) + friction_da = _make_raster(friction_data, backend='dask+numpy', chunks=(5, 6)) + result_da = cost_distance(raster_da, friction_da, max_cost=20.0) + + assert isinstance(result_da.data, da.Array) + np.testing.assert_allclose( + result_da.values, result_np.data, equal_nan=True, atol=1e-5 + ) + + +# ----------------------------------------------------------------------- +# 4-connectivity vs 8-connectivity +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy']) +def test_connectivity_4_vs_8(backend): + """4-connectivity diagonal cost should be higher than 8-connectivity.""" + source = np.zeros((3, 3)) + source[0, 0] = 1.0 + + friction_data = np.ones((3, 3)) + + raster = _make_raster(source, backend=backend) + friction = _make_raster(friction_data, backend=backend) + + r8 = cost_distance(raster, friction, connectivity=8) + r4 = cost_distance(raster, friction, connectivity=4) + + out8 = _compute(r8) + out4 = _compute(r4) + + # Diagonal (2,2): 8-conn = 2*sqrt(2), 4-conn = 4 (must go around) + np.testing.assert_allclose(out8[2, 2], 2 * np.sqrt(2), atol=1e-5) + np.testing.assert_allclose(out4[2, 2], 4.0, atol=1e-5) + + +# ----------------------------------------------------------------------- +# target_values parameter +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_target_values(backend): + """Only specified target values should be sources.""" + source = np.array([ + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + ]) + + friction_data = np.ones((3, 3)) + + raster = _make_raster(source, backend=backend, chunks=(3, 3)) + friction = _make_raster(friction_data, backend=backend, chunks=(3, 3)) + + # Only treat value=2 as target + result = cost_distance(raster, friction, target_values=[2]) + out = _compute(result) + + # (2,1) is the source with value=2 + assert out[2, 1] == 0.0 + + # (0,1) has value=1 but should NOT be a source here + assert out[0, 1] > 0.0 + # Cost from (2,1) to (0,1): 2 cardinal steps = 2.0 + np.testing.assert_allclose(out[0, 1], 2.0, atol=1e-5) + + +# ----------------------------------------------------------------------- +# Lazy coordinate arrays for dask input +# ----------------------------------------------------------------------- + +@pytest.mark.skipif(da is None, reason="dask not installed") +def test_dask_no_large_numpy_arrays(): + """Dask path should not materialise large numpy arrays.""" + from unittest.mock import patch + + height, width = 50, 60 + source = np.zeros((height, width)) + source[10, 10] = 1.0 + + friction_data = np.ones((height, width)) + + raster = _make_raster(source, backend='dask+numpy', chunks=(25, 30)) + friction = _make_raster(friction_data, backend='dask+numpy', chunks=(25, 30)) + + # Track large numpy allocations + original_full = np.full + large_allocs = [] + + def tracking_full(shape, *args, **kwargs): + result = original_full(shape, *args, **kwargs) + if hasattr(shape, '__len__'): + total = 1 + for s in shape: + total *= s + else: + total = shape + if total >= height * width: + large_allocs.append(('full', shape)) + return result + + # The kernel itself will allocate full-size arrays, that's expected + # when each chunk is processed. We just verify the outer dask wrapper + # doesn't allocate huge arrays before map_overlap. + result = cost_distance(raster, friction, max_cost=20.0) + + # Verify result is dask-backed + assert isinstance(result.data, da.Array) + + # Verify correctness + computed = result.values + assert computed[10, 10] == 0.0 + assert computed[0, 0] > 0.0 or np.isnan(computed[0, 0]) + + +# ----------------------------------------------------------------------- +# Validation errors +# ----------------------------------------------------------------------- + +def test_invalid_connectivity(): + source = np.zeros((3, 3)) + source[1, 1] = 1.0 + raster = _make_raster(source) + friction = _make_raster(np.ones((3, 3))) + with pytest.raises(ValueError, match="connectivity"): + cost_distance(raster, friction, connectivity=6) + + +def test_shape_mismatch(): + raster = _make_raster(np.zeros((3, 3))) + friction_data = np.ones((4, 4)) + friction = xr.DataArray(friction_data, dims=['y', 'x']) + friction['y'] = np.arange(4, dtype=np.float64) + friction['x'] = np.arange(4, dtype=np.float64) + with pytest.raises(ValueError, match="same shape"): + cost_distance(raster, friction) + + +def test_wrong_dims(): + data = np.zeros((3, 3)) + data[1, 1] = 1.0 + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lat'] = np.arange(3, dtype=np.float64) + raster['lon'] = np.arange(3, dtype=np.float64) + friction = xr.DataArray(np.ones((3, 3)), dims=['lat', 'lon']) + friction['lat'] = np.arange(3, dtype=np.float64) + friction['lon'] = np.arange(3, dtype=np.float64) + # Default x='x', y='y' won't match 'lat','lon' + with pytest.raises(ValueError, match="dims"): + cost_distance(raster, friction) + # Should work with correct dim names + result = cost_distance(raster, friction, x='lon', y='lat') + assert result.shape == (3, 3) + + +# ----------------------------------------------------------------------- +# Source at impassable cell +# ----------------------------------------------------------------------- + +@pytest.mark.parametrize("backend", ['numpy']) +def test_source_on_impassable_cell(backend): + """Source on NaN-friction cell should not seed Dijkstra.""" + source = np.zeros((3, 3)) + source[1, 1] = 1.0 # source + + friction_data = np.ones((3, 3)) + friction_data[1, 1] = np.nan # source cell is impassable + + raster = _make_raster(source, backend=backend) + friction = _make_raster(friction_data, backend=backend) + + result = cost_distance(raster, friction) + out = _compute(result) + + # Everything should be NaN — the only source is on impassable terrain + assert np.all(np.isnan(out)) diff --git a/xrspatial/tests/test_proximity.py b/xrspatial/tests/test_proximity.py index c1f0241e..47cd3282 100644 --- a/xrspatial/tests/test_proximity.py +++ b/xrspatial/tests/test_proximity.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + try: import dask.array as da except ImportError: @@ -350,3 +352,193 @@ def tracking_repeat(a, repeats, axis=None): assert computed.data[90, 100] == 0.0 # Check that non-target pixels have positive distance assert computed.data[0, 0] > 0.0 + + +def _make_kdtree_raster(height=20, width=30, chunks=(10, 15)): + """Helper: build a small dask-backed raster with a few target pixels.""" + data = np.zeros((height, width), dtype=np.float64) + data[3, 5] = 1.0 + data[12, 20] = 2.0 + data[18, 2] = 3.0 + _lon = np.linspace(0, 29, width) + _lat = np.linspace(19, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=chunks) + return raster + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +@pytest.mark.parametrize("metric", ["EUCLIDEAN", "MANHATTAN"]) +def test_proximity_dask_kdtree_matches_numpy(metric): + """k-d tree dask result must match numpy result for the same raster.""" + raster = _make_kdtree_raster() + numpy_raster = raster.copy() + numpy_raster.data = raster.data.compute() + + numpy_result = proximity(numpy_raster, x='lon', y='lat', + distance_metric=metric) + dask_result = proximity(raster, x='lon', y='lat', + distance_metric=metric) + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_no_large_arrays(): + """No full-raster-sized numpy arrays should be created in k-d tree path.""" + height, width = 100, 120 + data = np.zeros((height, width), dtype=np.float64) + data[10, 10] = 1.0 + data[50, 60] = 2.0 + + _lon = np.linspace(0, 119, width) + _lat = np.linspace(99, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(25, 30)) + + original_tile = np.tile + original_repeat = np.repeat + large_numpy_created = [] + + def tracking_tile(A, reps): + result = original_tile(A, reps) + if result.size >= height * width: + large_numpy_created.append(('tile', result.shape)) + return result + + def tracking_repeat(a, repeats, axis=None): + result = original_repeat(a, repeats, axis=axis) + if result.size >= height * width: + large_numpy_created.append(('repeat', result.shape)) + return result + + with patch.object(np, 'tile', tracking_tile): + with patch.object(np, 'repeat', tracking_repeat): + result = proximity(raster, x='lon', y='lat') + + assert len(large_numpy_created) == 0, ( + f"Large numpy arrays created: {large_numpy_created}" + ) + assert isinstance(result.data, da.Array) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_with_target_values(): + """target_values filtering works through the k-d tree path.""" + raster = _make_kdtree_raster() + numpy_raster = raster.copy() + numpy_raster.data = raster.data.compute() + + target_values = [2, 3] + numpy_result = proximity(numpy_raster, x='lon', y='lat', + target_values=target_values) + dask_result = proximity(raster, x='lon', y='lat', + target_values=target_values) + + assert isinstance(dask_result.data, da.Array) + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_no_targets(): + """No target pixels found → result is all NaN.""" + data = np.zeros((10, 10), dtype=np.float64) + _lon = np.arange(10, dtype=np.float64) + _lat = np.arange(10, dtype=np.float64)[::-1] + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(5, 5)) + + result = proximity(raster, x='lon', y='lat') + assert isinstance(result.data, da.Array) + computed = result.values + assert np.all(np.isnan(computed)) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_max_distance(): + """max_distance truncation works via distance_upper_bound in tree query.""" + raster = _make_kdtree_raster() + numpy_raster = raster.copy() + numpy_raster.data = raster.data.compute() + + max_dist = 5.0 + numpy_result = proximity(numpy_raster, x='lon', y='lat', + max_distance=max_dist) + dask_result = proximity(raster, x='lon', y='lat', + max_distance=max_dist) + + np.testing.assert_allclose( + dask_result.values, numpy_result.values, rtol=1e-5, equal_nan=True, + ) + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_fallback_no_scipy(): + """When cKDTree is None, falls back to single-chunk path.""" + import sys + prox_mod = sys.modules['xrspatial.proximity'] + + height, width = 8, 10 + data = np.zeros((height, width), dtype=np.float64) + data[2, 3] = 1.0 + data[6, 8] = 2.0 + _lon = np.linspace(0, 9, width) + _lat = np.linspace(7, 0, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(4, 5)) + + original_ckdtree = prox_mod.cKDTree + try: + prox_mod.cKDTree = None + result = proximity(raster, x='lon', y='lat') + assert isinstance(result.data, da.Array) + # Should still produce correct results via fallback + computed = result.values + assert computed[2, 3] == 0.0 + finally: + prox_mod.cKDTree = original_ckdtree + + +@pytest.mark.skipif(da is None, reason="dask is not installed") +def test_proximity_dask_kdtree_fallback_great_circle(): + """GREAT_CIRCLE metric falls back to single-chunk, not k-d tree.""" + import sys + prox_mod = sys.modules['xrspatial.proximity'] + + height, width = 8, 10 + data = np.zeros((height, width), dtype=np.float64) + data[2, 3] = 1.0 + _lon = np.linspace(-10, 10, width) + _lat = np.linspace(10, -10, height) + raster = xr.DataArray(data, dims=['lat', 'lon']) + raster['lon'] = _lon + raster['lat'] = _lat + raster.data = da.from_array(data, chunks=(4, 5)) + + # Patch _process_dask_kdtree to detect if it's called + kdtree_called = [] + original_fn = prox_mod._process_dask_kdtree + + def spy(*args, **kwargs): + kdtree_called.append(True) + return original_fn(*args, **kwargs) + + with patch.object(prox_mod, '_process_dask_kdtree', spy): + result = proximity(raster, x='lon', y='lat', + distance_metric='GREAT_CIRCLE') + + assert len(kdtree_called) == 0, "k-d tree path should not be used for GREAT_CIRCLE" + assert isinstance(result.data, da.Array)