Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Bug Fixes

- mark incomplete MXNet scans inconclusive instead of clean
- harden manifest parse boundaries around malformed metadata
- recover malformed Jinja template configs as inconclusive scan outcomes
- traverse NeMo YAML list configs when checking suspicious targets
Expand Down
59 changes: 40 additions & 19 deletions modelaudit/scanners/mxnet_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from modelaudit.detectors.suspicious_symbols import EXECUTABLE_SIGNATURES

from .base import BaseScanner, IssueSeverity, ScanResult
from .base import INCONCLUSIVE_SCAN_OUTCOME, BaseScanner, IssueSeverity, ScanResult

MAX_SYMBOL_READ_BYTES = 10 * 1024 * 1024
MAX_PARAMS_READ_BYTES = 10 * 1024 * 1024
Expand Down Expand Up @@ -123,6 +123,11 @@
)


def _scan_result_has_security_findings(result: ScanResult) -> bool:
"""Return True when the result includes WARNING/CRITICAL findings."""
return any(issue.severity in {IssueSeverity.WARNING, IssueSeverity.CRITICAL} for issue in result.issues)


class MXNetScanner(BaseScanner):
"""Scanner for MXNet symbol graph and params artifacts."""

Expand All @@ -140,28 +145,14 @@ def can_handle(cls, path: str) -> bool:
if suffix == ".params":
return cls._is_mxnet_params_filename(path_obj.name)

if suffix == ".json" and path_obj.name.lower().endswith("-symbol.json"):
return cls._is_mxnet_symbol_graph(path_obj)

return False
# Route MXNet symbol artifacts by their framework filename convention so
# malformed graphs reach scan() and fail closed as inconclusive.
return suffix == ".json" and path_obj.name.lower().endswith("-symbol.json")

@classmethod
def _is_mxnet_params_filename(cls, filename: str) -> bool:
return bool(PARAMS_NAME_RE.match(filename))

@classmethod
def _is_mxnet_symbol_graph(cls, path: Path) -> bool:
try:
raw_bytes, truncated = cls._read_bounded_bytes(path, MAX_SYMBOL_READ_BYTES)
if truncated:
return False

payload = json.loads(raw_bytes.decode("utf-8"))
except (OSError, UnicodeDecodeError, json.JSONDecodeError, ValueError, TypeError):
return False

return cls._has_valid_symbol_structure(payload)

@classmethod
def _has_valid_symbol_structure(cls, payload: Any) -> bool:
if not isinstance(payload, dict):
Expand Down Expand Up @@ -210,11 +201,32 @@ def scan(self, path: str) -> ScanResult:
location=path,
details={"extension": suffix},
)
self._mark_inconclusive_scan_result(result, "mxnet_unsupported_extension")
analysis_complete = False

result.finish(success=(not result.has_errors) and analysis_complete)
self._finish_mxnet_result(result, analysis_complete=analysis_complete)
return result

def _mark_inconclusive_scan_result(self, result: ScanResult, reason: str) -> None:
"""Mark MXNet analysis as incomplete for aggregate exit-code handling."""
existing_reasons = result.metadata.get("scan_outcome_reasons")
reasons = existing_reasons if isinstance(existing_reasons, list) else []
if reason not in reasons:
reasons.append(reason)

result.metadata["scan_outcome"] = INCONCLUSIVE_SCAN_OUTCOME
result.metadata["scan_outcome_reasons"] = reasons
result.metadata["analysis_incomplete"] = True

def _finish_mxnet_result(self, result: ScanResult, *, analysis_complete: bool) -> None:
"""Fail closed for incomplete MXNet scans unless security findings were recovered."""
has_security_findings = _scan_result_has_security_findings(result)
if result.metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME and not has_security_findings:
result.finish(success=False)
return

result.finish(success=(not result.has_errors) and (analysis_complete or has_security_findings))

def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
path_obj = Path(path)
try:
Expand All @@ -228,6 +240,7 @@ def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
location=path,
details={"exception": str(exc), "exception_type": type(exc).__name__},
)
self._mark_inconclusive_scan_result(result, "mxnet_symbol_read_failed")
return False

result.bytes_scanned += len(raw_bytes)
Expand All @@ -240,6 +253,7 @@ def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
location=path,
details={"max_bytes": MAX_SYMBOL_READ_BYTES},
)
self._mark_inconclusive_scan_result(result, "mxnet_symbol_truncated")

if not raw_bytes:
result.add_check(
Expand All @@ -249,6 +263,7 @@ def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
severity=IssueSeverity.INFO,
location=path,
)
self._mark_inconclusive_scan_result(result, "mxnet_symbol_empty")
return False

try:
Expand All @@ -262,6 +277,7 @@ def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
location=path,
details={"exception": str(exc), "exception_type": type(exc).__name__},
)
self._mark_inconclusive_scan_result(result, "mxnet_symbol_parse_failed")
return False

if not self._has_valid_symbol_structure(payload):
Expand All @@ -272,6 +288,7 @@ def _scan_symbol_graph(self, path: str, result: ScanResult) -> bool:
severity=IssueSeverity.INFO,
location=path,
)
self._mark_inconclusive_scan_result(result, "mxnet_symbol_invalid_structure")
return False

nodes = payload.get("nodes", [])
Expand Down Expand Up @@ -338,6 +355,7 @@ def _scan_params_blob(self, path: str, result: ScanResult) -> bool:
location=path,
details={"exception": str(exc), "exception_type": type(exc).__name__},
)
self._mark_inconclusive_scan_result(result, "mxnet_params_read_failed")
return False

result.bytes_scanned += len(raw_bytes)
Expand All @@ -351,6 +369,7 @@ def _scan_params_blob(self, path: str, result: ScanResult) -> bool:
location=path,
details={"max_bytes": MAX_PARAMS_READ_BYTES},
)
self._mark_inconclusive_scan_result(result, "mxnet_params_truncated")

if not raw_bytes:
result.add_check(
Expand All @@ -360,6 +379,7 @@ def _scan_params_blob(self, path: str, result: ScanResult) -> bool:
severity=IssueSeverity.INFO,
location=path,
)
self._mark_inconclusive_scan_result(result, "mxnet_params_empty")
return False

if len(raw_bytes) < MIN_PARAMS_SIZE_BYTES:
Expand All @@ -371,6 +391,7 @@ def _scan_params_blob(self, path: str, result: ScanResult) -> bool:
location=path,
details={"size_bytes": len(raw_bytes), "minimum_expected_bytes": MIN_PARAMS_SIZE_BYTES},
)
self._mark_inconclusive_scan_result(result, "mxnet_params_truncated")

self._scan_params_signatures(path, raw_bytes, result)
self._scan_params_text_payloads(path, raw_bytes, result)
Expand Down
142 changes: 139 additions & 3 deletions tests/scanners/test_mxnet_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import struct
from pathlib import Path

import pytest

import modelaudit.scanners.mxnet_scanner as mxnet_scanner
from modelaudit.core import determine_exit_code, scan_model_directory_or_file
from modelaudit.models import ModelAuditResultModel
from modelaudit.scanners import get_scanner_for_file
from modelaudit.scanners.base import IssueSeverity
from modelaudit.scanners.base import INCONCLUSIVE_SCAN_OUTCOME, IssueSeverity, ScanResult
from modelaudit.scanners.mxnet_scanner import MXNetScanner


Expand Down Expand Up @@ -40,6 +45,22 @@ def _write_params_file(path: Path, values: tuple[float, ...] | None = None) -> N
path.write_bytes(struct.pack(f"<{len(tensor_values)}f", *tensor_values))


def _assert_inconclusive_result(result: ScanResult, reason: str) -> None:
assert result.success is False
assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME
assert reason in result.metadata["scan_outcome_reasons"]
assert result.metadata["analysis_incomplete"] is True


def _assert_aggregate_inconclusive(result: ModelAuditResultModel, path: Path, reason: str) -> None:
metadata = result.file_metadata[str(path)]
assert metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME
assert reason in metadata["scan_outcome_reasons"]
assert metadata["analysis_incomplete"] is True
assert result.success is False
assert determine_exit_code(result) == 2


def test_mxnet_scanner_can_handle_symbol_and_params(tmp_path: Path) -> None:
symbol_path = tmp_path / "model-symbol.json"
params_path = tmp_path / "model-0000.params"
Expand All @@ -51,7 +72,7 @@ def test_mxnet_scanner_can_handle_symbol_and_params(tmp_path: Path) -> None:


def test_mxnet_scanner_rejects_non_mxnet_files(tmp_path: Path) -> None:
fake_symbol = tmp_path / "fake-symbol.json"
fake_symbol = tmp_path / "fake.json"
fake_symbol.write_text('{"not": "mxnet"}', encoding="utf-8")
bad_params_name = tmp_path / "weights.params"
bad_params_name.write_bytes(b"raw bytes")
Expand All @@ -60,6 +81,13 @@ def test_mxnet_scanner_rejects_non_mxnet_files(tmp_path: Path) -> None:
assert not MXNetScanner.can_handle(str(bad_params_name))


def test_mxnet_scanner_routes_malformed_symbol_for_fail_closed_scan(tmp_path: Path) -> None:
symbol_path = tmp_path / "malformed-symbol.json"
symbol_path.write_text('{"nodes": [', encoding="utf-8")

assert MXNetScanner.can_handle(str(symbol_path))


def test_mxnet_symbol_scan_with_valid_pair_has_no_security_findings(tmp_path: Path) -> None:
symbol_path = tmp_path / "resnet-symbol.json"
params_path = tmp_path / "resnet-0000.params"
Expand Down Expand Up @@ -149,10 +177,118 @@ def test_mxnet_scanner_handles_corrupt_params_file(tmp_path: Path) -> None:

result = MXNetScanner().scan(str(params_path))

assert not result.success
_assert_inconclusive_result(result, "mxnet_params_empty")
assert any("MXNet params blob is empty" in issue.message for issue in result.issues)


def test_mxnet_corrupt_params_aggregate_exit_code_is_inconclusive(tmp_path: Path) -> None:
params_path = tmp_path / "corrupt-0000.params"
params_path.write_bytes(b"")

result = scan_model_directory_or_file(str(params_path), cache_scan_results=False)

_assert_aggregate_inconclusive(result, params_path, "mxnet_params_empty")


def test_mxnet_truncated_params_aggregate_exit_code_is_inconclusive(tmp_path: Path) -> None:
params_path = tmp_path / "truncated-0000.params"
params_path.write_bytes(b"short")

direct_result = MXNetScanner().scan(str(params_path))
aggregate_result = scan_model_directory_or_file(str(params_path), cache_scan_results=False)

_assert_inconclusive_result(direct_result, "mxnet_params_truncated")
_assert_aggregate_inconclusive(aggregate_result, params_path, "mxnet_params_truncated")


def test_mxnet_malformed_symbol_scan_is_inconclusive(tmp_path: Path) -> None:
symbol_path = tmp_path / "broken-symbol.json"
symbol_path.write_text('{"nodes": [', encoding="utf-8")

result = MXNetScanner().scan(str(symbol_path))
aggregate_result = scan_model_directory_or_file(str(symbol_path), cache_scan_results=False)

_assert_inconclusive_result(result, "mxnet_symbol_parse_failed")
_assert_aggregate_inconclusive(aggregate_result, symbol_path, "mxnet_symbol_parse_failed")
assert any(check.name == "MXNet Symbol Parse" for check in result.checks)


def test_mxnet_unsupported_extension_scan_is_inconclusive(tmp_path: Path) -> None:
artifact_path = tmp_path / "model.mxnet"
artifact_path.write_bytes(b"mxnet-ish content")

result = MXNetScanner().scan(str(artifact_path))

_assert_inconclusive_result(result, "mxnet_unsupported_extension")


def test_mxnet_symbol_read_failure_scan_is_inconclusive(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
symbol_path = tmp_path / "unreadable-symbol.json"
symbol_path.write_text("{}", encoding="utf-8")

def raise_os_error(path: Path, max_bytes: int) -> tuple[bytes, bool]:
raise OSError("symbol read failed")

monkeypatch.setattr(MXNetScanner, "_read_bounded_bytes", staticmethod(raise_os_error))

result = MXNetScanner().scan(str(symbol_path))

_assert_inconclusive_result(result, "mxnet_symbol_read_failed")


def test_mxnet_params_read_failure_scan_is_inconclusive(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
params_path = tmp_path / "unreadable-0000.params"
params_path.write_bytes(b"placeholder params")

def raise_os_error(path: Path, max_bytes: int) -> tuple[bytes, bool]:
raise OSError("params read failed")

monkeypatch.setattr(MXNetScanner, "_read_bounded_bytes", staticmethod(raise_os_error))

result = MXNetScanner().scan(str(params_path))

_assert_inconclusive_result(result, "mxnet_params_read_failed")


def test_mxnet_truncated_symbol_scan_is_inconclusive(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
symbol_path = tmp_path / "truncated-symbol.json"
_write_symbol_file(symbol_path)
monkeypatch.setattr(mxnet_scanner, "MAX_SYMBOL_READ_BYTES", 8)

result = MXNetScanner().scan(str(symbol_path))

_assert_inconclusive_result(result, "mxnet_symbol_truncated")


def test_mxnet_empty_symbol_scan_is_inconclusive(tmp_path: Path) -> None:
symbol_path = tmp_path / "empty-symbol.json"
symbol_path.write_text("", encoding="utf-8")

result = MXNetScanner().scan(str(symbol_path))

_assert_inconclusive_result(result, "mxnet_symbol_empty")


def test_mxnet_invalid_symbol_structure_scan_is_inconclusive(tmp_path: Path) -> None:
symbol_path = tmp_path / "invalid-symbol.json"
symbol_path.write_text(json.dumps({"nodes": [], "arg_nodes": [], "heads": []}), encoding="utf-8")

result = MXNetScanner().scan(str(symbol_path))
aggregate_result = scan_model_directory_or_file(str(symbol_path), cache_scan_results=False)

_assert_inconclusive_result(result, "mxnet_symbol_invalid_structure")
_assert_aggregate_inconclusive(aggregate_result, symbol_path, "mxnet_symbol_invalid_structure")


def test_mxnet_params_numeric_blob_does_not_trigger_false_positives(tmp_path: Path) -> None:
symbol_path = tmp_path / "clean-symbol.json"
params_path = tmp_path / "clean-0000.params"
Expand Down
Loading