diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 244c31c6..23b2cbeb 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -194,7 +194,7 @@ def _read_geo_info(source, *, overview_level: int | None = None): overview_level : int or None Overview IFD index (0 = full resolution). """ - from ._dtypes import tiff_dtype_to_numpy + from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy from ._geotags import extract_geo_info from ._header import parse_all_ifds, parse_header, select_overview_ifd from ._reader import _coerce_path, _is_file_like @@ -230,9 +230,7 @@ def _read_geo_info(source, *, overview_level: int | None = None): raise ValueError("No IFDs found in TIFF file") ifd = select_overview_ifd(ifds, overview_level) geo_info = extract_geo_info(ifd, data, header.byte_order) - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) n_bands = ifd.samples_per_pixel if ifd.samples_per_pixel > 1 else 0 return geo_info, ifd.height, ifd.width, file_dtype, n_bands @@ -1446,7 +1444,7 @@ def read_geotiff_gpu(source: str, *, from ._header import ( parse_header, parse_all_ifds, select_overview_ifd, validate_tile_layout, ) - from ._dtypes import tiff_dtype_to_numpy + from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy from ._geotags import extract_geo_info from ._gpu_decode import gpu_decode_tiles @@ -1469,9 +1467,7 @@ def read_geotiff_gpu(source: str, *, # Skip mask IFDs (NewSubfileType bit 2) ifd = select_overview_ifd(ifds, overview_level) - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) geo_info = extract_geo_info(ifd, data, header.byte_order) diff --git a/xrspatial/geotiff/_dtypes.py b/xrspatial/geotiff/_dtypes.py index a510061d..baa5e2e5 100644 --- a/xrspatial/geotiff/_dtypes.py +++ b/xrspatial/geotiff/_dtypes.py @@ -118,6 +118,80 @@ def tiff_dtype_to_numpy(bits_per_sample: int, sample_format: int = 1) -> np.dtyp SUB_BYTE_BPS = {1, 2, 4, 12} +_GDAL_OT_FOR_BPS = { + 8: 'Byte', + 16: 'UInt16', + 32: 'UInt32', + 64: 'Float64', +} + + +def _suggest_gdal_ot(bps_values, sample_format=None) -> str: + """Pick a sensible ``gdal_translate -ot`` value for a mixed-bps file. + + Returns a real GDAL type name (``Byte``, ``UInt16`` etc.) when the + widest band has a recognised mapping, or ```` as a + placeholder otherwise. ``sample_format`` (TIFF SampleFormat: 1=uint, + 2=int, 3=float) refines the integer choice when known. + """ + if not bps_values: + return '' + widest = max(bps_values) + if sample_format == 3 and widest in (32, 64): + return 'Float32' if widest == 32 else 'Float64' + if sample_format == 2 and widest in (8, 16, 32): + return {8: 'Int8', 16: 'Int16', 32: 'Int32'}[widest] + return _GDAL_OT_FOR_BPS.get(widest, '') + + +def resolve_bits_per_sample(bps, sample_format=None) -> int: + """Resolve a TIFF ``BitsPerSample`` tag value to a single integer. + + The TIFF spec allows ``BitsPerSample`` to be either a scalar or a + sequence with one entry per sample. xarray-spatial decodes a whole + IFD with one numpy dtype, so the per-sample widths must agree. + + Parameters + ---------- + bps : int or sequence of int + Raw value from ``IFD.bits_per_sample``. Accepts ``int``, ``tuple``, + or ``list``. + sample_format : int, optional + TIFF SampleFormat (1=uint, 2=int, 3=float). Used only to make the + ``gdal_translate`` hint in the error message more accurate when + the entries don't agree; not consulted when they do. + + Returns + ------- + int + The shared bits-per-sample value. + + Raises + ------ + ValueError + If ``bps`` is a sequence whose entries are not all equal. Files + with per-band bit depths (e.g. RGB+8-bit-alpha with + ``(16, 16, 16, 8)``) are not supported; convert with GDAL or + rasterio first. + """ + if isinstance(bps, (tuple, list)): + if len(bps) == 0: + raise ValueError("BitsPerSample tuple is empty") + first = bps[0] + for v in bps[1:]: + if v != first: + ot = _suggest_gdal_ot(bps, sample_format) + raise ValueError( + f"Mixed BitsPerSample per band is not supported: " + f"{tuple(bps)}. xarray-spatial decodes all bands with " + f"a single dtype. Convert the file to a uniform bit " + f"depth first, e.g. " + f"`gdal_translate -ot {ot} in.tif out.tif`." + ) + return int(first) + return int(bps) + + def numpy_to_tiff_dtype(dt: np.dtype) -> tuple[int, int]: """Convert a numpy dtype to (bits_per_sample, sample_format). diff --git a/xrspatial/geotiff/_geotags.py b/xrspatial/geotiff/_geotags.py index 998f5f14..401017e7 100644 --- a/xrspatial/geotiff/_geotags.py +++ b/xrspatial/geotiff/_geotags.py @@ -21,6 +21,7 @@ TAG_MODEL_TRANSFORMATION, TAG_GEO_KEY_DIRECTORY, TAG_GEO_DOUBLE_PARAMS, TAG_GEO_ASCII_PARAMS, ) +from ._dtypes import resolve_bits_per_sample # ImageDescription tag (270). Captured for round-trip but not managed # by the writer -- it flows through extra_tags pass-through. @@ -514,9 +515,7 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview, if ifd.photometric == 3: raw_cmap = ifd.colormap if raw_cmap is not None: - bps_val = ifd.bits_per_sample - if isinstance(bps_val, tuple): - bps_val = bps_val[0] + bps_val = resolve_bits_per_sample(ifd.bits_per_sample) n_colors = 1 << bps_val # 2^BitsPerSample # TIFF ColorMap: 3 * n_colors uint16 values # Layout: [R0..R_{n-1}, G0..G_{n-1}, B0..B_{n-1}] diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index 11b1d3cd..a99f1e63 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -18,7 +18,7 @@ predictor_decode, unpack_bits, ) -from ._dtypes import SUB_BYTE_BPS, tiff_dtype_to_numpy +from ._dtypes import SUB_BYTE_BPS, resolve_bits_per_sample, tiff_dtype_to_numpy from ._geotags import GeoInfo, GeoTransform, extract_geo_info from ._header import ( IFD, @@ -663,9 +663,7 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader, offsets = ifd.strip_offsets byte_counts = ifd.strip_byte_counts pred = ifd.predictor - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) bytes_per_sample = bps // 8 is_sub_byte = bps in SUB_BYTE_BPS jpeg_tables = ifd.jpeg_tables @@ -802,9 +800,7 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, samples = ifd.samples_per_pixel compression = ifd.compression pred = ifd.predictor - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) bytes_per_sample = bps // 8 is_sub_byte = bps in SUB_BYTE_BPS jpeg_tables = ifd.jpeg_tables @@ -994,9 +990,7 @@ def _read_cog_http(url: str, overview_level: int | None = None, # Select IFD based on overview level, skipping any mask IFDs ifd = select_overview_ifd(ifds, overview_level) - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) geo_info = extract_geo_info(ifd, header_bytes, header.byte_order) @@ -1211,9 +1205,7 @@ def read_to_array(source, *, window=None, overview_level: int | None = None, # Select IFD, skipping any mask IFDs ifd = select_overview_ifd(ifds, overview_level) - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) geo_info = extract_geo_info(ifd, data, header.byte_order) diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 4bf31062..a3fe4f31 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -394,6 +394,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *, from ._header import parse_header, parse_all_ifds from ._geotags import extract_geo_info from ._reader import _FileSource + from ._dtypes import resolve_bits_per_sample if not source_files: raise ValueError("source_files must not be empty") @@ -409,9 +410,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *, geo = extract_geo_info(ifd, data, header.byte_order) src.close() - bps = ifd.bits_per_sample - if isinstance(bps, tuple): - bps = bps[0] + bps = resolve_bits_per_sample(ifd.bits_per_sample) sources_meta.append({ 'path': src_path, diff --git a/xrspatial/geotiff/tests/test_mixed_bps.py b/xrspatial/geotiff/tests/test_mixed_bps.py new file mode 100644 index 00000000..529a03e8 --- /dev/null +++ b/xrspatial/geotiff/tests/test_mixed_bps.py @@ -0,0 +1,203 @@ +"""Tests for non-uniform BitsPerSample handling (issue #1505). + +A TIFF whose BitsPerSample tag carries different values per band +(e.g. ``(16, 16, 16, 8)`` for RGB plus an 8-bit alpha) cannot be +decoded into a single numpy dtype. xarray-spatial rejects such files +with a clear error so the user can convert them with GDAL/rasterio +instead of silently getting garbage in the mismatched bands. +""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest + +from xrspatial.geotiff import open_geotiff +from xrspatial.geotiff._dtypes import resolve_bits_per_sample + + +def _build_multi_band_tiff( + width: int, + height: int, + samples: int, + bits_per_sample, + pixel_dtype: np.dtype = np.dtype('uint16'), +) -> bytes: + """Build a minimal stripped uncompressed multi-band TIFF. + + ``bits_per_sample`` is written as-is into tag 258 — pass a list/tuple + to exercise the per-band code path. + """ + bo = '<' + pixel_data = np.zeros((height, width, samples), dtype=pixel_dtype) + pixel_bytes = pixel_data.tobytes() + + tags: list[tuple[int, int, int, bytes]] = [] + + def add_short(tag, val): + tags.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + + def add_long(tag, val): + tags.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + def add_shorts(tag, vals): + tags.append((tag, 3, len(vals), + struct.pack(f'{bo}{len(vals)}H', *vals))) + + add_short(256, width) # ImageWidth + add_short(257, height) # ImageLength + if isinstance(bits_per_sample, (list, tuple)): + add_shorts(258, list(bits_per_sample)) # BitsPerSample (per band) + else: + add_short(258, bits_per_sample) + add_short(259, 1) # Compression = none + add_short(262, 2 if samples >= 3 else 1) # PhotometricInterpretation + add_long(273, 0) # StripOffsets (patched) + add_short(277, samples) # SamplesPerPixel + add_short(278, height) # RowsPerStrip + add_long(279, len(pixel_bytes)) # StripByteCounts + add_short(284, 1) # PlanarConfiguration = chunky + add_shorts(339, [1] * samples) # SampleFormat = uint + + tags.sort(key=lambda t: t[0]) + + num_entries = len(tags) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + + overflow_buf = bytearray() + tag_overflow_offsets: dict[int, int | None] = {} + for tag, typ, count, raw in tags: + if len(raw) > 4: + tag_overflow_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_overflow_offsets[tag] = None + + pixel_start = overflow_start + len(overflow_buf) + + # Patch StripOffsets to point at the pixel block + patched = [] + for tag, typ, count, raw in tags: + if tag == 273: + patched.append((tag, 4, 1, struct.pack(f'{bo}I', pixel_start))) + else: + patched.append((tag, typ, count, raw)) + tags = patched + + # Re-serialize the IFD with final layout + out = bytearray() + out.extend(b'II') # little-endian + out.extend(struct.pack(f'{bo}H', 42)) # magic + out.extend(struct.pack(f'{bo}I', ifd_start)) # offset of first IFD + out.extend(struct.pack(f'{bo}H', num_entries)) + for tag, typ, count, raw in tags: + out.extend(struct.pack(f'{bo}H', tag)) + out.extend(struct.pack(f'{bo}H', typ)) + out.extend(struct.pack(f'{bo}I', count)) + if len(raw) <= 4: + payload = raw + b'\x00' * (4 - len(raw)) + out.extend(payload) + else: + out.extend(struct.pack(f'{bo}I', + overflow_start + tag_overflow_offsets[tag])) + out.extend(struct.pack(f'{bo}I', 0)) # next IFD = 0 + out.extend(overflow_buf) + out.extend(pixel_bytes) + return bytes(out) + + +class TestResolveBitsPerSample: + """Unit tests for the helper itself.""" + + def test_scalar(self): + assert resolve_bits_per_sample(16) == 16 + + def test_one_element_tuple(self): + assert resolve_bits_per_sample((8,)) == 8 + + def test_uniform_tuple(self): + assert resolve_bits_per_sample((16, 16, 16)) == 16 + + def test_uniform_list(self): + assert resolve_bits_per_sample([32, 32, 32, 32]) == 32 + + def test_mixed_tuple_raises(self): + with pytest.raises(ValueError, match=r"Mixed BitsPerSample"): + resolve_bits_per_sample((16, 16, 16, 8)) + + def test_error_message_contains_values(self): + with pytest.raises(ValueError) as exc: + resolve_bits_per_sample((16, 16, 16, 8)) + msg = str(exc.value) + assert "(16, 16, 16, 8)" in msg + assert "gdal_translate" in msg + + def test_error_message_ot_matches_widest_bps(self): + """gdal_translate hint should suggest a type wide enough for input.""" + # 32-bit + 8-bit alpha -> widest is 32, default sample format = uint + with pytest.raises(ValueError) as exc: + resolve_bits_per_sample((32, 32, 32, 8)) + assert "-ot UInt32" in str(exc.value) + + # Widest is 16-bit -> UInt16 (the original hard-coded suggestion). + with pytest.raises(ValueError) as exc: + resolve_bits_per_sample((16, 16, 16, 8)) + assert "-ot UInt16" in str(exc.value) + + def test_error_message_ot_uses_sample_format_hint(self): + """sample_format=3 (float) at 32-bit -> Float32 instead of UInt32.""" + with pytest.raises(ValueError) as exc: + resolve_bits_per_sample((32, 32, 32, 8), sample_format=3) + assert "-ot Float32" in str(exc.value) + + # sample_format=2 (int) at 16-bit -> Int16 instead of UInt16. + with pytest.raises(ValueError) as exc: + resolve_bits_per_sample((16, 16, 8), sample_format=2) + assert "-ot Int16" in str(exc.value) + + def test_empty_tuple_raises(self): + with pytest.raises(ValueError): + resolve_bits_per_sample(()) + + +class TestMixedBitsPerSampleTiff: + """End-to-end tests against open_geotiff.""" + + def test_uniform_bps_reads_fine(self, tmp_path): + path = tmp_path / "uniform_rgba.tif" + path.write_bytes( + _build_multi_band_tiff( + width=4, height=3, samples=4, + bits_per_sample=(16, 16, 16, 16), + pixel_dtype=np.dtype('uint16'), + ) + ) + da = open_geotiff(str(path)) + assert da.dtype == np.uint16 + # Multi-band TIFFs come back as (y, x, band) + assert da.sizes['y'] == 3 + assert da.sizes['x'] == 4 + assert da.sizes['band'] == 4 + + def test_mixed_bps_rgb_plus_8bit_alpha_rejected(self, tmp_path): + """RGB+8-bit-alpha is the canonical case from issue #1505.""" + path = tmp_path / "mixed_rgba.tif" + # NB: the pixel block here is uint16 throughout; the test only + # exercises the dispatch, not the (impossible) decode path. + path.write_bytes( + _build_multi_band_tiff( + width=4, height=3, samples=4, + bits_per_sample=(16, 16, 16, 8), + pixel_dtype=np.dtype('uint16'), + ) + ) + with pytest.raises(ValueError) as exc: + open_geotiff(str(path)) + msg = str(exc.value) + assert "(16, 16, 16, 8)" in msg + assert "Mixed BitsPerSample" in msg