Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,21 @@ def write_streaming(dask_data, path: str, *,
tiles_per_segment = max(
1, streaming_buffer_bytes // bytes_per_tile_col)

# Hoist the compression thread pool over the entire tiled
# write. Re-creating the executor per segment paid the
# thread-startup cost on every horizontal stripe and
# offset the parallel speedup on wide rasters; a single
# pool reused across all segments avoids that overhead.
# Skip the pool when compression is uncompressed (no
# C-level work to release the GIL on) or when the host
# has only one usable core.
from concurrent.futures import ThreadPoolExecutor
_pool_workers = min(tiles_per_segment, os.cpu_count() or 4)
_use_pool = (comp_tag != COMPRESSION_NONE
and _pool_workers > 1)
tile_pool = (ThreadPoolExecutor(max_workers=_pool_workers)
if _use_pool else None)

for tr in range(tiles_down):
r0 = tr * th
r1 = min(r0 + th, height)
Expand Down Expand Up @@ -1400,6 +1415,8 @@ def write_streaming(dask_data, path: str, *,
seg_np = seg_np.copy()
seg_np[nan_mask] = seg_np.dtype.type(nodata)

# Build tile arrays for this segment
seg_tile_arrs = []
for tc in range(seg_start, seg_end):
c0 = tc * tw
c1 = min(c0 + tw, width)
Expand All @@ -1420,17 +1437,54 @@ def write_streaming(dask_data, path: str, *,
else:
tile_arr = np.ascontiguousarray(tile_slice)

compressed = _compress_block(
tile_arr, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)

seg_tile_arrs.append(tile_arr)

# Parallel compress on the hoisted ``tile_pool``
# when it exists. zlib/zstd/LZW release the GIL,
# so threading actually parallelises the C-level
# work. Peak memory while the segment is in
# flight covers BOTH the uncompressed
# ``seg_tile_arrs`` (one full tile per column,
# released after the futures resolve) AND the
# compressed buffers ``seg_compressed`` (held
# until the sequential write loop drains them).
# Both lists are bounded by ``tiles_per_segment``
# which the streaming buffer cap sets; fall
# through to a serial path when the pool is None
# (no compression / single core) or when only
# one tile sits in this segment.
n_seg_tiles = len(seg_tile_arrs)
if tile_pool is None or n_seg_tiles <= 1:
seg_compressed = [
_compress_block(
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)
for ta in seg_tile_arrs
]
else:
futures = [
tile_pool.submit(
_compress_block,
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)
for ta in seg_tile_arrs
]
seg_compressed = [
fut.result() for fut in futures]

# Sequential file write to preserve on-disk tile order
for compressed in seg_compressed:
actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
f.write(compressed)
current_offset += len(compressed)

del seg_np
del seg_np, seg_tile_arrs, seg_compressed

if tile_pool is not None:
tile_pool.shutdown(wait=True)
else:
# Strip layout
for i in range(n_entries):
Expand Down
189 changes: 189 additions & 0 deletions xrspatial/geotiff/tests/test_streaming_write_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Parallel per-tile compression in the streaming write path (P4).

The streaming tile-write path used to compress tiles serially inside each
horizontal segment. The non-streaming path already fans compress out to a
``ThreadPoolExecutor`` sized at ``os.cpu_count()``. These tests confirm
the streaming path now follows the same pattern: round-trip pixels are
unchanged, more than one worker thread participates, and a moderate
deflate write completes within a generous wall-time budget.
"""
from __future__ import annotations

import inspect
import os
import threading
import time

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import open_geotiff, to_geotiff
from xrspatial.geotiff import _writer as writer_mod


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_dataarray(shape, dtype=np.float32, seed=20260508):
rng = np.random.default_rng(seed)
if np.issubdtype(np.dtype(dtype), np.floating):
arr = rng.random(shape, dtype=dtype)
nodata = -9999.0
else:
info = np.iinfo(dtype)
arr = rng.integers(info.min // 2 if info.min < 0 else 1,
info.max // 2,
size=shape, dtype=dtype)
# Pick a sentinel that is in-range for the integer dtype.
nodata = 0 if info.min >= 0 else -1
if len(shape) == 2:
h, w = shape
else:
h, w = shape[:2]
y = np.linspace(41.0, 40.0, h)
x = np.linspace(-106.0, -105.0, w)
return xr.DataArray(arr, dims=['y', 'x'][:len(shape)],
coords={'y': y, 'x': x},
attrs={'crs': 4326, 'nodata': nodata})


# ---------------------------------------------------------------------------
# Round-trip correctness across dtype/compression/tile-size combos
# ---------------------------------------------------------------------------

@pytest.mark.parametrize(
'dtype,compression,tile_size',
[
(np.float32, 'deflate', 256),
(np.float32, 'zstd', 128),
(np.float32, 'lzw', 256),
(np.uint16, 'deflate', 256),
(np.uint8, 'deflate', 128),
(np.float32, 'none', 256),
],
)
def test_streaming_write_round_trip_unchanged(
dtype, compression, tile_size, tmp_path):
"""Streaming write must be bit-exact vs. the eager write (and vs. input).

Picks shapes that force multiple horizontal segments by setting a
small ``streaming_buffer_bytes`` so the parallel-compress branch
actually fires.
"""
shape = (600, 700)
da = _make_dataarray(shape, dtype=dtype)
# small chunks so dask path is taken; small buffer so segments split
dask_da = da.chunk({'y': 200, 'x': 200})

eager_path = str(tmp_path / 'eager.tif')
stream_path = str(tmp_path / 'stream.tif')

to_geotiff(da, eager_path,
compression=compression, tile_size=tile_size)

# Force multi-segment by limiting buffer to a couple of tile columns,
# if the underlying write_streaming exposes the kwarg.
sig = inspect.signature(writer_mod.write_streaming)
extra = {}
if 'streaming_buffer_bytes' in sig.parameters:
# ~3 tile columns at 256-tile widths of float32 -> few hundred KB.
extra['streaming_buffer_bytes'] = 256 * 1024
to_geotiff(dask_da, stream_path,
compression=compression, tile_size=tile_size, **extra)

eager = open_geotiff(eager_path)
stream = open_geotiff(stream_path)

# Bit-exact pixel match: streaming vs eager
np.testing.assert_array_equal(eager.values, stream.values)
# And streaming vs the source array (lossless codecs only)
np.testing.assert_array_equal(stream.values, da.values)


# ---------------------------------------------------------------------------
# Parallelism is observable: multiple thread IDs participate
# ---------------------------------------------------------------------------

def test_streaming_write_parallelism_observed(monkeypatch, tmp_path):
"""Confirm more than one worker thread runs ``_compress_block``.

Strategy: wrap the real function so each call records the current
thread ID, then write a raster sized to produce many tiles inside a
single segment. After the write, assert at least 2 unique thread IDs
were observed.

Force ``os.cpu_count()`` to 4 in the global ``os`` module for the
duration of the test so the assertion stays deterministic on
single-core CI containers (where the real cpu_count would size the
pool to 1 and the test would fail for environment reasons). The
writer imports ``os`` lazily inside the function, so patching the
global module is sufficient.
"""
monkeypatch.setattr(os, 'cpu_count', lambda: 4)

real_compress = writer_mod._compress_block
seen_threads: set[int] = set()
lock = threading.Lock()

def recording_compress(*args, **kwargs):
# Hold each call long enough that the pool cannot serialise it
# away on a single core.
tid = threading.get_ident()
with lock:
seen_threads.add(tid)
# Tiny synthetic delay so concurrent submissions overlap.
time.sleep(0.005)
return real_compress(*args, **kwargs)

monkeypatch.setattr(writer_mod, '_compress_block', recording_compress)

# Many tiles in a single segment: 16 tile columns x 8 tile rows = 128
# tiles, well above any cpu count.
shape = (16 * 256, 16 * 256)
da = _make_dataarray(shape, dtype=np.float32, seed=42)
dask_da = da.chunk({'y': 256 * 4, 'x': 256 * 4})

path = str(tmp_path / 'parallel_check.tif')
to_geotiff(dask_da, path,
compression='deflate', tile_size=256)

assert len(seen_threads) > 1, (
f"Expected >1 worker threads in streaming compress, "
f"saw {len(seen_threads)}: {seen_threads}")

Comment on lines +148 to +155

# ---------------------------------------------------------------------------
# Optional wall-clock regression guard (env-gated)
# ---------------------------------------------------------------------------

def test_streaming_write_perf_sanity(tmp_path):
"""Opt-in tripwire: a 4096x4096 deflate streaming write under a
generous wall-clock threshold. Skipped by default because shared CI
runners, CPU throttling, debug builds, and slow filesystems all
make absolute timings flaky. Set ``XRSPATIAL_RUN_PERF_TESTS=1`` to
enable. The deterministic regression coverage lives in
:func:`test_streaming_write_parallelism_observed` -- this is just a
sanity bound on top.
"""
if os.environ.get('XRSPATIAL_RUN_PERF_TESTS') != '1':
pytest.skip(
"set XRSPATIAL_RUN_PERF_TESTS=1 to run wall-clock perf tests")

shape = (4096, 4096)
da = _make_dataarray(shape, dtype=np.float32, seed=2026)
dask_da = da.chunk({'y': 1024, 'x': 1024})

path = str(tmp_path / 'perf_4k.tif')

t0 = time.perf_counter()
to_geotiff(dask_da, path, compression='deflate', tile_size=256)
elapsed = time.perf_counter() - t0

result = open_geotiff(path)
assert result.shape == shape

assert elapsed < 5.0, (
f"Streaming 4096x4096 deflate write took {elapsed:.2f}s, "
f"expected <5s (regression guard)")
Comment on lines +180 to +189
Loading