From 1887743723da6bf966250015f9da9db0f8f76c6f Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 14 May 2026 19:03:17 -0400 Subject: [PATCH] Fix uprating of cloned parameter subtrees --- changelog.d/uprating-cloned-subtrees.fixed.md | 1 + .../operations/uprate_parameters.py | 206 ++++++++++++++++-- .../parameters/operations/test_uprating.py | 161 ++++++++++++++ 3 files changed, 353 insertions(+), 15 deletions(-) create mode 100644 changelog.d/uprating-cloned-subtrees.fixed.md diff --git a/changelog.d/uprating-cloned-subtrees.fixed.md b/changelog.d/uprating-cloned-subtrees.fixed.md new file mode 100644 index 00000000..4733a701 --- /dev/null +++ b/changelog.d/uprating-cloned-subtrees.fixed.md @@ -0,0 +1 @@ +Fixed uprating dependency sorting for cloned parameter subtrees with duplicate parameter names. diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py index e9c10ecc..ee756cc0 100644 --- a/policyengine_core/parameters/operations/uprate_parameters.py +++ b/policyengine_core/parameters/operations/uprate_parameters.py @@ -32,9 +32,13 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode: for parameter in root.get_descendants() if isinstance(parameter, Parameter) ] + parameter_paths = get_parameter_paths(root) - for parameter in sort_parameters_by_uprating_dependencies(parameters): - uprate_parameter(parameter, root) + for parameter in sort_parameters_by_uprating_dependencies( + parameters, + parameter_paths, + ): + uprate_parameter(parameter, root, parameter_paths) return root @@ -57,35 +61,196 @@ def get_uprating_dependency_name(parameter: Parameter) -> Optional[str]: return dependency_name +def get_parameter_paths(root: ParameterNode) -> dict[int, str]: + paths = {} + + def visit_parameter_node(node: ParameterNode, path: str) -> None: + for child_name, child in node.children.items(): + child_path = f"{path}.{child_name}" if path else child_name + visit_child(child, child_path) + + def visit_child(child, path: str) -> None: + if isinstance(child, Parameter): + paths[id(child)] = path + elif isinstance(child, ParameterNode): + visit_parameter_node(child, path) + else: + brackets = getattr(child, "__dict__", {}).get("brackets") + if brackets is not None: + for index, bracket in enumerate(brackets): + visit_parameter_node(bracket, f"{path}[{index}]") + + visit_parameter_node(root, "") + return paths + + +def join_parameter_path(prefix: str, suffix: str) -> str: + if not prefix: + return suffix + if not suffix: + return prefix + return f"{prefix}.{suffix}" + + +def get_parameter_scope_prefixes( + parameter: Parameter, + parameter_paths: dict[int, str], +) -> Optional[tuple[str, str]]: + parameter_path = parameter_paths.get(id(parameter)) + if parameter_path is None or parameter_path == parameter.name: + return None + parameter_name_parts = parameter.name.split(".") + parameter_path_parts = parameter_path.split(".") + common_suffix_length = 0 + max_common_suffix_length = min( + len(parameter_name_parts), + len(parameter_path_parts), + ) + while ( + common_suffix_length < max_common_suffix_length + and parameter_name_parts[-common_suffix_length - 1] + == parameter_path_parts[-common_suffix_length - 1] + ): + common_suffix_length += 1 + if common_suffix_length == 0: + return None + original_prefix = ".".join(parameter_name_parts[:-common_suffix_length]) + current_prefix = ".".join(parameter_path_parts[:-common_suffix_length]) + return original_prefix, current_prefix + + +def map_original_parameter_path( + parameter_path: str, + original_prefix: str, + current_prefix: str, +) -> Optional[str]: + if not original_prefix: + return join_parameter_path(current_prefix, parameter_path) + if parameter_path == original_prefix: + return current_prefix + original_prefix_with_separator = f"{original_prefix}." + if parameter_path.startswith(original_prefix_with_separator): + return join_parameter_path( + current_prefix, + parameter_path[len(original_prefix_with_separator) :], + ) + return None + + +def get_scoped_uprating_dependency_names( + parameter: Parameter, + dependency_name: str, + parameter_paths: dict[int, str], +) -> list[str]: + parameter_path = parameter_paths.get(id(parameter)) + if parameter_path is None or parameter_path == parameter.name: + return [dependency_name] + + dependency_names = [] + + def add_dependency_name(candidate: Optional[str]) -> None: + if candidate and candidate not in dependency_names: + dependency_names.append(candidate) + + scope_prefixes = get_parameter_scope_prefixes(parameter, parameter_paths) + if scope_prefixes is not None: + original_prefix, current_prefix = scope_prefixes + add_dependency_name( + map_original_parameter_path( + dependency_name, + original_prefix, + current_prefix, + ) + ) + + add_dependency_name(dependency_name) + return dependency_names + + +def get_parameter_lookup_names( + parameter: Parameter, + parameter_paths: dict[int, str], +) -> set[str]: + parameter_path = parameter_paths.get(id(parameter)) + if parameter_path is None: + return {parameter.name} + return {parameter_path} + + +def get_uprating_parameter( + root: ParameterNode, + parameter: Parameter, + dependency_name: str, + parameter_paths: dict[int, str], +) -> Parameter: + for scoped_dependency_name in get_scoped_uprating_dependency_names( + parameter, + dependency_name, + parameter_paths, + ): + try: + return get_parameter(root, scoped_dependency_name) + except ValueError: + continue + return get_parameter(root, dependency_name) + + def sort_parameters_by_uprating_dependencies( parameters: list[Parameter], + parameter_paths: Optional[dict[int, str]] = None, ) -> list[Parameter]: + if parameter_paths is None: + parameter_paths = {} parameters_to_uprate = [ parameter for parameter in parameters if parameter.metadata.get("uprating") is not None ] - parameter_by_name = { - parameter.name: parameter for parameter in parameters_to_uprate - } + parameters_by_name = {} + for parameter in parameters_to_uprate: + for name in get_parameter_lookup_names(parameter, parameter_paths): + parameters_by_name.setdefault(name, []).append(parameter) ordered_parameters = [] visited = set() visiting = [] + visiting_ids = set() def visit(parameter: Parameter): - if parameter.name in visited: + parameter_id = id(parameter) + if parameter_id in visited: return - if parameter.name in visiting: - cycle = visiting[visiting.index(parameter.name) :] + [parameter.name] + if parameter_id in visiting_ids: + cycle_start = next( + index + for index, visiting_parameter in enumerate(visiting) + if id(visiting_parameter) == parameter_id + ) + cycle = visiting[cycle_start:] + [parameter] raise ValueError( - "Cyclic uprating dependency detected: " + " -> ".join(cycle) + "Cyclic uprating dependency detected: " + + " -> ".join(parameter.name for parameter in cycle) ) - visiting.append(parameter.name) + visiting.append(parameter) + visiting_ids.add(parameter_id) dependency_name = get_uprating_dependency_name(parameter) - if dependency_name in parameter_by_name: - visit(parameter_by_name[dependency_name]) + dependency_parameters = [] + if dependency_name is not None: + for scoped_dependency_name in get_scoped_uprating_dependency_names( + parameter, + dependency_name, + parameter_paths, + ): + dependency_parameters = parameters_by_name.get( + scoped_dependency_name, + [], + ) + if dependency_parameters: + break + for dependency in dependency_parameters: + visit(dependency) visiting.pop() - visited.add(parameter.name) + visiting_ids.remove(parameter_id) + visited.add(parameter_id) ordered_parameters.append(parameter) for parameter in sorted(parameters_to_uprate, key=lambda p: p.name): @@ -94,7 +259,13 @@ def visit(parameter: Parameter): return ordered_parameters -def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None: +def uprate_parameter( + parameter: Parameter, + root: ParameterNode, + parameter_paths: Optional[dict[int, str]] = None, +) -> None: + if parameter_paths is None: + parameter_paths = {} # Pull the uprating definition dict meta = normalize_uprating_metadata(parameter.metadata["uprating"]) @@ -106,7 +277,12 @@ def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None: ) # Otherwise, pull uprating table from YAML else: - uprating_parameter = get_parameter(root, meta["parameter"]) + uprating_parameter = get_uprating_parameter( + root, + parameter, + meta["parameter"], + parameter_paths, + ) # If uprating with a set candence, ensure that all # required values are present diff --git a/tests/core/parameters/operations/test_uprating.py b/tests/core/parameters/operations/test_uprating.py index af99eac3..dcd4eaf7 100644 --- a/tests/core/parameters/operations/test_uprating.py +++ b/tests/core/parameters/operations/test_uprating.py @@ -26,6 +26,167 @@ def test_parameter_uprating_processes_dependencies_before_dependents(): assert uprated.target("2026-01-01") == pytest.approx(110) +def test_parameter_uprating_processes_cloned_subtrees_with_duplicate_names(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "target": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "middle"}, + }, + "middle": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "base"}, + }, + "base": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + } + ) + root.add_child("baseline", root.clone()) + + uprated = uprate_parameters(root) + + assert uprated.middle("2026-01-01") == pytest.approx(110) + assert uprated.target("2026-01-01") == pytest.approx(110) + assert uprated.baseline.middle("2026-01-01") == pytest.approx(110) + assert uprated.baseline.target("2026-01-01") == pytest.approx(110) + + +def test_parameter_uprating_uses_cloned_subtree_upraters(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "target": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "middle"}, + }, + "middle": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "base"}, + }, + "base": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + } + ) + root.add_child("baseline", root.clone()) + root.base.update(start="2026-01-01", value=120) + + uprated = uprate_parameters(root) + + assert uprated.middle("2026-01-01") == pytest.approx(120) + assert uprated.target("2026-01-01") == pytest.approx(120) + assert uprated.baseline.middle("2026-01-01") == pytest.approx(110) + assert uprated.baseline.target("2026-01-01") == pytest.approx(110) + + +def test_parameter_uprating_uses_reattached_cloned_subtree_upraters(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "gov": { + "target": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "gov.middle"}, + }, + "middle": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "gov.base"}, + }, + "base": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + }, + } + ) + root.add_child("baseline", root.gov.clone()) + root.gov.base.update(start="2026-01-01", value=120) + + uprated = uprate_parameters(root) + + assert uprated.gov.middle("2026-01-01") == pytest.approx(120) + assert uprated.gov.target("2026-01-01") == pytest.approx(120) + assert uprated.baseline.middle("2026-01-01") == pytest.approx(110) + assert uprated.baseline.target("2026-01-01") == pytest.approx(110) + + +def test_parameter_uprating_keeps_external_upraters_root_scoped_in_clones(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "external": { + "uprater": { + "values": {"2025-01-01": 100, "2026-01-01": 120}, + }, + }, + "gov": { + "target": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "external.uprater"}, + }, + "external": { + "uprater": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + }, + }, + } + ) + root.add_child("baseline", root.gov.clone()) + + uprated = uprate_parameters(root) + + assert uprated.gov.target("2026-01-01") == pytest.approx(120) + assert uprated.baseline.target("2026-01-01") == pytest.approx(120) + + +def test_parameter_uprating_handles_nodes_with_brackets_child_name(): + from policyengine_core.parameters import ( + ParameterNode, + ParameterScale, + uprate_parameters, + ) + + root = ParameterNode( + data={ + "container": { + "uprater": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + }, + } + ) + root.container.add_child( + "brackets", + ParameterScale( + "container.brackets", + data={ + "brackets": [ + { + "threshold": {"values": {"2025-01-01": 0}}, + "amount": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "container.uprater"}, + }, + }, + ], + }, + file_path=None, + ), + ) + + uprated = uprate_parameters(root) + + assert uprated.container.brackets.brackets[0].amount("2026-01-01") == pytest.approx( + 110 + ) + + def test_scale_bracket_uprating_processes_dependencies_before_dependents(): from policyengine_core.parameters import ParameterNode, uprate_parameters