Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _read_geo_info(source, *, overview_level: int | None = None):
overview_level : int or None
Overview IFD index (0 = full resolution).
"""
from ._dtypes import tiff_dtype_to_numpy
from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import extract_geo_info
from ._header import parse_all_ifds, parse_header
from ._reader import _coerce_path, _is_file_like
Expand Down Expand Up @@ -231,9 +231,7 @@ def _read_geo_info(source, *, overview_level: int | None = None):
ifd_idx = min(overview_level, len(ifds) - 1)
ifd = ifds[ifd_idx]
geo_info = extract_geo_info(ifd, data, header.byte_order)
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
n_bands = ifd.samples_per_pixel if ifd.samples_per_pixel > 1 else 0
return geo_info, ifd.height, ifd.width, file_dtype, n_bands
Expand Down Expand Up @@ -1445,7 +1443,7 @@ def read_geotiff_gpu(source: str, *,
_FileSource, _check_dimensions, MAX_PIXELS_DEFAULT, _coerce_path,
)
from ._header import parse_header, parse_all_ifds, validate_tile_layout
from ._dtypes import tiff_dtype_to_numpy
from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import extract_geo_info
from ._gpu_decode import gpu_decode_tiles

Expand All @@ -1470,9 +1468,7 @@ def read_geotiff_gpu(source: str, *,
ifd_idx = min(overview_level, len(ifds) - 1)
ifd = ifds[ifd_idx]

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, data, header.byte_order)

Expand Down
41 changes: 41 additions & 0 deletions xrspatial/geotiff/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,47 @@ def tiff_dtype_to_numpy(bits_per_sample: int, sample_format: int = 1) -> np.dtyp
SUB_BYTE_BPS = {1, 2, 4, 12}


def resolve_bits_per_sample(bps) -> int:
"""Resolve a TIFF ``BitsPerSample`` tag value to a single integer.

The TIFF spec allows ``BitsPerSample`` to be either a scalar or a tuple
with one entry per sample. xarray-spatial decodes a whole IFD with one
numpy dtype, so the per-sample widths must agree.

Parameters
----------
bps : int or tuple of int
Raw value from ``IFD.bits_per_sample``.

Returns
-------
int
The shared bits-per-sample value.

Raises
------
ValueError
If ``bps`` is a tuple whose entries are not all equal. Files with
per-band bit depths (e.g. RGB+8-bit-alpha with ``(16, 16, 16, 8)``)
are not supported; convert with GDAL/rasterio first, e.g.
``gdal_translate -ot UInt16 in.tif out.tif``.
"""
if isinstance(bps, (tuple, list)):
if len(bps) == 0:
raise ValueError("BitsPerSample tuple is empty")
first = bps[0]
for v in bps[1:]:
if v != first:
raise ValueError(
f"Mixed BitsPerSample per band is not supported: {tuple(bps)}. "
"xarray-spatial decodes all bands with a single dtype. "
"Convert the file to a uniform bit depth first, "
"e.g. `gdal_translate -ot UInt16 in.tif out.tif`."
)
return int(first)
return int(bps)


def numpy_to_tiff_dtype(dt: np.dtype) -> tuple[int, int]:
"""Convert a numpy dtype to (bits_per_sample, sample_format).

Expand Down
5 changes: 2 additions & 3 deletions xrspatial/geotiff/_geotags.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,8 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview,
if ifd.photometric == 3:
raw_cmap = ifd.colormap
if raw_cmap is not None:
bps_val = ifd.bits_per_sample
if isinstance(bps_val, tuple):
bps_val = bps_val[0]
from ._dtypes import resolve_bits_per_sample
bps_val = resolve_bits_per_sample(ifd.bits_per_sample)
n_colors = 1 << bps_val # 2^BitsPerSample
# TIFF ColorMap: 3 * n_colors uint16 values
# Layout: [R0..R_{n-1}, G0..G_{n-1}, B0..B_{n-1}]
Expand Down
18 changes: 5 additions & 13 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
predictor_decode,
unpack_bits,
)
from ._dtypes import SUB_BYTE_BPS, tiff_dtype_to_numpy
from ._dtypes import SUB_BYTE_BPS, resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import GeoInfo, GeoTransform, extract_geo_info
from ._header import IFD, TIFFHeader, parse_all_ifds, parse_header, validate_tile_layout

Expand Down Expand Up @@ -649,9 +649,7 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader,
offsets = ifd.strip_offsets
byte_counts = ifd.strip_byte_counts
pred = ifd.predictor
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS

Expand Down Expand Up @@ -785,9 +783,7 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader,
samples = ifd.samples_per_pixel
compression = ifd.compression
pred = ifd.predictor
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS

Expand Down Expand Up @@ -978,9 +974,7 @@ def _read_cog_http(url: str, overview_level: int | None = None,
ifd_idx = min(overview_level, len(ifds) - 1)
ifd = ifds[ifd_idx]

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, header_bytes, header.byte_order)

Expand Down Expand Up @@ -1137,9 +1131,7 @@ def read_to_array(source, *, window=None, overview_level: int | None = None,
ifd_idx = min(overview_level, len(ifds) - 1)
ifd = ifds[ifd_idx]

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, data, header.byte_order)

Expand Down
5 changes: 2 additions & 3 deletions xrspatial/geotiff/_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
from ._header import parse_header, parse_all_ifds
from ._geotags import extract_geo_info
from ._reader import _FileSource
from ._dtypes import resolve_bits_per_sample

if not source_files:
raise ValueError("source_files must not be empty")
Expand All @@ -409,9 +410,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
geo = extract_geo_info(ifd, data, header.byte_order)
src.close()

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)

sources_meta.append({
'path': src_path,
Expand Down
180 changes: 180 additions & 0 deletions xrspatial/geotiff/tests/test_mixed_bps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Tests for non-uniform BitsPerSample handling (issue #1505).

A TIFF whose BitsPerSample tag carries different values per band
(e.g. ``(16, 16, 16, 8)`` for RGB plus an 8-bit alpha) cannot be
decoded into a single numpy dtype. xarray-spatial rejects such files
with a clear error so the user can convert them with GDAL/rasterio
instead of silently getting garbage in the mismatched bands.
"""
from __future__ import annotations

import struct

import numpy as np
import pytest

from xrspatial.geotiff import open_geotiff
from xrspatial.geotiff._dtypes import resolve_bits_per_sample


def _build_multi_band_tiff(
width: int,
height: int,
samples: int,
bits_per_sample,
pixel_dtype: np.dtype = np.dtype('uint16'),
) -> bytes:
"""Build a minimal stripped uncompressed multi-band TIFF.

``bits_per_sample`` is written as-is into tag 258 — pass a list/tuple
to exercise the per-band code path.
"""
bo = '<'
pixel_data = np.zeros((height, width, samples), dtype=pixel_dtype)
pixel_bytes = pixel_data.tobytes()

tags: list[tuple[int, int, int, bytes]] = []

def add_short(tag, val):
tags.append((tag, 3, 1, struct.pack(f'{bo}H', val)))

def add_long(tag, val):
tags.append((tag, 4, 1, struct.pack(f'{bo}I', val)))

def add_shorts(tag, vals):
tags.append((tag, 3, len(vals),
struct.pack(f'{bo}{len(vals)}H', *vals)))

add_short(256, width) # ImageWidth
add_short(257, height) # ImageLength
if isinstance(bits_per_sample, (list, tuple)):
add_shorts(258, list(bits_per_sample)) # BitsPerSample (per band)
else:
add_short(258, bits_per_sample)
add_short(259, 1) # Compression = none
add_short(262, 2 if samples >= 3 else 1) # PhotometricInterpretation
add_long(273, 0) # StripOffsets (patched)
add_short(277, samples) # SamplesPerPixel
add_short(278, height) # RowsPerStrip
add_long(279, len(pixel_bytes)) # StripByteCounts
add_short(284, 1) # PlanarConfiguration = chunky
add_shorts(339, [1] * samples) # SampleFormat = uint

tags.sort(key=lambda t: t[0])

num_entries = len(tags)
ifd_start = 8
ifd_size = 2 + 12 * num_entries + 4
overflow_start = ifd_start + ifd_size

overflow_buf = bytearray()
tag_overflow_offsets: dict[int, int | None] = {}
for tag, typ, count, raw in tags:
if len(raw) > 4:
tag_overflow_offsets[tag] = len(overflow_buf)
overflow_buf.extend(raw)
if len(overflow_buf) % 2:
overflow_buf.append(0)
else:
tag_overflow_offsets[tag] = None

pixel_start = overflow_start + len(overflow_buf)

# Patch StripOffsets to point at the pixel block
patched = []
for tag, typ, count, raw in tags:
if tag == 273:
patched.append((tag, 4, 1, struct.pack(f'{bo}I', pixel_start)))
else:
patched.append((tag, typ, count, raw))
tags = patched

# Re-serialize the IFD with final layout
out = bytearray()
out.extend(b'II') # little-endian
out.extend(struct.pack(f'{bo}H', 42)) # magic
out.extend(struct.pack(f'{bo}I', ifd_start)) # offset of first IFD
out.extend(struct.pack(f'{bo}H', num_entries))
for tag, typ, count, raw in tags:
out.extend(struct.pack(f'{bo}H', tag))
out.extend(struct.pack(f'{bo}H', typ))
out.extend(struct.pack(f'{bo}I', count))
if len(raw) <= 4:
payload = raw + b'\x00' * (4 - len(raw))
out.extend(payload)
else:
out.extend(struct.pack(f'{bo}I',
overflow_start + tag_overflow_offsets[tag]))
out.extend(struct.pack(f'{bo}I', 0)) # next IFD = 0
out.extend(overflow_buf)
out.extend(pixel_bytes)
return bytes(out)


class TestResolveBitsPerSample:
"""Unit tests for the helper itself."""

def test_scalar(self):
assert resolve_bits_per_sample(16) == 16

def test_one_element_tuple(self):
assert resolve_bits_per_sample((8,)) == 8

def test_uniform_tuple(self):
assert resolve_bits_per_sample((16, 16, 16)) == 16

def test_uniform_list(self):
assert resolve_bits_per_sample([32, 32, 32, 32]) == 32

def test_mixed_tuple_raises(self):
with pytest.raises(ValueError, match=r"Mixed BitsPerSample"):
resolve_bits_per_sample((16, 16, 16, 8))

def test_error_message_contains_values(self):
with pytest.raises(ValueError) as exc:
resolve_bits_per_sample((16, 16, 16, 8))
msg = str(exc.value)
assert "(16, 16, 16, 8)" in msg
assert "gdal_translate" in msg

def test_empty_tuple_raises(self):
with pytest.raises(ValueError):
resolve_bits_per_sample(())


class TestMixedBitsPerSampleTiff:
"""End-to-end tests against open_geotiff."""

def test_uniform_bps_reads_fine(self, tmp_path):
path = tmp_path / "uniform_rgba.tif"
path.write_bytes(
_build_multi_band_tiff(
width=4, height=3, samples=4,
bits_per_sample=(16, 16, 16, 16),
pixel_dtype=np.dtype('uint16'),
)
)
da = open_geotiff(str(path))
assert da.dtype == np.uint16
# Multi-band TIFFs come back as (y, x, band)
assert da.sizes['y'] == 3
assert da.sizes['x'] == 4
assert da.sizes['band'] == 4

def test_mixed_bps_rgb_plus_8bit_alpha_rejected(self, tmp_path):
"""RGB+8-bit-alpha is the canonical case from issue #1505."""
path = tmp_path / "mixed_rgba.tif"
# NB: the pixel block here is uint16 throughout; the test only
# exercises the dispatch, not the (impossible) decode path.
path.write_bytes(
_build_multi_band_tiff(
width=4, height=3, samples=4,
bits_per_sample=(16, 16, 16, 8),
pixel_dtype=np.dtype('uint16'),
)
)
with pytest.raises(ValueError) as exc:
open_geotiff(str(path))
msg = str(exc.value)
assert "(16, 16, 16, 8)" in msg
assert "Mixed BitsPerSample" in msg
Loading