Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/secure-yaml-and-breakdowns.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use safe YAML loaders and remove dynamic eval from parameter breakdown handling.
4 changes: 2 additions & 2 deletions policyengine_core/data_structures/parameter_node_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class ParameterNodeMetadata:
metadata:
breakdown:
- region # If `region` is an Enum-type variable with possible values `[ENGLAND, WALES]` then these children will be added.
- range(1, 7) # This code is `eval`uated to produce the list `[1, 2, 3, 4, 5, 6]` which are then added as grandchildren.
- [True, False] # This list is added as great-grandchildren (using the same `eval` method as above).
- range(1, 7) # Safe dynamic form producing `[1, 2, 3, 4, 5, 6]`, which are then added as grandchildren.
- [True, False] # Literal collection form added as great-grandchildren.

"""

Expand Down
33 changes: 24 additions & 9 deletions policyengine_core/parameters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from policyengine_core.warnings import LibYAMLWarning

try:
from yaml import CLoader as Loader
from yaml import CSafeLoader as Loader
except ImportError:
message = [
"libyaml is not installed in your environment.",
Expand All @@ -17,7 +17,7 @@
]
warnings.warn(" ".join(message), LibYAMLWarning)
from yaml import (
Loader,
SafeLoader as Loader,
) # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270)

ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List)
Expand All @@ -33,15 +33,30 @@ def date_constructor(_loader, node):


def dict_no_duplicate_constructor(loader, node, deep=False):
keys = [key.value for key, value in node.value]
explicit_keys = {}
for key_node, _value_node in node.value:
if key_node.tag == "tag:yaml.org,2002:merge":
continue
key = loader.construct_object(key_node, deep=deep)
try:
if key in explicit_keys:
raise yaml.parser.ParserError(
"", node.start_mark, f"Found duplicate key '{key}'"
)
except TypeError as exc:
raise yaml.constructor.ConstructorError(
"", node.start_mark, f"Found unhashable key '{key}'"
) from exc
explicit_keys[key] = True

if len(keys) != len(set(keys)):
duplicate = next((key for key in keys if keys.count(key) > 1))
raise yaml.parser.ParserError(
"", node.start_mark, f"Found duplicate key '{duplicate}'"
)
loader.flatten_mapping(node)
pairs = loader.construct_pairs(node, deep=deep)
mapping = {}

return loader.construct_mapping(node, deep)
for key, value in pairs:
mapping[key] = value

return mapping


yaml.add_constructor(
Expand Down
92 changes: 90 additions & 2 deletions policyengine_core/parameters/operations/homogenize_parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import logging
from typing import Any, Dict, List, Type

Expand All @@ -6,6 +7,8 @@
from policyengine_core.parameters.parameter_node import ParameterNode
from policyengine_core.variables import Variable

MAX_DYNAMIC_BREAKDOWN_VALUES = 10_000


def homogenize_parameter_structures(
root: ParameterNode, variables: Dict[str, Variable], default_value: Any = 0
Expand Down Expand Up @@ -43,6 +46,11 @@ def get_breakdown_variables(node: ParameterNode) -> List[str]:
f"Invalid breakdown metadata for parameter {node.name}: {type(breakdown)}"
)
return None
if len(breakdown) == 0:
logging.warning(
f"Invalid breakdown metadata for parameter {node.name}: empty list"
)
return None
return breakdown
else:
return None
Expand Down Expand Up @@ -71,8 +79,7 @@ def homogenize_parameter_node(
elif dtype == bool:
possible_values = [True, False]
else:
# Try to execute the breakdown as Python code
possible_values = list(eval(first_breakdown))
possible_values = evaluate_dynamic_breakdown(first_breakdown)
if not hasattr(node, "children"):
node = ParameterNode(
node.name,
Expand Down Expand Up @@ -119,3 +126,84 @@ def homogenize_parameter_node(
node.children[child], breakdown[1:], variables, default_value
)
return node


def evaluate_dynamic_breakdown(expression: str) -> List[Any]:
"""Safely evaluate a dynamic breakdown expression.

The parameter metadata only needs literal collections and the documented
``range(...)`` / ``list(range(...))`` forms. Anything else is rejected.
"""

parsed = ast.parse(expression, mode="eval")
evaluated = evaluate_dynamic_breakdown_node(parsed.body)
if isinstance(evaluated, range):
validate_dynamic_breakdown_range_cardinality(evaluated, expression)
return list(evaluated)
if isinstance(evaluated, (list, tuple)):
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
return list(evaluated)
if isinstance(evaluated, set):
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
return list(evaluated)
raise ValueError(
f"Invalid dynamic breakdown expression '{expression}'. "
"Only literal collections and range() calls are allowed."
)


def validate_dynamic_breakdown_cardinality(count: int, expression: str) -> None:
if count > MAX_DYNAMIC_BREAKDOWN_VALUES:
raise ValueError(
f"Dynamic breakdown expression '{expression}' produces {count} values, "
f"which exceeds the maximum of {MAX_DYNAMIC_BREAKDOWN_VALUES}."
)


def validate_dynamic_breakdown_range_cardinality(
values: range, expression: str
) -> None:
try:
count = len(values)
except OverflowError as exc:
raise ValueError(
f"Dynamic breakdown expression '{expression}' produces too many values."
) from exc
validate_dynamic_breakdown_cardinality(count, expression)


def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any:
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.List):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return [evaluate_dynamic_breakdown_node(element) for element in node.elts]
if isinstance(node, ast.Tuple):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts)
if isinstance(node, ast.Set):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return {evaluate_dynamic_breakdown_node(element) for element in node.elts}
if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
operand = evaluate_dynamic_breakdown_node(node.operand)
return operand if isinstance(node.op, ast.UAdd) else -operand
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
if node.func.id == "range":
args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args]
if node.keywords:
raise ValueError("range() keyword arguments are not allowed")
result = range(*args)
validate_dynamic_breakdown_range_cardinality(result, ast.unparse(node))
return result
if node.func.id == "list":
if len(node.args) != 1 or node.keywords:
raise ValueError("list() must contain a single positional argument")
evaluated = evaluate_dynamic_breakdown_node(node.args[0])
if isinstance(evaluated, (range, list, tuple, set)):
return evaluated
raise ValueError(
"list() only supports range() and literal collection expressions"
)
raise ValueError(
f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}"
)
19 changes: 16 additions & 3 deletions policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def import_yaml():
import yaml

try:
from yaml import CLoader as Loader
from yaml import CSafeLoader as Loader
except ImportError:
log.warning(
" "
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, *, tax_benefit_system, options, **kwargs):
def collect(self):
try:
tests = yaml.load(self.path.open(), Loader=Loader)
except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError):
except (yaml.YAMLError, TypeError):
message = os.linesep.join(
[
traceback.format_exc(),
Expand All @@ -139,6 +139,11 @@ def collect(self):
tests: List[Dict] = [tests]

for test in tests:
if not isinstance(test, dict):
raise ValueError(
f"'{self.path}' is not a valid YAML test file. "
"Expected a mapping or a list of mappings."
)
if not self.should_ignore(test):
yield YamlItem.from_parent(
self,
Expand All @@ -150,11 +155,19 @@ def collect(self):

def should_ignore(self, test):
name_filter = self.options.get("name_filter")
keywords = test.get("keywords", [])
if keywords is None:
keywords = []
if not isinstance(keywords, list):
raise ValueError(
f"'{self.path}' is not a valid YAML test file. "
"'keywords' must be a list."
)
return (
name_filter is not None
and name_filter not in os.path.splitext(self.fspath.basename)[0]
and name_filter not in test.get("name", "")
and name_filter not in test.get("keywords", [])
and name_filter not in keywords
)


Expand Down
130 changes: 130 additions & 0 deletions tests/core/test_parameter_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pytest

from policyengine_core.errors import ParameterParsingError
from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures
from policyengine_core.parameters.helpers import _load_yaml_file
from policyengine_core.parameters.operations.homogenize_parameters import (
MAX_DYNAMIC_BREAKDOWN_VALUES,
evaluate_dynamic_breakdown,
)


def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch):
calls = []

monkeypatch.setattr(
"os.system",
lambda command: calls.append(command) or 0,
)

yaml_path = tmp_path / "malicious.yaml"
yaml_path.write_text(
'!!python/object/apply:os.system ["echo pwned"]\n',
encoding="utf-8",
)

with pytest.raises(ParameterParsingError):
_load_yaml_file(str(yaml_path))

assert calls == []


def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code(
monkeypatch,
):
eval_calls = []

monkeypatch.setattr(
"builtins.eval",
lambda expression, globals=None, locals=None: (
eval_calls.append(expression) or range(1, 4)
),
)

root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": ['__import__("os").system("echo pwned")'],
},
}
}
)

with pytest.raises(ValueError, match="breakdown"):
homogenize_parameter_structures(root, {}, default_value=0)

assert eval_calls == []


def test_homogenize_parameter_structures_rejects_oversized_dynamic_breakdown():
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [f"list(range({MAX_DYNAMIC_BREAKDOWN_VALUES + 1}))"],
},
}
}
)

with pytest.raises(ValueError, match="exceeds the maximum"):
homogenize_parameter_structures(root, {}, default_value=0)


def test_homogenize_parameter_structures_rejects_overflowing_dynamic_breakdown():
huge_stop = "1" + ("0" * 100)
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [f"range(0, {huge_stop})"],
},
}
}
)

with pytest.raises(ValueError, match="too many values"):
homogenize_parameter_structures(root, {}, default_value=0)


def test_parameter_yaml_loader_rejects_implicit_duplicate_keys(tmp_path):
yaml_path = tmp_path / "duplicate-bools.yaml"
yaml_path.write_text("true: 1\nTrue: 2\n", encoding="utf-8")

with pytest.raises(ParameterParsingError, match="duplicate key"):
_load_yaml_file(str(yaml_path))


def test_parameter_yaml_loader_allows_merge_key_overrides(tmp_path):
yaml_path = tmp_path / "merge-override.yaml"
yaml_path.write_text(
"defaults: &defaults\n value: 1\nmerged:\n <<: *defaults\n value: 2\n",
encoding="utf-8",
)

result = _load_yaml_file(str(yaml_path))

assert result["merged"]["value"] == 2


def test_evaluate_dynamic_breakdown_allows_documented_safe_forms():
assert evaluate_dynamic_breakdown("list(range(1, 4))") == [1, 2, 3]
assert evaluate_dynamic_breakdown("[1, 2, -3]") == [1, 2, -3]
assert evaluate_dynamic_breakdown('("a", "b")') == ["a", "b"]


def test_homogenize_parameter_structures_ignores_empty_breakdown_lists():
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [],
},
}
}
)

result = homogenize_parameter_structures(root, {}, default_value=0)

assert result is root
Loading
Loading