Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/policyengine/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .output import Output as Output
from .output import OutputCollection as OutputCollection
from .parameter import Parameter as Parameter
from .parameter_node import ParameterNode as ParameterNode
from .parameter_value import ParameterValue as ParameterValue
from .policy import Policy as Policy
from .region import Region as Region
Expand All @@ -23,4 +24,5 @@
TaxBenefitModelVersion.model_rebuild()
Variable.model_rebuild()
Parameter.model_rebuild()
ParameterNode.model_rebuild()
ParameterValue.model_rebuild()
29 changes: 29 additions & 0 deletions src/policyengine/core/parameter_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
from uuid import uuid4

from pydantic import BaseModel, Field

if TYPE_CHECKING:
from .tax_benefit_model_version import TaxBenefitModelVersion


class ParameterNode(BaseModel):
"""Represents a folder/category node in the parameter hierarchy.

Parameter nodes are intermediate nodes in the parameter tree (e.g., "gov",
"gov.hmrc", "gov.hmrc.income_tax"). They provide structure and human-readable
labels for navigating the parameter tree, but don't have values themselves.

Unlike Parameter objects (which are leaf nodes with actual values),
ParameterNode objects are purely organizational.
"""

model_config = {"arbitrary_types_allowed": True}

id: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')")
label: str | None = Field(
default=None, description="Human-readable label (e.g., 'HMRC')"
)
description: str | None = Field(default=None, description="Node description")
tax_benefit_model_version: "TaxBenefitModelVersion"
22 changes: 20 additions & 2 deletions src/policyengine/core/tax_benefit_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if TYPE_CHECKING:
from .parameter import Parameter
from .parameter_node import ParameterNode
from .parameter_value import ParameterValue
from .region import Region, RegionRegistry
from .simulation import Simulation
Expand All @@ -25,6 +26,7 @@ class TaxBenefitModelVersion(BaseModel):

variables: list["Variable"] = Field(default_factory=list)
parameters: list["Parameter"] = Field(default_factory=list)
parameter_nodes: list["ParameterNode"] = Field(default_factory=list)

# Region registry for geographic simulations
region_registry: "RegionRegistry | None" = Field(
Expand All @@ -43,6 +45,9 @@ def parameter_values(self) -> list["ParameterValue"]:
parameters_by_name: dict[str, "Parameter"] = Field(
default_factory=dict, exclude=True
)
parameter_nodes_by_name: dict[str, "ParameterNode"] = Field(
default_factory=dict, exclude=True
)

def run(self, simulation: "Simulation") -> "Simulation":
raise NotImplementedError(
Expand All @@ -69,6 +74,11 @@ def add_variable(self, var: "Variable") -> None:
self.variables.append(var)
self.variables_by_name[var.name] = var

def add_parameter_node(self, node: "ParameterNode") -> None:
"""Add a parameter node and index it for fast lookup."""
self.parameter_nodes.append(node)
self.parameter_nodes_by_name[node.name] = node

def get_parameter(self, name: str) -> "Parameter":
"""Get a parameter by name (O(1) lookup)."""
if name in self.parameters_by_name:
Expand All @@ -85,6 +95,14 @@ def get_variable(self, name: str) -> "Variable":
f"Variable '{name}' not found in {self.model.id} version {self.version}"
)

def get_parameter_node(self, name: str) -> "ParameterNode":
"""Get a parameter node by name (O(1) lookup)."""
if name in self.parameter_nodes_by_name:
return self.parameter_nodes_by_name[name]
raise ValueError(
f"ParameterNode '{name}' not found in {self.model.id} version {self.version}"
)

def get_region(self, code: str) -> "Region | None":
"""Get a region by its code.

Expand All @@ -99,5 +117,5 @@ def get_region(self, code: str) -> "Region | None":
return self.region_registry.get(code)

def __repr__(self) -> str:
# Give the id and version, and the number of variables, parameters, parameter values
return f"<TaxBenefitModelVersion id={self.id} variables={len(self.variables)} parameters={len(self.parameters)} parameter_values={len(self.parameter_values)}>"
# Give the id and version, and the number of variables, parameters, parameter nodes, parameter values
return f"<TaxBenefitModelVersion id={self.id} variables={len(self.variables)} parameters={len(self.parameters)} parameter_nodes={len(self.parameter_nodes)} parameter_values={len(self.parameter_values)}>"
11 changes: 11 additions & 0 deletions src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from policyengine.core import (
Parameter,
ParameterNode,
TaxBenefitModel,
TaxBenefitModelVersion,
Variable,
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(self, **kwargs: dict):
self.add_variable(variable)

from policyengine_core.parameters import Parameter as CoreParameter
from policyengine_core.parameters import ParameterNode as CoreParameterNode

scale_lookup = build_scale_lookup(system)

Expand All @@ -207,6 +209,15 @@ def __init__(self, **kwargs: dict):
_core_param=param_node,
)
self.add_parameter(parameter)
elif isinstance(param_node, CoreParameterNode):
node = ParameterNode(
id=self.id + "-" + param_node.name,
name=param_node.name,
label=param_node.metadata.get("label"),
description=param_node.description,
tax_benefit_model_version=self,
)
self.add_parameter_node(node)

def _build_entity_relationships(
self, dataset: PolicyEngineUKDataset
Expand Down
11 changes: 11 additions & 0 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from policyengine.core import (
Parameter,
ParameterNode,
TaxBenefitModel,
TaxBenefitModelVersion,
Variable,
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(self, **kwargs: dict):
self.add_variable(variable)

from policyengine_core.parameters import Parameter as CoreParameter
from policyengine_core.parameters import ParameterNode as CoreParameterNode

scale_lookup = build_scale_lookup(system)

Expand All @@ -184,6 +186,15 @@ def __init__(self, **kwargs: dict):
_core_param=param_node,
)
self.add_parameter(parameter)
elif isinstance(param_node, CoreParameterNode):
node = ParameterNode(
id=self.id + "-" + param_node.name,
name=param_node.name,
label=param_node.metadata.get("label"),
description=param_node.description,
tax_benefit_model_version=self,
)
self.add_parameter_node(node)

def _build_entity_relationships(
self, dataset: PolicyEngineUSDataset
Expand Down
Loading