diff --git a/changelog.d/secure-yaml-and-breakdowns.fixed.md b/changelog.d/secure-yaml-and-breakdowns.fixed.md new file mode 100644 index 000000000..586241d33 --- /dev/null +++ b/changelog.d/secure-yaml-and-breakdowns.fixed.md @@ -0,0 +1 @@ +Use safe YAML loaders and remove dynamic eval from parameter breakdown handling. diff --git a/policyengine_core/data_structures/parameter_node_metadata.py b/policyengine_core/data_structures/parameter_node_metadata.py index 0483d9cb0..e91ba2600 100644 --- a/policyengine_core/data_structures/parameter_node_metadata.py +++ b/policyengine_core/data_structures/parameter_node_metadata.py @@ -37,8 +37,8 @@ class ParameterNodeMetadata: metadata: breakdown: - region # If `region` is an Enum-type variable with possible values `[ENGLAND, WALES]` then these children will be added. - - range(1, 7) # This code is `eval`uated to produce the list `[1, 2, 3, 4, 5, 6]` which are then added as grandchildren. - - [True, False] # This list is added as great-grandchildren (using the same `eval` method as above). + - range(1, 7) # Safe dynamic form producing `[1, 2, 3, 4, 5, 6]`, which are then added as grandchildren. + - [True, False] # Literal collection form added as great-grandchildren. """ diff --git a/policyengine_core/parameters/config.py b/policyengine_core/parameters/config.py index 73982d8a3..7adec9d25 100644 --- a/policyengine_core/parameters/config.py +++ b/policyengine_core/parameters/config.py @@ -7,7 +7,7 @@ from policyengine_core.warnings import LibYAMLWarning try: - from yaml import CLoader as Loader + from yaml import CSafeLoader as Loader except ImportError: message = [ "libyaml is not installed in your environment.", @@ -17,7 +17,7 @@ ] warnings.warn(" ".join(message), LibYAMLWarning) from yaml import ( - Loader, + SafeLoader as Loader, ) # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) @@ -33,15 +33,30 @@ def date_constructor(_loader, node): def dict_no_duplicate_constructor(loader, node, deep=False): - keys = [key.value for key, value in node.value] + explicit_keys = {} + for key_node, _value_node in node.value: + if key_node.tag == "tag:yaml.org,2002:merge": + continue + key = loader.construct_object(key_node, deep=deep) + try: + if key in explicit_keys: + raise yaml.parser.ParserError( + "", node.start_mark, f"Found duplicate key '{key}'" + ) + except TypeError as exc: + raise yaml.constructor.ConstructorError( + "", node.start_mark, f"Found unhashable key '{key}'" + ) from exc + explicit_keys[key] = True - if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) - raise yaml.parser.ParserError( - "", node.start_mark, f"Found duplicate key '{duplicate}'" - ) + loader.flatten_mapping(node) + pairs = loader.construct_pairs(node, deep=deep) + mapping = {} - return loader.construct_mapping(node, deep) + for key, value in pairs: + mapping[key] = value + + return mapping yaml.add_constructor( diff --git a/policyengine_core/parameters/operations/homogenize_parameters.py b/policyengine_core/parameters/operations/homogenize_parameters.py index 04888300c..39164b721 100644 --- a/policyengine_core/parameters/operations/homogenize_parameters.py +++ b/policyengine_core/parameters/operations/homogenize_parameters.py @@ -1,3 +1,4 @@ +import ast import logging from typing import Any, Dict, List, Type @@ -6,6 +7,8 @@ from policyengine_core.parameters.parameter_node import ParameterNode from policyengine_core.variables import Variable +MAX_DYNAMIC_BREAKDOWN_VALUES = 10_000 + def homogenize_parameter_structures( root: ParameterNode, variables: Dict[str, Variable], default_value: Any = 0 @@ -43,6 +46,11 @@ def get_breakdown_variables(node: ParameterNode) -> List[str]: f"Invalid breakdown metadata for parameter {node.name}: {type(breakdown)}" ) return None + if len(breakdown) == 0: + logging.warning( + f"Invalid breakdown metadata for parameter {node.name}: empty list" + ) + return None return breakdown else: return None @@ -71,8 +79,7 @@ def homogenize_parameter_node( elif dtype == bool: possible_values = [True, False] else: - # Try to execute the breakdown as Python code - possible_values = list(eval(first_breakdown)) + possible_values = evaluate_dynamic_breakdown(first_breakdown) if not hasattr(node, "children"): node = ParameterNode( node.name, @@ -119,3 +126,84 @@ def homogenize_parameter_node( node.children[child], breakdown[1:], variables, default_value ) return node + + +def evaluate_dynamic_breakdown(expression: str) -> List[Any]: + """Safely evaluate a dynamic breakdown expression. + + The parameter metadata only needs literal collections and the documented + ``range(...)`` / ``list(range(...))`` forms. Anything else is rejected. + """ + + parsed = ast.parse(expression, mode="eval") + evaluated = evaluate_dynamic_breakdown_node(parsed.body) + if isinstance(evaluated, range): + validate_dynamic_breakdown_range_cardinality(evaluated, expression) + return list(evaluated) + if isinstance(evaluated, (list, tuple)): + validate_dynamic_breakdown_cardinality(len(evaluated), expression) + return list(evaluated) + if isinstance(evaluated, set): + validate_dynamic_breakdown_cardinality(len(evaluated), expression) + return list(evaluated) + raise ValueError( + f"Invalid dynamic breakdown expression '{expression}'. " + "Only literal collections and range() calls are allowed." + ) + + +def validate_dynamic_breakdown_cardinality(count: int, expression: str) -> None: + if count > MAX_DYNAMIC_BREAKDOWN_VALUES: + raise ValueError( + f"Dynamic breakdown expression '{expression}' produces {count} values, " + f"which exceeds the maximum of {MAX_DYNAMIC_BREAKDOWN_VALUES}." + ) + + +def validate_dynamic_breakdown_range_cardinality( + values: range, expression: str +) -> None: + try: + count = len(values) + except OverflowError as exc: + raise ValueError( + f"Dynamic breakdown expression '{expression}' produces too many values." + ) from exc + validate_dynamic_breakdown_cardinality(count, expression) + + +def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any: + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) + return [evaluate_dynamic_breakdown_node(element) for element in node.elts] + if isinstance(node, ast.Tuple): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) + return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts) + if isinstance(node, ast.Set): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) + return {evaluate_dynamic_breakdown_node(element) for element in node.elts} + if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)): + operand = evaluate_dynamic_breakdown_node(node.operand) + return operand if isinstance(node.op, ast.UAdd) else -operand + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if node.func.id == "range": + args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args] + if node.keywords: + raise ValueError("range() keyword arguments are not allowed") + result = range(*args) + validate_dynamic_breakdown_range_cardinality(result, ast.unparse(node)) + return result + if node.func.id == "list": + if len(node.args) != 1 or node.keywords: + raise ValueError("list() must contain a single positional argument") + evaluated = evaluate_dynamic_breakdown_node(node.args[0]) + if isinstance(evaluated, (range, list, tuple, set)): + return evaluated + raise ValueError( + "list() only supports range() and literal collection expressions" + ) + raise ValueError( + f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}" + ) diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 03276236b..6b8d4aefc 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -31,7 +31,7 @@ def import_yaml(): import yaml try: - from yaml import CLoader as Loader + from yaml import CSafeLoader as Loader except ImportError: log.warning( " " @@ -126,7 +126,7 @@ def __init__(self, *, tax_benefit_system, options, **kwargs): def collect(self): try: tests = yaml.load(self.path.open(), Loader=Loader) - except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): + except (yaml.YAMLError, TypeError): message = os.linesep.join( [ traceback.format_exc(), @@ -139,6 +139,11 @@ def collect(self): tests: List[Dict] = [tests] for test in tests: + if not isinstance(test, dict): + raise ValueError( + f"'{self.path}' is not a valid YAML test file. " + "Expected a mapping or a list of mappings." + ) if not self.should_ignore(test): yield YamlItem.from_parent( self, @@ -150,11 +155,19 @@ def collect(self): def should_ignore(self, test): name_filter = self.options.get("name_filter") + keywords = test.get("keywords", []) + if keywords is None: + keywords = [] + if not isinstance(keywords, list): + raise ValueError( + f"'{self.path}' is not a valid YAML test file. " + "'keywords' must be a list." + ) return ( name_filter is not None and name_filter not in os.path.splitext(self.fspath.basename)[0] and name_filter not in test.get("name", "") - and name_filter not in test.get("keywords", []) + and name_filter not in keywords ) diff --git a/tests/core/test_parameter_security.py b/tests/core/test_parameter_security.py new file mode 100644 index 000000000..a503ead9a --- /dev/null +++ b/tests/core/test_parameter_security.py @@ -0,0 +1,130 @@ +import pytest + +from policyengine_core.errors import ParameterParsingError +from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures +from policyengine_core.parameters.helpers import _load_yaml_file +from policyengine_core.parameters.operations.homogenize_parameters import ( + MAX_DYNAMIC_BREAKDOWN_VALUES, + evaluate_dynamic_breakdown, +) + + +def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch): + calls = [] + + monkeypatch.setattr( + "os.system", + lambda command: calls.append(command) or 0, + ) + + yaml_path = tmp_path / "malicious.yaml" + yaml_path.write_text( + '!!python/object/apply:os.system ["echo pwned"]\n', + encoding="utf-8", + ) + + with pytest.raises(ParameterParsingError): + _load_yaml_file(str(yaml_path)) + + assert calls == [] + + +def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code( + monkeypatch, +): + eval_calls = [] + + monkeypatch.setattr( + "builtins.eval", + lambda expression, globals=None, locals=None: ( + eval_calls.append(expression) or range(1, 4) + ), + ) + + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": ['__import__("os").system("echo pwned")'], + }, + } + } + ) + + with pytest.raises(ValueError, match="breakdown"): + homogenize_parameter_structures(root, {}, default_value=0) + + assert eval_calls == [] + + +def test_homogenize_parameter_structures_rejects_oversized_dynamic_breakdown(): + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [f"list(range({MAX_DYNAMIC_BREAKDOWN_VALUES + 1}))"], + }, + } + } + ) + + with pytest.raises(ValueError, match="exceeds the maximum"): + homogenize_parameter_structures(root, {}, default_value=0) + + +def test_homogenize_parameter_structures_rejects_overflowing_dynamic_breakdown(): + huge_stop = "1" + ("0" * 100) + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [f"range(0, {huge_stop})"], + }, + } + } + ) + + with pytest.raises(ValueError, match="too many values"): + homogenize_parameter_structures(root, {}, default_value=0) + + +def test_parameter_yaml_loader_rejects_implicit_duplicate_keys(tmp_path): + yaml_path = tmp_path / "duplicate-bools.yaml" + yaml_path.write_text("true: 1\nTrue: 2\n", encoding="utf-8") + + with pytest.raises(ParameterParsingError, match="duplicate key"): + _load_yaml_file(str(yaml_path)) + + +def test_parameter_yaml_loader_allows_merge_key_overrides(tmp_path): + yaml_path = tmp_path / "merge-override.yaml" + yaml_path.write_text( + "defaults: &defaults\n value: 1\nmerged:\n <<: *defaults\n value: 2\n", + encoding="utf-8", + ) + + result = _load_yaml_file(str(yaml_path)) + + assert result["merged"]["value"] == 2 + + +def test_evaluate_dynamic_breakdown_allows_documented_safe_forms(): + assert evaluate_dynamic_breakdown("list(range(1, 4))") == [1, 2, 3] + assert evaluate_dynamic_breakdown("[1, 2, -3]") == [1, 2, -3] + assert evaluate_dynamic_breakdown('("a", "b")') == ["a", "b"] + + +def test_homogenize_parameter_structures_ignores_empty_breakdown_lists(): + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [], + }, + } + } + ) + + result = homogenize_parameter_structures(root, {}, default_value=0) + + assert result is root diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index 72e3faadb..9bd6f36f6 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -203,6 +203,113 @@ def test_performance_tables_option_output(): clean_performance_files(paths) +def test_yaml_runner_rejects_python_object_tags(tmp_path, monkeypatch): + calls = [] + yaml_path = tmp_path / "malicious.yaml" + yaml_path.write_text( + '!!python/object/apply:os.system ["echo pwned"]\n', + encoding="utf-8", + ) + + monkeypatch.setattr( + "os.system", + lambda command: calls.append(command) or 0, + ) + + malicious_yaml_file = object.__new__(YamlFile) + malicious_yaml_file.path = yaml_path + malicious_yaml_file.options = {} + malicious_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError): + list(malicious_yaml_file.collect()) + + assert calls == [] + + +def test_yaml_runner_wraps_composer_errors(tmp_path): + yaml_path = tmp_path / "invalid-anchor.yaml" + yaml_path.write_text("value: *missing_anchor\n", encoding="utf-8") + + invalid_yaml_file = object.__new__(YamlFile) + invalid_yaml_file.path = yaml_path + invalid_yaml_file.options = {} + invalid_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="not a valid YAML file"): + list(invalid_yaml_file.collect()) + + +def test_yaml_runner_rejects_scalar_roots(tmp_path): + yaml_path = tmp_path / "scalar.yaml" + yaml_path.write_text("foo\n", encoding="utf-8") + + scalar_yaml_file = object.__new__(YamlFile) + scalar_yaml_file.path = yaml_path + scalar_yaml_file.options = {} + scalar_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="list of mappings"): + list(scalar_yaml_file.collect()) + + +def test_yaml_runner_rejects_scalar_keywords(tmp_path): + yaml_path = tmp_path / "invalid-keywords.yaml" + yaml_path.write_text( + "name: Example\nkeywords: 0\noutput: {}\n", + encoding="utf-8", + ) + + invalid_yaml_file = object.__new__(YamlFile) + invalid_yaml_file.path = yaml_path + invalid_yaml_file.options = {"name_filter": "missing"} + invalid_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="'keywords' must be a list"): + list(invalid_yaml_file.collect()) + + +def test_yaml_runner_allows_yaml_merge_anchors(tmp_path): + yaml_path = tmp_path / "anchors.yaml" + yaml_path.write_text( + """ +- name: define anchor + input: + persons: &persons + Alicia: + salary: 4000 + households: + household: + parents: [Alicia] + output: + salary: 4000 + +- name: merge anchor + input: + persons: + <<: *persons + households: + household: + parents: [Alicia] + output: + salary: 4000 +""".strip(), + encoding="utf-8", + ) + + yaml_file = object.__new__(YamlFile) + yaml_file.config = None + yaml_file.session = None + yaml_file._nodeid = "anchors" + yaml_file.path = yaml_path + yaml_file.options = {} + yaml_file.tax_benefit_system = TaxBenefitSystem() + + collected = list(yaml_file.collect()) + + assert len(collected) == 2 + + def clean_performance_files(paths: List[str]): for path in paths: if os.path.isfile(path):