diff --git a/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py new file mode 100644 index 0000000..7e84d31 --- /dev/null +++ b/alembic/versions/20260310_67608331ee8a_add_parameter_nodes_table.py @@ -0,0 +1,43 @@ +"""Add parameter_nodes table + +Revision ID: 67608331ee8a +Revises: add_modelled_policies +Create Date: 2026-03-10 18:29:54.555074 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "67608331ee8a" +down_revision: Union[str, Sequence[str], None] = "886921687770" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table( + "parameter_nodes", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], + ["tax_benefit_model_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table("parameter_nodes") diff --git a/changelog.d/124.added.md b/changelog.d/124.added.md new file mode 100644 index 0000000..e495c95 --- /dev/null +++ b/changelog.d/124.added.md @@ -0,0 +1 @@ +Add parameter_nodes table to store folder/category labels for parameter tree navigation diff --git a/scripts/seed_models.py b/scripts/seed_models.py index 9313faa..37a7dbf 100644 --- a/scripts/seed_models.py +++ b/scripts/seed_models.py @@ -1,7 +1,7 @@ -"""Seed tax-benefit models with variables and parameters. +"""Seed tax-benefit models with variables, parameters, and parameter nodes. This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables, -Parameters, and ParameterValues from policyengine.py. +Parameters, ParameterValues, and ParameterNodes from policyengine.py. Usage: python scripts/seed_models.py # Seed UK and US models @@ -55,10 +55,10 @@ def seed_model( variable_whitelist: set[str] | None = None, parameter_prefixes: set[str] | None = None, ) -> TaxBenefitModelVersion: - """Seed a tax-benefit model with its variables and parameters. + """Seed a tax-benefit model with its variables, parameters, and parameter nodes. Args: - model_version: The policyengine.py model version object + model_version: The policyengine.py model version object (with parameter_nodes) session: Database session skip_state_params: Skip US state-level parameters (gov.states.*) variable_whitelist: If provided, only seed variables whose name is in this set @@ -336,6 +336,72 @@ def seed_model( + (f" (skipped {skipped} invalid)" if skipped else "") ) + # Add parameter nodes (folder/category structure) + # Uses model_version.parameter_nodes exposed by policyengine.py + parameter_nodes = model_version.parameter_nodes + + # Filter by prefix if specified (same as parameters) + if parameter_prefixes is not None: + parameter_nodes = [ + n + for n in parameter_nodes + if any(n.name.startswith(prefix) for prefix in parameter_prefixes) + ] + + # Deduplicate by name + seen_node_names = set() + nodes_to_add = [] + for node in parameter_nodes: + if node.name not in seen_node_names: + nodes_to_add.append(node) + seen_node_names.add(node.name) + + console.print(f" Found {len(nodes_to_add)} parameter nodes (folder structure)") + + with logfire.span("add_parameter_nodes", count=len(nodes_to_add)): + node_rows = [] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(nodes_to_add)} parameter nodes", + total=len(nodes_to_add), + ) + for node in nodes_to_add: + node_rows.append( + { + "id": uuid4(), + "name": node.name, + "label": node.label, + "description": node.description or "", + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(node_rows)} parameter nodes...") + bulk_insert( + session, + "parameter_nodes", + [ + "id", + "name", + "label", + "description", + "tax_benefit_model_version_id", + "created_at", + ], + node_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(nodes_to_add)} parameter nodes" + ) + return db_version diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 72b64ef..6a807e7 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -17,6 +17,7 @@ from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId from policyengine_api.models import ( Parameter, + ParameterNode, ParameterRead, TaxBenefitModel, TaxBenefitModelVersion, @@ -144,14 +145,27 @@ def get_parameter_children( prefix = f"{parent_path}." if parent_path else "" # Fetch all parameters under this path - query = ( + param_query = ( select(Parameter) .join(TaxBenefitModelVersion) .join(TaxBenefitModel) .where(TaxBenefitModel.name == model_name) .where(Parameter.name.startswith(prefix)) ) - descendants = session.exec(query).all() + descendants = session.exec(param_query).all() + + # Fetch all parameter nodes under this path for labels + node_query = ( + select(ParameterNode) + .join(TaxBenefitModelVersion) + .join(TaxBenefitModel) + .where(TaxBenefitModel.name == model_name) + .where(ParameterNode.name.startswith(prefix)) + ) + nodes = session.exec(node_query).all() + + # Build a map of node path -> label for quick lookup + node_labels: dict[str, str | None] = {node.name: node.label for node in nodes} # Group by direct child path children_map: dict[str, dict] = {} @@ -187,12 +201,13 @@ def get_parameter_children( info = children_map[path] if info["descendant_count"] > 0: # Node: has children below it + # Priority: 1) parameter_nodes label, 2) direct_param label, 3) path segment direct_param = info["direct_param"] - label = ( - direct_param.label - if direct_param and direct_param.label - else path.rsplit(".", 1)[-1] - ) + label = node_labels.get(path) + if not label and direct_param and direct_param.label: + label = direct_param.label + if not label: + label = path.rsplit(".", 1)[-1] children.append( ParameterChild( path=path, diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 838bfeb..e6bd19e 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -54,6 +54,7 @@ AggregateType, ) from .parameter import Parameter, ParameterCreate, ParameterRead +from .parameter_node import ParameterNode, ParameterNodeCreate, ParameterNodeRead from .parameter_value import ( ParameterValue, ParameterValueCreate, @@ -166,6 +167,9 @@ "IntraDecileImpactRead", "Parameter", "ParameterCreate", + "ParameterNode", + "ParameterNodeCreate", + "ParameterNodeRead", "ParameterRead", "ParameterValue", "ParameterValueCreate", diff --git a/src/policyengine_api/models/parameter_node.py b/src/policyengine_api/models/parameter_node.py new file mode 100644 index 0000000..a1a0f8e --- /dev/null +++ b/src/policyengine_api/models/parameter_node.py @@ -0,0 +1,54 @@ +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 + +from sqlmodel import Field, Relationship, SQLModel + +if TYPE_CHECKING: + from .tax_benefit_model_version import TaxBenefitModelVersion + + +class ParameterNodeBase(SQLModel): + """Base parameter node fields. + + Parameter nodes represent folder/category nodes in the parameter hierarchy + (e.g., "gov", "gov.hmrc", "gov.hmrc.income_tax"). They provide structure + and human-readable labels for navigating the parameter tree, but don't + have values themselves. + """ + + name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')") + label: str | None = Field( + default=None, description="Human-readable label (e.g., 'HMRC')" + ) + description: str | None = Field(default=None, description="Node description") + tax_benefit_model_version_id: UUID = Field( + foreign_key="tax_benefit_model_versions.id" + ) + + +class ParameterNode(ParameterNodeBase, table=True): + """Parameter node database model.""" + + __tablename__ = "parameter_nodes" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # Relationships + tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship( + back_populates="parameter_nodes" + ) + + +class ParameterNodeCreate(ParameterNodeBase): + """Schema for creating parameter nodes.""" + + pass + + +class ParameterNodeRead(ParameterNodeBase): + """Schema for reading parameter nodes.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/tax_benefit_model_version.py b/src/policyengine_api/models/tax_benefit_model_version.py index 7f066ed..4b36c8f 100644 --- a/src/policyengine_api/models/tax_benefit_model_version.py +++ b/src/policyengine_api/models/tax_benefit_model_version.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from .parameter import Parameter + from .parameter_node import ParameterNode from .tax_benefit_model import TaxBenefitModel from .variable import Variable @@ -34,6 +35,9 @@ class TaxBenefitModelVersion(TaxBenefitModelVersionBase, table=True): parameters: list["Parameter"] = Relationship( back_populates="tax_benefit_model_version" ) + parameter_nodes: list["ParameterNode"] = Relationship( + back_populates="tax_benefit_model_version" + ) class TaxBenefitModelVersionCreate(TaxBenefitModelVersionBase):