diff --git a/CHANGELOG.md b/CHANGELOG.md index fc10c0c2..219f1e2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bug Fixes +- preserve picklescan stack state across reused scanner runs - mark partial streaming scans inconclusive when large-file streaming coverage is incomplete - harden native code detection in model scanners ([#897](https://github.com/promptfoo/modelaudit/issues/897)) ([f4f661a](https://github.com/promptfoo/modelaudit/commit/f4f661a09be0032e15aa8895864413e3878233f8)) diff --git a/packages/modelaudit-picklescan/src/modelaudit_picklescan/engine/scanner.py b/packages/modelaudit-picklescan/src/modelaudit_picklescan/engine/scanner.py index 25f26cd1..f296bbe2 100644 --- a/packages/modelaudit-picklescan/src/modelaudit_picklescan/engine/scanner.py +++ b/packages/modelaudit-picklescan/src/modelaudit_picklescan/engine/scanner.py @@ -123,7 +123,6 @@ def __init__( self.deadline = deadline if deadline is not None else time.monotonic() + options.timeout_s self.stack: list[Any] = [] self.memo: dict[int | str, Any] = {} - self.next_memo_index = 0 self.findings: list[Finding] = [] self.notices: list[Notice] = [] self.errors: list[ScanError] = [] @@ -172,7 +171,6 @@ def run(self) -> None: self.first_pickle_end_pos = self.position_offset + self.stream.tell() self.stack.clear() self.memo.clear() - self.next_memo_index = 0 break if not parsed_opcode: @@ -382,6 +380,11 @@ def _handle_opcode(self, op_name: str, arg: Any, position: int) -> None: self.stack.pop() return + if op_name == "DUP": + if self.stack: + self.stack.append(self.stack[-1]) + return + if op_name == "POP_MARK": self._pop_to_mark() return @@ -406,15 +409,25 @@ def _handle_opcode(self, op_name: str, arg: Any, position: int) -> None: self._collapse_top_n(3) return - if op_name in _MEMO_WRITE_OPCODES: + if op_name in {"APPEND", "SETITEM"}: if self.stack: + self.stack.pop() + if op_name == "SETITEM" and self.stack: + self.stack.pop() + return + + if op_name in {"APPENDS", "SETITEMS", "ADDITEMS"}: + self._pop_to_mark() + return + + if op_name in _MEMO_WRITE_OPCODES: + if self.stack and isinstance(arg, int): self.memo[arg] = self.stack[-1] return if op_name == "MEMOIZE": if self.stack: - self.memo[self.next_memo_index] = self.stack[-1] - self.next_memo_index += 1 + self.memo[len(self.memo)] = self.stack[-1] return if op_name in _MEMO_READ_OPCODES: diff --git a/packages/modelaudit-picklescan/tests/test_api.py b/packages/modelaudit-picklescan/tests/test_api.py index f9afa4d0..123270c1 100644 --- a/packages/modelaudit-picklescan/tests/test_api.py +++ b/packages/modelaudit-picklescan/tests/test_api.py @@ -101,6 +101,49 @@ def test_scan_bytes_attributes_reduce_calls_to_the_callable_operand_not_nested_a ) +@pytest.mark.parametrize( + ("payload", "source"), + [ + (b"cbuiltins\nlen\n}cos\nsystem\nK\x01s\x85R.", "setitem-args.pkl"), + (b"cbuiltins\nlen\n}(cos\nsystem\nK\x01u\x85R.", "setitems-args.pkl"), + ], +) +def test_scan_bytes_dict_mutation_operands_do_not_become_reduce_call_targets(payload: bytes, source: str) -> None: + report = scan_bytes(payload, source=source) + + assert report.status == ScanStatus.COMPLETE + assert report.verdict == SafetyVerdict.MALICIOUS + assert any( + finding.rule_code == "DANGEROUS_GLOBAL" and finding.details.get("import_reference") in SYSTEM_GLOBALS + for finding in report.findings + ) + assert not any( + finding.rule_code == "DANGEROUS_CALL" and finding.details.get("import_reference") in SYSTEM_GLOBALS + for finding in report.findings + ) + + +@pytest.mark.parametrize( + ("payload", "source"), + [ + (b"\x80\x04cbuiltins\nlen\nq\x00cos\nsystem\n\x94h\x01\x8c\x04echo\x85R.", "memoize-after-put.pkl"), + (b"\x80\x04cbuiltins\nlen\nqdcos\nsystem\n\x94h\x01\x8c\x04echo\x85R.", "memoize-after-sparse-put.pkl"), + ], +) +def test_scan_bytes_memoize_index_uses_runtime_memo_size_after_explicit_memo_write( + payload: bytes, + source: str, +) -> None: + report = scan_bytes(payload, source=source) + + assert report.status == ScanStatus.COMPLETE + assert report.verdict == SafetyVerdict.MALICIOUS + assert any( + finding.rule_code == "DANGEROUS_CALL" and finding.details.get("import_reference") in SYSTEM_GLOBALS + for finding in report.findings + ) + + def test_scan_stream_uses_explicit_source_and_does_not_leak_prior_scan_state() -> None: scanner = PickleScanner()