Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 161 additions & 14 deletions xrspatial/geotiff/_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,88 @@

from xrspatial.utils import ngjit

# -- Decompression-bomb defenses ---------------------------------------------
#
# A malicious TIFF can declare a small strip/tile compressed payload that
# expands to multiple gigabytes when decoded. Without a cap the reader is
# OOM-killed before the post-decode size check (in ``_decode_strip_or_tile``)
# ever runs, since by then the bomb has already been allocated. Each codec
# below takes an ``expected_size`` (the byte length the caller computed from
# the IFD's declared dimensions) and refuses to produce more than
# ``expected_size * _DECOMPRESS_MARGIN`` bytes. The margin allows for the
# small amount of legitimate codec metadata that some encoders emit, while
# still rejecting the 1000:1 ratios characteristic of bomb attacks.
_DECOMPRESS_MARGIN = 1.05


def _max_output_with_margin(expected_size: int) -> int:
"""Return the cap (in bytes) for a codec given the caller's expected size.

Adds at least one byte of slack so that callers passing 0 (meaning
"unknown") still get a usable buffer for tests, while a single byte of
overflow is detected.
"""
if expected_size <= 0:
# No cap available: fall back to a generous default to preserve
# backward compatibility with callers that don't supply a size.
# The reader always supplies a size, so this branch is mainly for
# direct callers and round-trip tests.
return 0
return int(expected_size * _DECOMPRESS_MARGIN) + 1


# -- Deflate (zlib wrapper) --------------------------------------------------


def deflate_decompress(data: bytes) -> bytes:
"""Decompress deflate/zlib data."""
return zlib.decompress(data)
def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes:
"""Decompress deflate/zlib data with an optional output-size cap.

Parameters
----------
data : bytes
Deflate/zlib compressed payload.
expected_size : int, optional
Caller's expected uncompressed byte count. When > 0, the decoder
refuses to produce more than ``expected_size * 1.05 + 1`` bytes
and raises ``ValueError`` on overflow (decompression-bomb guard).

Returns
-------
bytes
Uncompressed data.
"""
if expected_size <= 0:
# Backward-compat path: caller hasn't supplied an expected size.
return zlib.decompress(data)
cap = _max_output_with_margin(expected_size)
decompressor = zlib.decompressobj()
# Read one byte beyond the cap so that an overflowing stream is detected
# rather than silently truncated.
out = decompressor.decompress(data, max_length=cap + 1)
# Drain any output the decompressor has buffered (including unconsumed
# input). We stop as soon as output exceeds the cap, or when the
# decompressor declares EOF, or when no further bytes are produced.
while not decompressor.eof and len(out) <= cap:
feed = decompressor.unconsumed_tail
more = decompressor.decompress(feed, max_length=cap + 1 - len(out))
if not more:
break
out += more
if len(out) > cap:
raise ValueError(
f"deflate decode exceeded expected size: {len(out)} bytes "
f"produced, cap is {cap} (expected {expected_size}). Likely "
f"a decompression bomb."
)
# Flush any remaining state to surface tail bytes / errors.
out += decompressor.flush()
if len(out) > cap:
raise ValueError(
f"deflate decode exceeded expected size: {len(out)} bytes "
f"produced, cap is {cap} (expected {expected_size}). Likely "
f"a decompression bomb."
)
return out


def deflate_compress(data: bytes, level: int = 6) -> bytes:
Expand Down Expand Up @@ -794,15 +870,20 @@ def unpack_bits(data: np.ndarray, bps: int, pixel_count: int) -> np.ndarray:

# -- PackBits (simple RLE) ----------------------------------------------------

def packbits_decompress(data: bytes) -> bytes:
def packbits_decompress(data: bytes, expected_size: int = 0) -> bytes:
"""Decompress PackBits (TIFF compression tag 32773).

Simple RLE: read a header byte n.
- 0 <= n <= 127: copy the next n+1 bytes literally.
- -127 <= n <= -1: repeat the next byte 1-n times.
- n == -128: no-op.

When ``expected_size`` > 0, the decoder refuses to emit more than
``expected_size * 1.05 + 1`` bytes and raises ``ValueError`` on overflow
(decompression-bomb guard).
"""
src = data if isinstance(data, (bytes, bytearray)) else bytes(data)
cap = _max_output_with_margin(expected_size)
out = bytearray()
i = 0
length = len(src)
Expand All @@ -820,6 +901,12 @@ def packbits_decompress(data: bytes) -> bytes:
out.extend(bytes([src[i]]) * (1 - n))
i += 1
# n == -128: skip
if cap and len(out) > cap:
raise ValueError(
f"packbits decode exceeded expected size: {len(out)} bytes "
f"produced, cap is {cap} (expected {expected_size}). "
f"Likely a decompression bomb."
)
return bytes(out)


Expand Down Expand Up @@ -953,13 +1040,44 @@ def jpeg_compress(data: bytes, width: int, height: int,
_zstd = None


def zstd_decompress(data: bytes) -> bytes:
"""Decompress Zstandard data. Requires the ``zstandard`` package."""
def zstd_decompress(data: bytes, expected_size: int = 0) -> bytes:
"""Decompress Zstandard data with an optional output-size cap.

Requires the ``zstandard`` package. When ``expected_size`` > 0 the
decoder refuses to emit more than ``expected_size * 1.05 + 1`` bytes
and raises ``ValueError`` on overflow (decompression-bomb guard).
"""
if not ZSTD_AVAILABLE:
raise ImportError(
"zstandard is required to read ZSTD-compressed TIFFs. "
"Install it with: pip install zstandard")
return _zstd.ZstdDecompressor().decompress(data)
if expected_size <= 0:
return _zstd.ZstdDecompressor().decompress(data)
cap = _max_output_with_margin(expected_size)
# ``decompress(data, max_output_size=...)`` is not enforced when the
# frame embeds the content size, so use ``stream_reader`` and read at
# most ``cap + 1`` bytes. If the decoder still has data after that, the
# frame was bigger than the cap.
reader = _zstd.ZstdDecompressor().stream_reader(data)
try:
out = reader.read(cap + 1)
if len(out) > cap:
raise ValueError(
f"zstd decode exceeded expected size: produced more than "
f"{cap} bytes (expected {expected_size}). Likely a "
f"decompression bomb."
)
# Probe for additional bytes — a zero-length read confirms EOF.
extra = reader.read(1)
if extra:
raise ValueError(
f"zstd decode exceeded expected size: produced more than "
f"{cap} bytes (expected {expected_size}). Likely a "
f"decompression bomb."
)
finally:
reader.close()
return out


def zstd_compress(data: bytes, level: int = 3) -> bytes:
Expand Down Expand Up @@ -1099,13 +1217,38 @@ def lerc_compress(data: bytes, width: int, height: int,
_lz4 = None


def lz4_decompress(data: bytes) -> bytes:
"""Decompress LZ4 frame data. Requires the ``lz4`` package."""
def lz4_decompress(data: bytes, expected_size: int = 0) -> bytes:
"""Decompress LZ4 frame data with an optional output-size cap.

Requires the ``lz4`` package. When ``expected_size`` > 0 the decoder
refuses to emit more than ``expected_size * 1.05 + 1`` bytes and raises
``ValueError`` on overflow (decompression-bomb guard).
"""
if not LZ4_AVAILABLE:
raise ImportError(
"lz4 is required to read LZ4-compressed TIFFs. "
"Install it with: pip install lz4")
return _lz4.decompress(data)
if expected_size <= 0:
return _lz4.decompress(data)
cap = _max_output_with_margin(expected_size)
decompressor = _lz4.LZ4FrameDecompressor()
out = decompressor.decompress(data, max_length=cap + 1)
if len(out) > cap:
raise ValueError(
f"lz4 decode exceeded expected size: {len(out)} bytes produced, "
f"cap is {cap} (expected {expected_size}). Likely a "
f"decompression bomb."
)
# ``needs_input == False`` means the decoder has buffered output it
# couldn't deliver because ``max_length`` was reached. That implies the
# frame is larger than the cap, so the input is a bomb.
if not decompressor.needs_input:
raise ValueError(
f"lz4 decode exceeded expected size: produced more than {cap} "
f"bytes (expected {expected_size}). Likely a decompression "
f"bomb."
)
return out


def lz4_compress(data: bytes, level: int = 0) -> bytes:
Expand Down Expand Up @@ -1155,23 +1298,27 @@ def decompress(data, compression: int, expected_size: int = 0,
if compression == COMPRESSION_NONE:
return np.frombuffer(data, dtype=np.uint8)
elif compression in (COMPRESSION_DEFLATE, COMPRESSION_ADOBE_DEFLATE):
return np.frombuffer(deflate_decompress(data), dtype=np.uint8)
return np.frombuffer(
deflate_decompress(data, expected_size), dtype=np.uint8)
elif compression == COMPRESSION_LZW:
return lzw_decompress(data, expected_size)
elif compression == COMPRESSION_PACKBITS:
return np.frombuffer(packbits_decompress(data), dtype=np.uint8)
return np.frombuffer(
packbits_decompress(data, expected_size), dtype=np.uint8)
elif compression == COMPRESSION_JPEG:
return np.frombuffer(
jpeg_decompress(data, width, height, samples,
jpeg_tables=jpeg_tables),
dtype=np.uint8)
elif compression == COMPRESSION_ZSTD:
return np.frombuffer(zstd_decompress(data), dtype=np.uint8)
return np.frombuffer(
zstd_decompress(data, expected_size), dtype=np.uint8)
elif compression == COMPRESSION_JPEG2000:
return np.frombuffer(
jpeg2000_decompress(data, width, height, samples), dtype=np.uint8)
elif compression == COMPRESSION_LZ4:
return np.frombuffer(lz4_decompress(data), dtype=np.uint8)
return np.frombuffer(
lz4_decompress(data, expected_size), dtype=np.uint8)
elif compression == COMPRESSION_LERC:
return np.frombuffer(
lerc_decompress(data, width, height, samples), dtype=np.uint8)
Expand Down
Loading
Loading