Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/uprating-cloned-subtrees.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed uprating dependency sorting for cloned parameter subtrees with duplicate parameter names.
206 changes: 191 additions & 15 deletions policyengine_core/parameters/operations/uprate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
for parameter in root.get_descendants()
if isinstance(parameter, Parameter)
]
parameter_paths = get_parameter_paths(root)

for parameter in sort_parameters_by_uprating_dependencies(parameters):
uprate_parameter(parameter, root)
for parameter in sort_parameters_by_uprating_dependencies(
parameters,
parameter_paths,
):
uprate_parameter(parameter, root, parameter_paths)
return root


Expand All @@ -57,35 +61,196 @@ def get_uprating_dependency_name(parameter: Parameter) -> Optional[str]:
return dependency_name


def get_parameter_paths(root: ParameterNode) -> dict[int, str]:
paths = {}

def visit_parameter_node(node: ParameterNode, path: str) -> None:
for child_name, child in node.children.items():
child_path = f"{path}.{child_name}" if path else child_name
visit_child(child, child_path)

def visit_child(child, path: str) -> None:
if isinstance(child, Parameter):
paths[id(child)] = path
elif isinstance(child, ParameterNode):
visit_parameter_node(child, path)
else:
brackets = getattr(child, "__dict__", {}).get("brackets")
if brackets is not None:
for index, bracket in enumerate(brackets):
visit_parameter_node(bracket, f"{path}[{index}]")

visit_parameter_node(root, "")
return paths


def join_parameter_path(prefix: str, suffix: str) -> str:
if not prefix:
return suffix
if not suffix:
return prefix
return f"{prefix}.{suffix}"


def get_parameter_scope_prefixes(
parameter: Parameter,
parameter_paths: dict[int, str],
) -> Optional[tuple[str, str]]:
parameter_path = parameter_paths.get(id(parameter))
if parameter_path is None or parameter_path == parameter.name:
return None
parameter_name_parts = parameter.name.split(".")
parameter_path_parts = parameter_path.split(".")
common_suffix_length = 0
max_common_suffix_length = min(
len(parameter_name_parts),
len(parameter_path_parts),
)
while (
common_suffix_length < max_common_suffix_length
and parameter_name_parts[-common_suffix_length - 1]
== parameter_path_parts[-common_suffix_length - 1]
):
common_suffix_length += 1
if common_suffix_length == 0:
return None
original_prefix = ".".join(parameter_name_parts[:-common_suffix_length])
current_prefix = ".".join(parameter_path_parts[:-common_suffix_length])
return original_prefix, current_prefix


def map_original_parameter_path(
parameter_path: str,
original_prefix: str,
current_prefix: str,
) -> Optional[str]:
if not original_prefix:
return join_parameter_path(current_prefix, parameter_path)
if parameter_path == original_prefix:
return current_prefix
original_prefix_with_separator = f"{original_prefix}."
if parameter_path.startswith(original_prefix_with_separator):
return join_parameter_path(
current_prefix,
parameter_path[len(original_prefix_with_separator) :],
)
return None


def get_scoped_uprating_dependency_names(
parameter: Parameter,
dependency_name: str,
parameter_paths: dict[int, str],
) -> list[str]:
parameter_path = parameter_paths.get(id(parameter))
if parameter_path is None or parameter_path == parameter.name:
return [dependency_name]

dependency_names = []

def add_dependency_name(candidate: Optional[str]) -> None:
if candidate and candidate not in dependency_names:
dependency_names.append(candidate)

scope_prefixes = get_parameter_scope_prefixes(parameter, parameter_paths)
if scope_prefixes is not None:
original_prefix, current_prefix = scope_prefixes
add_dependency_name(
map_original_parameter_path(
dependency_name,
original_prefix,
current_prefix,
)
)

add_dependency_name(dependency_name)
return dependency_names


def get_parameter_lookup_names(
parameter: Parameter,
parameter_paths: dict[int, str],
) -> set[str]:
parameter_path = parameter_paths.get(id(parameter))
if parameter_path is None:
return {parameter.name}
return {parameter_path}


def get_uprating_parameter(
root: ParameterNode,
parameter: Parameter,
dependency_name: str,
parameter_paths: dict[int, str],
) -> Parameter:
for scoped_dependency_name in get_scoped_uprating_dependency_names(
parameter,
dependency_name,
parameter_paths,
):
try:
return get_parameter(root, scoped_dependency_name)
except ValueError:
continue
return get_parameter(root, dependency_name)


def sort_parameters_by_uprating_dependencies(
parameters: list[Parameter],
parameter_paths: Optional[dict[int, str]] = None,
) -> list[Parameter]:
if parameter_paths is None:
parameter_paths = {}
parameters_to_uprate = [
parameter
for parameter in parameters
if parameter.metadata.get("uprating") is not None
]
parameter_by_name = {
parameter.name: parameter for parameter in parameters_to_uprate
}
parameters_by_name = {}
for parameter in parameters_to_uprate:
for name in get_parameter_lookup_names(parameter, parameter_paths):
parameters_by_name.setdefault(name, []).append(parameter)
ordered_parameters = []
visited = set()
visiting = []
visiting_ids = set()

def visit(parameter: Parameter):
if parameter.name in visited:
parameter_id = id(parameter)
if parameter_id in visited:
return
if parameter.name in visiting:
cycle = visiting[visiting.index(parameter.name) :] + [parameter.name]
if parameter_id in visiting_ids:
cycle_start = next(
index
for index, visiting_parameter in enumerate(visiting)
if id(visiting_parameter) == parameter_id
)
cycle = visiting[cycle_start:] + [parameter]
raise ValueError(
"Cyclic uprating dependency detected: " + " -> ".join(cycle)
"Cyclic uprating dependency detected: "
+ " -> ".join(parameter.name for parameter in cycle)
)
visiting.append(parameter.name)
visiting.append(parameter)
visiting_ids.add(parameter_id)
dependency_name = get_uprating_dependency_name(parameter)
if dependency_name in parameter_by_name:
visit(parameter_by_name[dependency_name])
dependency_parameters = []
if dependency_name is not None:
for scoped_dependency_name in get_scoped_uprating_dependency_names(
parameter,
dependency_name,
parameter_paths,
):
dependency_parameters = parameters_by_name.get(
scoped_dependency_name,
[],
)
if dependency_parameters:
break
for dependency in dependency_parameters:
visit(dependency)
visiting.pop()
visited.add(parameter.name)
visiting_ids.remove(parameter_id)
visited.add(parameter_id)
ordered_parameters.append(parameter)

for parameter in sorted(parameters_to_uprate, key=lambda p: p.name):
Expand All @@ -94,7 +259,13 @@ def visit(parameter: Parameter):
return ordered_parameters


def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None:
def uprate_parameter(
parameter: Parameter,
root: ParameterNode,
parameter_paths: Optional[dict[int, str]] = None,
) -> None:
if parameter_paths is None:
parameter_paths = {}
# Pull the uprating definition dict
meta = normalize_uprating_metadata(parameter.metadata["uprating"])

Expand All @@ -106,7 +277,12 @@ def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None:
)
# Otherwise, pull uprating table from YAML
else:
uprating_parameter = get_parameter(root, meta["parameter"])
uprating_parameter = get_uprating_parameter(
root,
parameter,
meta["parameter"],
parameter_paths,
)

# If uprating with a set candence, ensure that all
# required values are present
Expand Down
Loading
Loading