Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Bug Fixes

- preserve picklescan stack state across reused scanner runs
- mark partial streaming scans inconclusive when large-file streaming coverage is incomplete
- harden native code detection in model scanners ([#897](https://github.com/promptfoo/modelaudit/issues/897)) ([f4f661a](https://github.com/promptfoo/modelaudit/commit/f4f661a09be0032e15aa8895864413e3878233f8))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(
self.deadline = deadline if deadline is not None else time.monotonic() + options.timeout_s
self.stack: list[Any] = []
self.memo: dict[int | str, Any] = {}
self.next_memo_index = 0
self.findings: list[Finding] = []
self.notices: list[Notice] = []
self.errors: list[ScanError] = []
Expand Down Expand Up @@ -172,7 +171,6 @@ def run(self) -> None:
self.first_pickle_end_pos = self.position_offset + self.stream.tell()
self.stack.clear()
self.memo.clear()
self.next_memo_index = 0
break

if not parsed_opcode:
Expand Down Expand Up @@ -382,6 +380,11 @@ def _handle_opcode(self, op_name: str, arg: Any, position: int) -> None:
self.stack.pop()
return

if op_name == "DUP":
if self.stack:
self.stack.append(self.stack[-1])
return

if op_name == "POP_MARK":
self._pop_to_mark()
return
Expand All @@ -406,15 +409,25 @@ def _handle_opcode(self, op_name: str, arg: Any, position: int) -> None:
self._collapse_top_n(3)
return

if op_name in _MEMO_WRITE_OPCODES:
if op_name in {"APPEND", "SETITEM"}:
if self.stack:
self.stack.pop()
if op_name == "SETITEM" and self.stack:
self.stack.pop()
return

if op_name in {"APPENDS", "SETITEMS", "ADDITEMS"}:
self._pop_to_mark()
return

if op_name in _MEMO_WRITE_OPCODES:
if self.stack and isinstance(arg, int):
self.memo[arg] = self.stack[-1]
return

if op_name == "MEMOIZE":
if self.stack:
self.memo[self.next_memo_index] = self.stack[-1]
self.next_memo_index += 1
self.memo[len(self.memo)] = self.stack[-1]
return

if op_name in _MEMO_READ_OPCODES:
Expand Down
43 changes: 43 additions & 0 deletions packages/modelaudit-picklescan/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,49 @@ def test_scan_bytes_attributes_reduce_calls_to_the_callable_operand_not_nested_a
)


@pytest.mark.parametrize(
("payload", "source"),
[
(b"cbuiltins\nlen\n}cos\nsystem\nK\x01s\x85R.", "setitem-args.pkl"),
(b"cbuiltins\nlen\n}(cos\nsystem\nK\x01u\x85R.", "setitems-args.pkl"),
],
)
def test_scan_bytes_dict_mutation_operands_do_not_become_reduce_call_targets(payload: bytes, source: str) -> None:
report = scan_bytes(payload, source=source)

assert report.status == ScanStatus.COMPLETE
assert report.verdict == SafetyVerdict.MALICIOUS
assert any(
finding.rule_code == "DANGEROUS_GLOBAL" and finding.details.get("import_reference") in SYSTEM_GLOBALS
for finding in report.findings
)
assert not any(
finding.rule_code == "DANGEROUS_CALL" and finding.details.get("import_reference") in SYSTEM_GLOBALS
for finding in report.findings
)


@pytest.mark.parametrize(
("payload", "source"),
[
(b"\x80\x04cbuiltins\nlen\nq\x00cos\nsystem\n\x94h\x01\x8c\x04echo\x85R.", "memoize-after-put.pkl"),
(b"\x80\x04cbuiltins\nlen\nqdcos\nsystem\n\x94h\x01\x8c\x04echo\x85R.", "memoize-after-sparse-put.pkl"),
],
)
def test_scan_bytes_memoize_index_uses_runtime_memo_size_after_explicit_memo_write(
payload: bytes,
source: str,
) -> None:
report = scan_bytes(payload, source=source)

assert report.status == ScanStatus.COMPLETE
assert report.verdict == SafetyVerdict.MALICIOUS
assert any(
finding.rule_code == "DANGEROUS_CALL" and finding.details.get("import_reference") in SYSTEM_GLOBALS
for finding in report.findings
)


def test_scan_stream_uses_explicit_source_and_does_not_leak_prior_scan_state() -> None:
scanner = PickleScanner()

Expand Down
Loading