diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index a9558b5a..b7080e84 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -7,12 +7,88 @@ from xrspatial.utils import ngjit +# -- Decompression-bomb defenses --------------------------------------------- +# +# A malicious TIFF can declare a small strip/tile compressed payload that +# expands to multiple gigabytes when decoded. Without a cap the reader is +# OOM-killed before the post-decode size check (in ``_decode_strip_or_tile``) +# ever runs, since by then the bomb has already been allocated. Each codec +# below takes an ``expected_size`` (the byte length the caller computed from +# the IFD's declared dimensions) and refuses to produce more than +# ``expected_size * _DECOMPRESS_MARGIN`` bytes. The margin allows for the +# small amount of legitimate codec metadata that some encoders emit, while +# still rejecting the 1000:1 ratios characteristic of bomb attacks. +_DECOMPRESS_MARGIN = 1.05 + + +def _max_output_with_margin(expected_size: int) -> int: + """Return the cap (in bytes) for a codec given the caller's expected size. + + Adds at least one byte of slack so that callers passing 0 (meaning + "unknown") still get a usable buffer for tests, while a single byte of + overflow is detected. + """ + if expected_size <= 0: + # No cap available: fall back to a generous default to preserve + # backward compatibility with callers that don't supply a size. + # The reader always supplies a size, so this branch is mainly for + # direct callers and round-trip tests. + return 0 + return int(expected_size * _DECOMPRESS_MARGIN) + 1 + + # -- Deflate (zlib wrapper) -------------------------------------------------- -def deflate_decompress(data: bytes) -> bytes: - """Decompress deflate/zlib data.""" - return zlib.decompress(data) +def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes: + """Decompress deflate/zlib data with an optional output-size cap. + + Parameters + ---------- + data : bytes + Deflate/zlib compressed payload. + expected_size : int, optional + Caller's expected uncompressed byte count. When > 0, the decoder + refuses to produce more than ``expected_size * 1.05 + 1`` bytes + and raises ``ValueError`` on overflow (decompression-bomb guard). + + Returns + ------- + bytes + Uncompressed data. + """ + if expected_size <= 0: + # Backward-compat path: caller hasn't supplied an expected size. + return zlib.decompress(data) + cap = _max_output_with_margin(expected_size) + decompressor = zlib.decompressobj() + # Read one byte beyond the cap so that an overflowing stream is detected + # rather than silently truncated. + out = decompressor.decompress(data, max_length=cap + 1) + # Drain any output the decompressor has buffered (including unconsumed + # input). We stop as soon as output exceeds the cap, or when the + # decompressor declares EOF, or when no further bytes are produced. + while not decompressor.eof and len(out) <= cap: + feed = decompressor.unconsumed_tail + more = decompressor.decompress(feed, max_length=cap + 1 - len(out)) + if not more: + break + out += more + if len(out) > cap: + raise ValueError( + f"deflate decode exceeded expected size: {len(out)} bytes " + f"produced, cap is {cap} (expected {expected_size}). Likely " + f"a decompression bomb." + ) + # Flush any remaining state to surface tail bytes / errors. + out += decompressor.flush() + if len(out) > cap: + raise ValueError( + f"deflate decode exceeded expected size: {len(out)} bytes " + f"produced, cap is {cap} (expected {expected_size}). Likely " + f"a decompression bomb." + ) + return out def deflate_compress(data: bytes, level: int = 6) -> bytes: @@ -794,15 +870,20 @@ def unpack_bits(data: np.ndarray, bps: int, pixel_count: int) -> np.ndarray: # -- PackBits (simple RLE) ---------------------------------------------------- -def packbits_decompress(data: bytes) -> bytes: +def packbits_decompress(data: bytes, expected_size: int = 0) -> bytes: """Decompress PackBits (TIFF compression tag 32773). Simple RLE: read a header byte n. - 0 <= n <= 127: copy the next n+1 bytes literally. - -127 <= n <= -1: repeat the next byte 1-n times. - n == -128: no-op. + + When ``expected_size`` > 0, the decoder refuses to emit more than + ``expected_size * 1.05 + 1`` bytes and raises ``ValueError`` on overflow + (decompression-bomb guard). """ src = data if isinstance(data, (bytes, bytearray)) else bytes(data) + cap = _max_output_with_margin(expected_size) out = bytearray() i = 0 length = len(src) @@ -820,6 +901,12 @@ def packbits_decompress(data: bytes) -> bytes: out.extend(bytes([src[i]]) * (1 - n)) i += 1 # n == -128: skip + if cap and len(out) > cap: + raise ValueError( + f"packbits decode exceeded expected size: {len(out)} bytes " + f"produced, cap is {cap} (expected {expected_size}). " + f"Likely a decompression bomb." + ) return bytes(out) @@ -953,13 +1040,44 @@ def jpeg_compress(data: bytes, width: int, height: int, _zstd = None -def zstd_decompress(data: bytes) -> bytes: - """Decompress Zstandard data. Requires the ``zstandard`` package.""" +def zstd_decompress(data: bytes, expected_size: int = 0) -> bytes: + """Decompress Zstandard data with an optional output-size cap. + + Requires the ``zstandard`` package. When ``expected_size`` > 0 the + decoder refuses to emit more than ``expected_size * 1.05 + 1`` bytes + and raises ``ValueError`` on overflow (decompression-bomb guard). + """ if not ZSTD_AVAILABLE: raise ImportError( "zstandard is required to read ZSTD-compressed TIFFs. " "Install it with: pip install zstandard") - return _zstd.ZstdDecompressor().decompress(data) + if expected_size <= 0: + return _zstd.ZstdDecompressor().decompress(data) + cap = _max_output_with_margin(expected_size) + # ``decompress(data, max_output_size=...)`` is not enforced when the + # frame embeds the content size, so use ``stream_reader`` and read at + # most ``cap + 1`` bytes. If the decoder still has data after that, the + # frame was bigger than the cap. + reader = _zstd.ZstdDecompressor().stream_reader(data) + try: + out = reader.read(cap + 1) + if len(out) > cap: + raise ValueError( + f"zstd decode exceeded expected size: produced more than " + f"{cap} bytes (expected {expected_size}). Likely a " + f"decompression bomb." + ) + # Probe for additional bytes — a zero-length read confirms EOF. + extra = reader.read(1) + if extra: + raise ValueError( + f"zstd decode exceeded expected size: produced more than " + f"{cap} bytes (expected {expected_size}). Likely a " + f"decompression bomb." + ) + finally: + reader.close() + return out def zstd_compress(data: bytes, level: int = 3) -> bytes: @@ -1099,13 +1217,38 @@ def lerc_compress(data: bytes, width: int, height: int, _lz4 = None -def lz4_decompress(data: bytes) -> bytes: - """Decompress LZ4 frame data. Requires the ``lz4`` package.""" +def lz4_decompress(data: bytes, expected_size: int = 0) -> bytes: + """Decompress LZ4 frame data with an optional output-size cap. + + Requires the ``lz4`` package. When ``expected_size`` > 0 the decoder + refuses to emit more than ``expected_size * 1.05 + 1`` bytes and raises + ``ValueError`` on overflow (decompression-bomb guard). + """ if not LZ4_AVAILABLE: raise ImportError( "lz4 is required to read LZ4-compressed TIFFs. " "Install it with: pip install lz4") - return _lz4.decompress(data) + if expected_size <= 0: + return _lz4.decompress(data) + cap = _max_output_with_margin(expected_size) + decompressor = _lz4.LZ4FrameDecompressor() + out = decompressor.decompress(data, max_length=cap + 1) + if len(out) > cap: + raise ValueError( + f"lz4 decode exceeded expected size: {len(out)} bytes produced, " + f"cap is {cap} (expected {expected_size}). Likely a " + f"decompression bomb." + ) + # ``needs_input == False`` means the decoder has buffered output it + # couldn't deliver because ``max_length`` was reached. That implies the + # frame is larger than the cap, so the input is a bomb. + if not decompressor.needs_input: + raise ValueError( + f"lz4 decode exceeded expected size: produced more than {cap} " + f"bytes (expected {expected_size}). Likely a decompression " + f"bomb." + ) + return out def lz4_compress(data: bytes, level: int = 0) -> bytes: @@ -1155,23 +1298,27 @@ def decompress(data, compression: int, expected_size: int = 0, if compression == COMPRESSION_NONE: return np.frombuffer(data, dtype=np.uint8) elif compression in (COMPRESSION_DEFLATE, COMPRESSION_ADOBE_DEFLATE): - return np.frombuffer(deflate_decompress(data), dtype=np.uint8) + return np.frombuffer( + deflate_decompress(data, expected_size), dtype=np.uint8) elif compression == COMPRESSION_LZW: return lzw_decompress(data, expected_size) elif compression == COMPRESSION_PACKBITS: - return np.frombuffer(packbits_decompress(data), dtype=np.uint8) + return np.frombuffer( + packbits_decompress(data, expected_size), dtype=np.uint8) elif compression == COMPRESSION_JPEG: return np.frombuffer( jpeg_decompress(data, width, height, samples, jpeg_tables=jpeg_tables), dtype=np.uint8) elif compression == COMPRESSION_ZSTD: - return np.frombuffer(zstd_decompress(data), dtype=np.uint8) + return np.frombuffer( + zstd_decompress(data, expected_size), dtype=np.uint8) elif compression == COMPRESSION_JPEG2000: return np.frombuffer( jpeg2000_decompress(data, width, height, samples), dtype=np.uint8) elif compression == COMPRESSION_LZ4: - return np.frombuffer(lz4_decompress(data), dtype=np.uint8) + return np.frombuffer( + lz4_decompress(data, expected_size), dtype=np.uint8) elif compression == COMPRESSION_LERC: return np.frombuffer( lerc_decompress(data, width, height, samples), dtype=np.uint8) diff --git a/xrspatial/geotiff/tests/test_decompression_caps.py b/xrspatial/geotiff/tests/test_decompression_caps.py new file mode 100644 index 00000000..8d08a377 --- /dev/null +++ b/xrspatial/geotiff/tests/test_decompression_caps.py @@ -0,0 +1,271 @@ +"""Tests for decompression-bomb defenses (security finding S1). + +Each codec used by the TIFF reader (deflate, zstd, lz4, packbits) accepts an +``expected_size`` argument and refuses to produce more than ~5% above that +size before raising ``ValueError``. Without these caps a small malicious +TIFF (a few MB compressed) could expand to many GB during decode and OOM +the reader before the post-decode size check ran. + +Each test here builds a minimal TIFF whose strip payload, when decoded, +would balloon to ~1 GiB while declaring image dimensions implying ~1 KiB +of pixel data. The reader must raise cleanly rather than allocate the +bomb. +""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest +import zlib + +from xrspatial.geotiff._compression import ( + deflate_decompress, + lz4_decompress, + packbits_decompress, + zstd_decompress, +) +from xrspatial.geotiff._reader import read_to_array + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _build_tiff_with_strip(strip_bytes: bytes, *, compression: int, + width: int, height: int) -> bytes: + """Build a minimal little-endian uint8 TIFF with one strip of opaque bytes. + + The strip is written verbatim (no further encoding); the caller has + already compressed it. ``width * height`` is the declared logical + image size — the reader uses that to compute the expected decompressed + size and apply the bomb cap. + """ + bo = '<' + + # Tags: (tag_id, type_id, count, value_or_offset_bytes_4) + # We keep every tag's value inline so the TIFF body order is: + # header(8) | IFD | strip + tags = [] + + def add_short(tag, val): + tags.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + + def add_long(tag, val): + tags.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + add_short(256, width) # ImageWidth (assume <= 65535) + add_short(257, height) # ImageLength + add_short(258, 8) # BitsPerSample = 8 (uint8) + add_short(259, compression) # Compression + add_short(262, 1) # PhotometricInterpretation = BlackIsZero + add_short(277, 1) # SamplesPerPixel + add_short(278, height) # RowsPerStrip + add_long(273, 0) # StripOffsets (placeholder) + add_long(279, len(strip_bytes)) # StripByteCounts + add_short(339, 1) # SampleFormat = unsigned int + + tags.sort(key=lambda t: t[0]) + num_entries = len(tags) + ifd_size = 2 + 12 * num_entries + 4 # count + entries + next-IFD + ifd_start = 8 + strip_offset = ifd_start + ifd_size + + # Patch StripOffsets (tag 273) with the real strip location. + patched = [] + for tag_id, typ, count, raw in tags: + if tag_id == 273: + raw = struct.pack(f'{bo}I', strip_offset) + patched.append((tag_id, typ, count, raw)) + tags = patched + + out = bytearray() + out += b'II' # little-endian + out += struct.pack(f'{bo}H', 42) # magic + out += struct.pack(f'{bo}I', ifd_start) # offset to IFD0 + out += struct.pack(f'{bo}H', num_entries) # IFD entry count + for tag_id, typ, count, raw in tags: + out += struct.pack(f'{bo}HHI', tag_id, typ, count) + # Pad raw to 4 bytes (all our tags fit inline). + out += raw.ljust(4, b'\x00') + out += struct.pack(f'{bo}I', 0) # no next IFD + out += strip_bytes + return bytes(out) + + +# --------------------------------------------------------------------------- +# Codec-level direct tests +# --------------------------------------------------------------------------- + +class TestCodecDirect: + """Exercise the codec functions directly with bomb payloads.""" + + def test_deflate_bomb_raises(self): + # 100 MiB of zeros compresses to ~100 KiB; the cap is 1 KiB. + big = b'\x00' * (100 * 1024 * 1024) + comp = zlib.compress(big, 9) + with pytest.raises(ValueError, match="exceed"): + deflate_decompress(comp, expected_size=1024) + + def test_deflate_legitimate_passes(self): + data = b'A' * 4096 + comp = zlib.compress(data, 9) + out = deflate_decompress(comp, expected_size=len(data)) + assert out == data + + def test_packbits_bomb_raises(self): + # 0x81 0x00 = "repeat next byte 128 times". Repeated 1M times this + # decodes to 128 MiB of zeros from a 2 MiB input. + bomb = b'\x81\x00' * (1024 * 1024) + with pytest.raises(ValueError, match="exceed"): + packbits_decompress(bomb, expected_size=1024) + + def test_packbits_legitimate_passes(self): + # Literal run: header byte n=3 means "copy next 4 bytes literally". + data = b'\x03ABCD' + out = packbits_decompress(data, expected_size=4) + assert out == b'ABCD' + + +@pytest.mark.skipif( + pytest.importorskip("zstandard", reason="zstandard not installed") is None, + reason="zstandard not installed", +) +class TestZstdDirect: + def test_zstd_bomb_raises(self): + import zstandard + big = b'\x00' * (100 * 1024 * 1024) + comp = zstandard.ZstdCompressor().compress(big) + with pytest.raises(ValueError, match="exceed"): + zstd_decompress(comp, expected_size=1024) + + def test_zstd_legitimate_passes(self): + import zstandard + data = b'A' * 4096 + comp = zstandard.ZstdCompressor().compress(data) + out = zstd_decompress(comp, expected_size=len(data)) + assert out == data + + +@pytest.mark.skipif( + pytest.importorskip("lz4.frame", reason="lz4 not installed") is None, + reason="lz4 not installed", +) +class TestLz4Direct: + def test_lz4_bomb_raises(self): + import lz4.frame + big = b'\x00' * (100 * 1024 * 1024) + comp = lz4.frame.compress(big) + with pytest.raises(ValueError, match="exceed"): + lz4_decompress(comp, expected_size=1024) + + def test_lz4_legitimate_passes(self): + import lz4.frame + data = b'A' * 4096 + comp = lz4.frame.compress(data) + out = lz4_decompress(comp, expected_size=len(data)) + assert out == data + + +# --------------------------------------------------------------------------- +# End-to-end TIFF tests (audit reproducer shape) +# --------------------------------------------------------------------------- + +# 1024 x 1024 uint8 = 1 MiB declared image. We feed a strip whose decoded +# size is 1 GiB, so the reader sees a 1024:1 ratio bomb. The strip header +# we patch in claims the compressed length truthfully (a few KB) so the +# raw I/O step succeeds and the codec is the layer that has to refuse. +_DECLARED_W = 1024 +_DECLARED_H = 1024 +_DECLARED_BYTES = _DECLARED_W * _DECLARED_H # 1 MiB +_BOMB_BYTES = 1 << 30 # 1 GiB + + +def test_deflate_bomb_rejected(tmp_path): + """1 MiB declared, 1 GiB decoded — must raise rather than OOM.""" + payload = b'\x00' * _BOMB_BYTES + strip = zlib.compress(payload, 9) + # Sanity: the TIFF on disk should be small. + assert len(strip) < 5 * 1024 * 1024 + tiff = _build_tiff_with_strip(strip, compression=8, + width=_DECLARED_W, height=_DECLARED_H) + path = tmp_path / "deflate_bomb.tif" + path.write_bytes(tiff) + with pytest.raises(ValueError, match="exceed"): + read_to_array(str(path)) + + +def test_zstd_bomb_rejected(tmp_path): + zstandard = pytest.importorskip("zstandard") + payload = b'\x00' * _BOMB_BYTES + strip = zstandard.ZstdCompressor().compress(payload) + assert len(strip) < 5 * 1024 * 1024 + tiff = _build_tiff_with_strip(strip, compression=50000, + width=_DECLARED_W, height=_DECLARED_H) + path = tmp_path / "zstd_bomb.tif" + path.write_bytes(tiff) + with pytest.raises(ValueError, match="exceed"): + read_to_array(str(path)) + + +def test_lz4_bomb_rejected(tmp_path): + lz4_frame = pytest.importorskip("lz4.frame") + payload = b'\x00' * _BOMB_BYTES + strip = lz4_frame.compress(payload) + # LZ4 has a higher floor than deflate/zstd for runs of zeros, but should + # still be a fraction of a percent of the payload. + assert len(strip) < 32 * 1024 * 1024 + tiff = _build_tiff_with_strip(strip, compression=50004, + width=_DECLARED_W, height=_DECLARED_H) + path = tmp_path / "lz4_bomb.tif" + path.write_bytes(tiff) + with pytest.raises(ValueError, match="exceed"): + read_to_array(str(path)) + + +def test_packbits_bomb_rejected(tmp_path): + # Packbits "repeat next byte 128 times" header is 0x81 0x00 (2 bytes). + # We declare 1024x1024=1 MiB image but supply a 2 MiB strip that + # decodes to 128 MiB. The cap should fire long before allocation. + strip = b'\x81\x00' * (1024 * 1024) + tiff = _build_tiff_with_strip(strip, compression=32773, + width=_DECLARED_W, height=_DECLARED_H) + path = tmp_path / "packbits_bomb.tif" + path.write_bytes(tiff) + with pytest.raises(ValueError, match="exceed"): + read_to_array(str(path)) + + +# --------------------------------------------------------------------------- +# Negative tests: legitimate high-ratio compression must still pass +# --------------------------------------------------------------------------- + +def test_legitimate_high_compression_passes(tmp_path): + """All-zero array compresses to a fraction of declared size — must pass.""" + arr = np.zeros((_DECLARED_H, _DECLARED_W), dtype=np.uint8) + strip = zlib.compress(arr.tobytes(), 9) + # Confirm we actually have a high ratio (not a degenerate test). + assert len(strip) < _DECLARED_BYTES // 50 + tiff = _build_tiff_with_strip(strip, compression=8, + width=_DECLARED_W, height=_DECLARED_H) + path = tmp_path / "legit.tif" + path.write_bytes(tiff) + out, _ = read_to_array(str(path)) + assert out.shape == (_DECLARED_H, _DECLARED_W) + assert out.dtype == np.uint8 + assert (out == 0).all() + + +def test_cap_includes_metadata_margin(): + """The cap allows ~5% of legitimate codec metadata above expected size. + + Some encoders emit small framing or trailing bytes; the cap must not + reject them. We feed a payload exactly at expected_size + a few bytes + and confirm it decodes. + """ + expected = 1000 + # Decompressed size: expected + 30 (3% over). Within the 5% margin. + data = b'A' * (expected + 30) + comp = zlib.compress(data, 9) + out = deflate_decompress(comp, expected_size=expected) + assert out == data