diff --git a/fickling/analysis.py b/fickling/analysis.py index 7903f18..840c9b8 100644 --- a/fickling/analysis.py +++ b/fickling/analysis.py @@ -2,7 +2,7 @@ import json from abc import ABC, abstractmethod -from ast import unparse +from ast import Import, ImportFrom, unparse from collections import defaultdict from collections.abc import Iterable, Iterator from enum import Enum @@ -415,6 +415,42 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: ) +class ScannerDeactivation(Analysis): + """Detects pickles that attempt to import pickle security scanning libraries.""" + + SCANNER_MODULES: frozenset[str] = frozenset( + { + "fickling", + "picklescan", + "modelscan", + "model_unpickler", + "saferpickle", + "modelaudit", + } + ) + + def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: + for node in context.pickled.properties.imports: + module_name = self._get_top_level_module(node) + if module_name and module_name in self.SCANNER_MODULES: + shortened, _ = context.shorten_code(node) + yield AnalysisResult( + Severity.OVERTLY_MALICIOUS, + f"`{shortened}` imports a pickle security scanning library; " + "this is an attempt to deactivate or interfere with security analysis", + "ScannerDeactivation", + trigger=shortened, + ) + + @staticmethod + def _get_top_level_module(node: Import | ImportFrom) -> str | None: + if isinstance(node, ImportFrom) and node.module: + return node.module.split(".")[0] + if isinstance(node, Import) and node.names: + return node.names[0].name.split(".")[0] + return None + + class AnalysisResults: def __init__(self, pickled: Pickled, results: Iterable[AnalysisResult]): self.pickled: Pickled = pickled diff --git a/test/test_scanner_deactivation_analysis.py b/test/test_scanner_deactivation_analysis.py new file mode 100644 index 0000000..9f79e90 --- /dev/null +++ b/test/test_scanner_deactivation_analysis.py @@ -0,0 +1,97 @@ +from unittest import TestCase + +import fickling.fickle as op +from fickling.analysis import ScannerDeactivation, Severity, check_safety +from fickling.fickle import Pickled + + +class TestScannerDeactivationAnalysis(TestCase): + def test_fickling_remove_hook(self): + """Pickle calling fickling.hook.remove_hook to strip safety hooks.""" + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("fickling.hook"), + op.ShortBinUnicode("remove_hook"), + op.StackGlobal(), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + res = check_safety(pickled) + self.assertEqual(res.severity, Severity.OVERTLY_MALICIOUS) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNotNone(detailed.get("ScannerDeactivation")) + + def test_fickling_deactivate_safe_ml(self): + """Pickle calling fickling.hook.deactivate_safe_ml_environment.""" + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("fickling.hook"), + op.ShortBinUnicode("deactivate_safe_ml_environment"), + op.StackGlobal(), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + res = check_safety(pickled) + self.assertEqual(res.severity, Severity.OVERTLY_MALICIOUS) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNotNone(detailed.get("ScannerDeactivation")) + + def test_fickling_remove_hook_via_global_opcode(self): + """Older Global opcode path should also trigger detection.""" + pickled = Pickled( + [ + op.Global.create("fickling.hook", "remove_hook"), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + res = check_safety(pickled) + self.assertEqual(res.severity, Severity.OVERTLY_MALICIOUS) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNotNone(detailed.get("ScannerDeactivation")) + + def test_benign_import_not_flagged(self): + """A benign stdlib import should not trigger ScannerDeactivation.""" + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("collections"), + op.ShortBinUnicode("OrderedDict"), + op.StackGlobal(), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + res = check_safety(pickled) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNone(detailed.get("ScannerDeactivation")) + + def test_all_scanner_modules_covered(self): + """Every module in SCANNER_MODULES should be detected by ScannerDeactivation.""" + for module in ScannerDeactivation.SCANNER_MODULES: + with self.subTest(module=module): + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode(module), + op.ShortBinUnicode("some_func"), + op.StackGlobal(), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + res = check_safety(pickled) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNotNone( + detailed.get("ScannerDeactivation"), + f"{module} was not detected by ScannerDeactivation", + )