From 48ec41bb9356f5b2ed66538bdf04f6968691503f Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 17:52:55 +0530 Subject: [PATCH 1/6] feat: Add parameter_nodes table for folder/category labels Add ParameterNode model and migration to store folder structure labels for the parameter tree (e.g., "HMRC" instead of "hmrc"). Changes: - Add ParameterNode SQLModel with name, label, description fields - Add Alembic migration for parameter_nodes table - Update seed_models.py to seed nodes from policyengine.py - Update /parameters/children to use node labels from DB Requires: PolicyEngine/policyengine.py#254 --- ..._67608331ee8a_add_parameter_nodes_table.py | 38 ++++++++++ scripts/seed_models.py | 76 ++++++++++++++++++- src/policyengine_api/api/parameters.py | 29 +++++-- src/policyengine_api/models/__init__.py | 4 + src/policyengine_api/models/parameter_node.py | 54 +++++++++++++ .../models/tax_benefit_model_version.py | 4 + 6 files changed, 194 insertions(+), 11 deletions(-) create mode 100644 alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py create mode 100644 src/policyengine_api/models/parameter_node.py diff --git a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py new file mode 100644 index 0000000..52859bb --- /dev/null +++ b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py @@ -0,0 +1,38 @@ +"""Add parameter_nodes table + +Revision ID: 67608331ee8a +Revises: add_modelled_policies +Create Date: 2026-03-10 18:29:54.555074 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '67608331ee8a' +down_revision: Union[str, Sequence[str], None] = 'add_modelled_policies' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table('parameter_nodes', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table('parameter_nodes') diff --git a/scripts/seed_models.py b/scripts/seed_models.py index 9313faa..36fe943 100644 --- a/scripts/seed_models.py +++ b/scripts/seed_models.py @@ -1,7 +1,7 @@ -"""Seed tax-benefit models with variables and parameters. +"""Seed tax-benefit models with variables, parameters, and parameter nodes. This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables, -Parameters, and ParameterValues from policyengine.py. +Parameters, ParameterValues, and ParameterNodes from policyengine.py. Usage: python scripts/seed_models.py # Seed UK and US models @@ -55,10 +55,10 @@ def seed_model( variable_whitelist: set[str] | None = None, parameter_prefixes: set[str] | None = None, ) -> TaxBenefitModelVersion: - """Seed a tax-benefit model with its variables and parameters. + """Seed a tax-benefit model with its variables, parameters, and parameter nodes. Args: - model_version: The policyengine.py model version object + model_version: The policyengine.py model version object (with parameter_nodes) session: Database session skip_state_params: Skip US state-level parameters (gov.states.*) variable_whitelist: If provided, only seed variables whose name is in this set @@ -336,6 +336,74 @@ def seed_model( + (f" (skipped {skipped} invalid)" if skipped else "") ) + # Add parameter nodes (folder/category structure) + # Uses model_version.parameter_nodes exposed by policyengine.py + parameter_nodes = model_version.parameter_nodes + + # Filter by prefix if specified (same as parameters) + if parameter_prefixes is not None: + parameter_nodes = [ + n + for n in parameter_nodes + if any(n.name.startswith(prefix) for prefix in parameter_prefixes) + ] + + # Deduplicate by name + seen_node_names = set() + nodes_to_add = [] + for node in parameter_nodes: + if node.name not in seen_node_names: + nodes_to_add.append(node) + seen_node_names.add(node.name) + + console.print( + f" Found {len(nodes_to_add)} parameter nodes (folder structure)" + ) + + with logfire.span("add_parameter_nodes", count=len(nodes_to_add)): + node_rows = [] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(nodes_to_add)} parameter nodes", + total=len(nodes_to_add), + ) + for node in nodes_to_add: + node_rows.append( + { + "id": uuid4(), + "name": node.name, + "label": node.label, + "description": node.description or "", + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(node_rows)} parameter nodes...") + bulk_insert( + session, + "parameter_nodes", + [ + "id", + "name", + "label", + "description", + "tax_benefit_model_version_id", + "created_at", + ], + node_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(nodes_to_add)} parameter nodes" + ) + return db_version diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 72b64ef..6a807e7 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -17,6 +17,7 @@ from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, + ParameterNode, ParameterRead, TaxBenefitModel, TaxBenefitModelVersion, @@ -144,14 +145,27 @@ def get_parameter_children( prefix = f"{parent_path}." if parent_path else "" # Fetch all parameters under this path - query = ( + param_query = ( select(Parameter) .join(TaxBenefitModelVersion) .join(TaxBenefitModel) .where(TaxBenefitModel.name == model_name) .where(Parameter.name.startswith(prefix)) ) - descendants = session.exec(query).all() + descendants = session.exec(param_query).all() + + # Fetch all parameter nodes under this path for labels + node_query = ( + select(ParameterNode) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(ParameterNode.name.startswith(prefix)) + ) + nodes = session.exec(node_query).all() + + # Build a map of node path -> label for quick lookup + node_labels: dict[str, str | None] = {node.name: node.label for node in nodes} # Group by direct child path children_map: dict[str, dict] = {} @@ -187,12 +201,13 @@ def get_parameter_children( info = children_map[path] if info["descendant_count"] > 0: # Node: has children below it + # Priority: 1) parameter_nodes label, 2) direct_param label, 3) path segment direct_param = info["direct_param"] - label = ( - direct_param.label - if direct_param and direct_param.label - else path.rsplit(".", 1)[-1] - ) + label = node_labels.get(path) + if not label and direct_param and direct_param.label: + label = direct_param.label + if not label: + label = path.rsplit(".", 1)[-1] children.append( ParameterChild( path=path, diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 838bfeb..e6bd19e 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -54,6 +54,7 @@ AggregateType, ) from .parameter import Parameter, ParameterCreate, ParameterRead +from .parameter_node import ParameterNode, ParameterNodeCreate, ParameterNodeRead from .parameter_value import ( ParameterValue, ParameterValueCreate, @@ -166,6 +167,9 @@ "IntraDecileImpactRead", "Parameter", "ParameterCreate", + "ParameterNode", + "ParameterNodeCreate", + "ParameterNodeRead", "ParameterRead", "ParameterValue", "ParameterValueCreate", diff --git a/src/policyengine_api/models/parameter_node.py b/src/policyengine_api/models/parameter_node.py new file mode 100644 index 0000000..a1a0f8e --- /dev/null +++ b/src/policyengine_api/models/parameter_node.py @@ -0,0 +1,54 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Field, Relationship, SQLModel + +if TYPE_CHECKING: + from .tax_benefit_model_version import TaxBenefitModelVersion + + +class ParameterNodeBase(SQLModel): + """Base parameter node fields. + + Parameter nodes represent folder/category nodes in the parameter hierarchy + (e.g., "gov", "gov.hmrc", "gov.hmrc.income_tax"). They provide structure + and human-readable labels for navigating the parameter tree, but don't + have values themselves. + """ + + name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')") + label: str | None = Field( + default=None, description="Human-readable label (e.g., 'HMRC')" + ) + description: str | None = Field(default=None, description="Node description") + tax_benefit_model_version_id: UUID = Field( + foreign_key="tax_benefit_model_versions.id" + ) + + +class ParameterNode(ParameterNodeBase, table=True): + """Parameter node database model.""" + + __tablename__ = "parameter_nodes" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship( + back_populates="parameter_nodes" + ) + + +class ParameterNodeCreate(ParameterNodeBase): + """Schema for creating parameter nodes.""" + + pass + + +class ParameterNodeRead(ParameterNodeBase): + """Schema for reading parameter nodes.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/tax_benefit_model_version.py b/src/policyengine_api/models/tax_benefit_model_version.py index 7f066ed..4b36c8f 100644 --- a/src/policyengine_api/models/tax_benefit_model_version.py +++ b/src/policyengine_api/models/tax_benefit_model_version.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from .parameter import Parameter + from .parameter_node import ParameterNode from .tax_benefit_model import TaxBenefitModel from .variable import Variable @@ -34,6 +35,9 @@ class TaxBenefitModelVersion(TaxBenefitModelVersionBase, table=True): parameters: list["Parameter"] = Relationship( back_populates="tax_benefit_model_version" ) + parameter_nodes: list["ParameterNode"] = Relationship( + back_populates="tax_benefit_model_version" + ) class TaxBenefitModelVersionCreate(TaxBenefitModelVersionBase): From 1f88c233565e5bf4a30cd2c8def386f136feb482 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 17:57:34 +0530 Subject: [PATCH 2/6] fix: Sort imports in migration file --- .../versions/20260310_67608331ee8a_add_parameter_nodes_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py index 52859bb..a207a30 100644 --- a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py +++ b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py @@ -7,10 +7,10 @@ """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes +from alembic import op # revision identifiers, used by Alembic. revision: str = '67608331ee8a' From fca97cb69e093b905b41f85806f4ce888122f2c6 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 18:07:29 +0530 Subject: [PATCH 3/6] fix: Apply ruff formatting --- ..._67608331ee8a_add_parameter_nodes_table.py | 29 +++++++++++-------- scripts/seed_models.py | 4 +-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py index a207a30..a7d03f1 100644 --- a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py +++ b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py @@ -5,6 +5,7 @@ Create Date: 2026-03-10 18:29:54.555074 """ + from typing import Sequence, Union import sqlalchemy as sa @@ -13,26 +14,30 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = '67608331ee8a' -down_revision: Union[str, Sequence[str], None] = 'add_modelled_policies' +revision: str = "67608331ee8a" +down_revision: Union[str, Sequence[str], None] = "add_modelled_policies" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - op.create_table('parameter_nodes', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "parameter_nodes", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], + ["tax_benefit_model_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: """Downgrade schema.""" - op.drop_table('parameter_nodes') + op.drop_table("parameter_nodes") diff --git a/scripts/seed_models.py b/scripts/seed_models.py index 36fe943..37a7dbf 100644 --- a/scripts/seed_models.py +++ b/scripts/seed_models.py @@ -356,9 +356,7 @@ def seed_model( nodes_to_add.append(node) seen_node_names.add(node.name) - console.print( - f" Found {len(nodes_to_add)} parameter nodes (folder structure)" - ) + console.print(f" Found {len(nodes_to_add)} parameter nodes (folder structure)") with logfire.span("add_parameter_nodes", count=len(nodes_to_add)): node_rows = [] From 4c3840ab9ef0239c4fb2a5f2bc29f80dc39241b4 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 18:30:07 +0530 Subject: [PATCH 4/6] chore: Add changelog fragment for parameter_nodes --- changelog.d/added/parameter-nodes.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/added/parameter-nodes.md diff --git a/changelog.d/added/parameter-nodes.md b/changelog.d/added/parameter-nodes.md new file mode 100644 index 0000000..e495c95 --- /dev/null +++ b/changelog.d/added/parameter-nodes.md @@ -0,0 +1 @@ +Add parameter_nodes table to store folder/category labels for parameter tree navigation From 3cee553c93ef3bc1af6810f24fb21df9e6267b08 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 18:35:31 +0530 Subject: [PATCH 5/6] fix: Correct changelog fragment naming --- changelog.d/{added/parameter-nodes.md => 124.added.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changelog.d/{added/parameter-nodes.md => 124.added.md} (100%) diff --git a/changelog.d/added/parameter-nodes.md b/changelog.d/124.added.md similarity index 100% rename from changelog.d/added/parameter-nodes.md rename to changelog.d/124.added.md From 3e01c82ffeb9d1b41418d11cdc6ecc427dbd3a07 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 11 Mar 2026 19:35:19 +0530 Subject: [PATCH 6/6] fix: Correct migration down_revision to main branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change down_revision from 'add_modelled_policies' (only exists on feat/modelled-policies branch) to '886921687770' (last migration on main). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../versions/20260310_67608331ee8a_add_parameter_nodes_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py index a7d03f1..7e84d31 100644 --- a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py +++ b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py @@ -15,7 +15,7 @@ # revision identifiers, used by Alembic. revision: str = "67608331ee8a" -down_revision: Union[str, Sequence[str], None] = "add_modelled_policies" +down_revision: Union[str, Sequence[str], None] = "886921687770" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None