Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Bug Fixes

- mark malformed SafeTensors scans inconclusive instead of clean
- preserve picklescan stack state across reused scanner runs
- mark partial streaming scans inconclusive when large-file streaming coverage is incomplete
- harden native code detection in model scanners ([#897](https://github.com/promptfoo/modelaudit/issues/897)) ([f4f661a](https://github.com/promptfoo/modelaudit/commit/f4f661a09be0032e15aa8895864413e3878233f8))
Expand Down
180 changes: 147 additions & 33 deletions modelaudit/scanners/safetensors_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from modelaudit.detectors.suspicious_symbols import SUSPICIOUS_METADATA_PATTERNS

from .base import BaseScanner, IssueSeverity, ScanResult
from .base import INCONCLUSIVE_SCAN_OUTCOME, BaseScanner, IssueSeverity, ScanResult

# Map SafeTensors dtypes to byte sizes for integrity checking
_DTYPE_SIZES = {
Expand All @@ -31,6 +31,9 @@
"U64": 8,
}
MAX_HEADER_BYTES = 16 * 1024 * 1024
SAFETENSORS_HEADER_INCONCLUSIVE_REASON = "safetensors_header_validation_failed"
SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON = "safetensors_structure_validation_failed"
SAFETENSORS_HEADER_LIMIT_INCONCLUSIVE_REASON = "safetensors_header_size_limit_exceeded"


class SafeTensorsScanner(BaseScanner):
Expand All @@ -40,6 +43,26 @@ class SafeTensorsScanner(BaseScanner):
description = "Scans SafeTensors model files for integrity issues"
supported_extensions: ClassVar[list[str]] = [".safetensors"]

@staticmethod
def _mark_inconclusive(result: ScanResult, reason: str) -> None:
"""Mark malformed safetensors framing as an explicit inconclusive scan."""
result.metadata["analysis_incomplete"] = True
result.metadata["scan_outcome"] = INCONCLUSIVE_SCAN_OUTCOME

reasons = result.metadata.get("scan_outcome_reasons")
if not isinstance(reasons, list):
reasons = []
result.metadata["scan_outcome_reasons"] = reasons
if reason not in reasons:
reasons.append(reason)

@staticmethod
def _is_valid_shape(shape: Any) -> bool:
"""Return True when shape is a safetensors-compatible list of non-negative ints."""
return isinstance(shape, list) and all(
isinstance(dim, int) and not isinstance(dim, bool) and dim >= 0 for dim in shape
)

@classmethod
def can_handle(cls, path: str) -> bool:
"""Check if this scanner can handle the given path."""
Expand Down Expand Up @@ -70,6 +93,7 @@ def scan(self, path: str) -> ScanResult:
result = self._create_result()
file_size = self.get_file_size(path)
result.metadata["file_size"] = file_size
structural_validation_failed = False

# Add file integrity check for compliance
self.add_file_integrity_check(path, result)
Expand All @@ -87,6 +111,7 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"bytes_read": len(header_len_bytes), "required": 8},
)
self._mark_inconclusive(result, SAFETENSORS_HEADER_INCONCLUSIVE_REASON)
result.finish(success=False)
return result

Expand All @@ -101,6 +126,7 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"header_len": header_len, "max_allowed": file_size - 8},
)
self._mark_inconclusive(result, SAFETENSORS_HEADER_INCONCLUSIVE_REASON)
result.finish(success=False)
return result
else:
Expand Down Expand Up @@ -128,6 +154,7 @@ def scan(self, path: str) -> ScanResult:
),
)
result.metadata["analysis_incomplete"] = True
self._mark_inconclusive(result, SAFETENSORS_HEADER_LIMIT_INCONCLUSIVE_REASON)
result.bytes_scanned = file_size
result.finish(success=True)
return result
Expand All @@ -150,6 +177,7 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"bytes_read": len(header_bytes), "expected": header_len},
)
self._mark_inconclusive(result, SAFETENSORS_HEADER_INCONCLUSIVE_REASON)
result.finish(success=False)
return result

Expand All @@ -161,6 +189,7 @@ def scan(self, path: str) -> ScanResult:
severity=IssueSeverity.INFO,
location=path,
)
self._mark_inconclusive(result, SAFETENSORS_HEADER_INCONCLUSIVE_REASON)
result.finish(success=False)
return result
else:
Expand All @@ -173,7 +202,7 @@ def scan(self, path: str) -> ScanResult:

try:
header = json.loads(header_bytes.decode("utf-8"))
except json.JSONDecodeError as e:
except (UnicodeDecodeError, json.JSONDecodeError) as e:
result.add_check(
name="SafeTensors JSON Parse",
passed=False,
Expand All @@ -183,9 +212,9 @@ def scan(self, path: str) -> ScanResult:
details={"exception": str(e), "exception_type": type(e).__name__},
why="SafeTensors header contained invalid JSON.",
)
self._mark_inconclusive(result, SAFETENSORS_HEADER_INCONCLUSIVE_REASON)
result.finish(success=False)
return result

tensor_names = [k for k in header if k != "__metadata__"]
result.metadata["tensor_count"] = len(tensor_names)
result.metadata["tensors"] = tensor_names
Expand Down Expand Up @@ -213,13 +242,40 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"tensor": name, "actual_type": type(info).__name__, "expected_type": "dict"},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

begin, end = info.get("data_offsets", [0, 0])
raw_offsets = info.get("data_offsets")
dtype = info.get("dtype")
shape = info.get("shape", [])

if not isinstance(begin, int) or not isinstance(end, int):
if not isinstance(raw_offsets, list) or len(raw_offsets) != 2:
result.add_check(
name="Tensor Offset Structure Validation",
passed=False,
message=f"Invalid data_offsets structure for {name}",
severity=IssueSeverity.INFO,
location=path,
details={
"tensor": name,
"actual_type": type(raw_offsets).__name__,
"expected_type": "list",
"expected_length": 2,
},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

begin, end = raw_offsets

if (
not isinstance(begin, int)
or isinstance(begin, bool)
or not isinstance(end, int)
or isinstance(end, bool)
):
result.add_check(
name="Tensor Offset Type Validation",
passed=False,
Expand All @@ -232,6 +288,8 @@ def scan(self, path: str) -> ScanResult:
"end_type": type(end).__name__,
},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

if begin < 0 or end <= begin or end > data_size:
Expand All @@ -256,32 +314,82 @@ def scan(self, path: str) -> ScanResult:
offsets.append((begin, end))

# Validate dtype/shape size
if not isinstance(dtype, str) or dtype not in _DTYPE_SIZES:
result.add_check(
name="Tensor Dtype Validation",
passed=False,
message=f"Invalid dtype for tensor {name}",
severity=IssueSeverity.INFO,
location=path,
details={
"tensor": name,
"dtype": dtype,
"actual_type": type(dtype).__name__,
},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

if not self._is_valid_shape(shape):
result.add_check(
name="Tensor Shape Validation",
passed=False,
message=f"Invalid shape for tensor {name}",
severity=IssueSeverity.INFO,
location=path,
details={
"tensor": name,
"shape": shape,
"actual_type": type(shape).__name__,
},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

expected_size = self._expected_size(dtype, shape)
if expected_size is not None:
if expected_size != end - begin:
result.add_check(
name="Tensor Size Consistency Check",
passed=False,
message=f"Size mismatch for tensor {name}",
severity=IssueSeverity.CRITICAL,
location=path,
details={
"tensor": name,
"expected_size": expected_size,
"actual_size": end - begin,
},
)
else:
result.add_check(
name="Tensor Size Consistency Check",
passed=True,
message=f"Tensor {name} size matches dtype/shape",
location=path,
details={
"tensor": name,
"size": expected_size,
},
)
if expected_size is None:
result.add_check(
name="Tensor Size Computation Check",
passed=False,
message=f"Unable to compute expected size for tensor {name}",
severity=IssueSeverity.INFO,
location=path,
details={
"tensor": name,
"dtype": dtype,
"shape": shape,
},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
continue

if expected_size != end - begin:
result.add_check(
name="Tensor Size Consistency Check",
passed=False,
message=f"Size mismatch for tensor {name}",
severity=IssueSeverity.CRITICAL,
location=path,
details={
"tensor": name,
"expected_size": expected_size,
"actual_size": end - begin,
},
)
else:
result.add_check(
name="Tensor Size Consistency Check",
passed=True,
message=f"Tensor {name} size matches dtype/shape",
location=path,
details={
"tensor": name,
"size": expected_size,
},
)

# Check offset continuity
offsets.sort(key=lambda x: x[0])
Expand All @@ -298,6 +406,8 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"gap_at": begin, "expected": last_end},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True
break
last_end = end

Expand All @@ -320,6 +430,8 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"last_offset": last_end, "data_size": data_size},
)
self._mark_inconclusive(result, SAFETENSORS_STRUCTURE_INCONCLUSIVE_REASON)
structural_validation_failed = True

# Check metadata
metadata = header.get("__metadata__", {})
Expand Down Expand Up @@ -388,18 +500,20 @@ def scan(self, path: str) -> ScanResult:
result.finish(success=False)
return result

result.finish(success=not result.has_errors)
result.finish(success=not result.has_errors and not structural_validation_failed)
return result

@staticmethod
def _expected_size(dtype: str | None, shape: list[int]) -> int | None:
def _expected_size(dtype: str | None, shape: Any) -> int | None:
"""Return expected tensor byte size from dtype and shape."""
if dtype not in _DTYPE_SIZES:
return None
if not isinstance(shape, list):
return None
size = _DTYPE_SIZES[dtype]
total = 1
for dim in shape:
if not isinstance(dim, int) or dim < 0:
if not isinstance(dim, int) or isinstance(dim, bool) or dim < 0:
return None
total *= dim
return total * size
Expand Down
Loading
Loading