From 8b365f3bed96ed202818bcb5dd3b9dcbeda3f60b Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Tue, 10 Mar 2026 20:18:46 +0530 Subject: [PATCH] feat: Expose parameter_nodes property on TaxBenefitModelVersion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ParameterNode class and expose parameter_nodes list from TaxBenefitModelVersion. This allows the API to seed folder/category labels (e.g., "HMRC" instead of "hmrc") for the parameter tree. Changes: - Add src/policyengine/core/parameter_node.py - Update TaxBenefitModelVersion with parameter_nodes list and lookup - Update UK and US models to populate parameter_nodes from CoreParameterNode 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/policyengine/core/__init__.py | 2 ++ src/policyengine/core/parameter_node.py | 29 +++++++++++++++++++ .../core/tax_benefit_model_version.py | 22 ++++++++++++-- .../tax_benefit_models/uk/model.py | 11 +++++++ .../tax_benefit_models/us/model.py | 11 +++++++ 5 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 src/policyengine/core/parameter_node.py diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index fdd250ea..b2e70b3f 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -6,6 +6,7 @@ from .output import Output as Output from .output import OutputCollection as OutputCollection from .parameter import Parameter as Parameter +from .parameter_node import ParameterNode as ParameterNode from .parameter_value import ParameterValue as ParameterValue from .policy import Policy as Policy from .region import Region as Region @@ -23,4 +24,5 @@ TaxBenefitModelVersion.model_rebuild() Variable.model_rebuild() Parameter.model_rebuild() +ParameterNode.model_rebuild() ParameterValue.model_rebuild() diff --git a/src/policyengine/core/parameter_node.py b/src/policyengine/core/parameter_node.py new file mode 100644 index 00000000..9a3e25a0 --- /dev/null +++ b/src/policyengine/core/parameter_node.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING +from uuid import uuid4 + +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from .tax_benefit_model_version import TaxBenefitModelVersion + + +class ParameterNode(BaseModel): + """Represents a folder/category node in the parameter hierarchy. + + Parameter nodes are intermediate nodes in the parameter tree (e.g., "gov", + "gov.hmrc", "gov.hmrc.income_tax"). They provide structure and human-readable + labels for navigating the parameter tree, but don't have values themselves. + + Unlike Parameter objects (which are leaf nodes with actual values), + ParameterNode objects are purely organizational. + """ + + model_config = {"arbitrary_types_allowed": True} + + id: str = Field(default_factory=lambda: str(uuid4())) + name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')") + label: str | None = Field( + default=None, description="Human-readable label (e.g., 'HMRC')" + ) + description: str | None = Field(default=None, description="Node description") + tax_benefit_model_version: "TaxBenefitModelVersion" diff --git a/src/policyengine/core/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py index af37172c..a926e203 100644 --- a/src/policyengine/core/tax_benefit_model_version.py +++ b/src/policyengine/core/tax_benefit_model_version.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from .parameter import Parameter + from .parameter_node import ParameterNode from .parameter_value import ParameterValue from .region import Region, RegionRegistry from .simulation import Simulation @@ -25,6 +26,7 @@ class TaxBenefitModelVersion(BaseModel): variables: list["Variable"] = Field(default_factory=list) parameters: list["Parameter"] = Field(default_factory=list) + parameter_nodes: list["ParameterNode"] = Field(default_factory=list) # Region registry for geographic simulations region_registry: "RegionRegistry | None" = Field( @@ -43,6 +45,9 @@ def parameter_values(self) -> list["ParameterValue"]: parameters_by_name: dict[str, "Parameter"] = Field( default_factory=dict, exclude=True ) + parameter_nodes_by_name: dict[str, "ParameterNode"] = Field( + default_factory=dict, exclude=True + ) def run(self, simulation: "Simulation") -> "Simulation": raise NotImplementedError( @@ -69,6 +74,11 @@ def add_variable(self, var: "Variable") -> None: self.variables.append(var) self.variables_by_name[var.name] = var + def add_parameter_node(self, node: "ParameterNode") -> None: + """Add a parameter node and index it for fast lookup.""" + self.parameter_nodes.append(node) + self.parameter_nodes_by_name[node.name] = node + def get_parameter(self, name: str) -> "Parameter": """Get a parameter by name (O(1) lookup).""" if name in self.parameters_by_name: @@ -85,6 +95,14 @@ def get_variable(self, name: str) -> "Variable": f"Variable '{name}' not found in {self.model.id} version {self.version}" ) + def get_parameter_node(self, name: str) -> "ParameterNode": + """Get a parameter node by name (O(1) lookup).""" + if name in self.parameter_nodes_by_name: + return self.parameter_nodes_by_name[name] + raise ValueError( + f"ParameterNode '{name}' not found in {self.model.id} version {self.version}" + ) + def get_region(self, code: str) -> "Region | None": """Get a region by its code. @@ -99,5 +117,5 @@ def get_region(self, code: str) -> "Region | None": return self.region_registry.get(code) def __repr__(self) -> str: - # Give the id and version, and the number of variables, parameters, parameter values - return f"" + # Give the id and version, and the number of variables, parameters, parameter nodes, parameter values + return f"" diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 231f6c73..4bca0b22 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -10,6 +10,7 @@ from policyengine.core import ( Parameter, + ParameterNode, TaxBenefitModel, TaxBenefitModelVersion, Variable, @@ -189,6 +190,7 @@ def __init__(self, **kwargs: dict): self.add_variable(variable) from policyengine_core.parameters import Parameter as CoreParameter + from policyengine_core.parameters import ParameterNode as CoreParameterNode scale_lookup = build_scale_lookup(system) @@ -207,6 +209,15 @@ def __init__(self, **kwargs: dict): _core_param=param_node, ) self.add_parameter(parameter) + elif isinstance(param_node, CoreParameterNode): + node = ParameterNode( + id=self.id + "-" + param_node.name, + name=param_node.name, + label=param_node.metadata.get("label"), + description=param_node.description, + tax_benefit_model_version=self, + ) + self.add_parameter_node(node) def _build_entity_relationships( self, dataset: PolicyEngineUKDataset diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index f13cdb9b..bc53b30c 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -9,6 +9,7 @@ from policyengine.core import ( Parameter, + ParameterNode, TaxBenefitModel, TaxBenefitModelVersion, Variable, @@ -166,6 +167,7 @@ def __init__(self, **kwargs: dict): self.add_variable(variable) from policyengine_core.parameters import Parameter as CoreParameter + from policyengine_core.parameters import ParameterNode as CoreParameterNode scale_lookup = build_scale_lookup(system) @@ -184,6 +186,15 @@ def __init__(self, **kwargs: dict): _core_param=param_node, ) self.add_parameter(parameter) + elif isinstance(param_node, CoreParameterNode): + node = ParameterNode( + id=self.id + "-" + param_node.name, + name=param_node.name, + label=param_node.metadata.get("label"), + description=param_node.description, + tax_benefit_model_version=self, + ) + self.add_parameter_node(node) def _build_entity_relationships( self, dataset: PolicyEngineUSDataset