Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Add parameter_nodes table

Revision ID: 67608331ee8a
Revises: add_modelled_policies
Create Date: 2026-03-10 18:29:54.555074

"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "67608331ee8a"
down_revision: Union[str, Sequence[str], None] = "886921687770"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
op.create_table(
"parameter_nodes",
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False),
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["tax_benefit_model_version_id"],
["tax_benefit_model_versions.id"],
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
"""Downgrade schema."""
op.drop_table("parameter_nodes")
1 change: 1 addition & 0 deletions changelog.d/124.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add parameter_nodes table to store folder/category labels for parameter tree navigation
74 changes: 70 additions & 4 deletions scripts/seed_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Seed tax-benefit models with variables and parameters.
"""Seed tax-benefit models with variables, parameters, and parameter nodes.

This script seeds TaxBenefitModel, TaxBenefitModelVersion, Variables,
Parameters, and ParameterValues from policyengine.py.
Parameters, ParameterValues, and ParameterNodes from policyengine.py.

Usage:
python scripts/seed_models.py # Seed UK and US models
Expand Down Expand Up @@ -55,10 +55,10 @@ def seed_model(
variable_whitelist: set[str] | None = None,
parameter_prefixes: set[str] | None = None,
) -> TaxBenefitModelVersion:
"""Seed a tax-benefit model with its variables and parameters.
"""Seed a tax-benefit model with its variables, parameters, and parameter nodes.

Args:
model_version: The policyengine.py model version object
model_version: The policyengine.py model version object (with parameter_nodes)
session: Database session
skip_state_params: Skip US state-level parameters (gov.states.*)
variable_whitelist: If provided, only seed variables whose name is in this set
Expand Down Expand Up @@ -336,6 +336,72 @@ def seed_model(
+ (f" (skipped {skipped} invalid)" if skipped else "")
)

# Add parameter nodes (folder/category structure)
# Uses model_version.parameter_nodes exposed by policyengine.py
parameter_nodes = model_version.parameter_nodes

# Filter by prefix if specified (same as parameters)
if parameter_prefixes is not None:
parameter_nodes = [
n
for n in parameter_nodes
if any(n.name.startswith(prefix) for prefix in parameter_prefixes)
]

# Deduplicate by name
seen_node_names = set()
nodes_to_add = []
for node in parameter_nodes:
if node.name not in seen_node_names:
nodes_to_add.append(node)
seen_node_names.add(node.name)

console.print(f" Found {len(nodes_to_add)} parameter nodes (folder structure)")

with logfire.span("add_parameter_nodes", count=len(nodes_to_add)):
node_rows = []

with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
task = progress.add_task(
f"Preparing {len(nodes_to_add)} parameter nodes",
total=len(nodes_to_add),
)
for node in nodes_to_add:
node_rows.append(
{
"id": uuid4(),
"name": node.name,
"label": node.label,
"description": node.description or "",
"tax_benefit_model_version_id": db_version.id,
"created_at": datetime.now(timezone.utc),
}
)
progress.advance(task)

console.print(f" Inserting {len(node_rows)} parameter nodes...")
bulk_insert(
session,
"parameter_nodes",
[
"id",
"name",
"label",
"description",
"tax_benefit_model_version_id",
"created_at",
],
node_rows,
)

console.print(
f" [green]✓[/green] Added {len(nodes_to_add)} parameter nodes"
)

return db_version


Expand Down
29 changes: 22 additions & 7 deletions src/policyengine_api/api/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId
from policyengine_api.models import (
Parameter,
ParameterNode,
ParameterRead,
TaxBenefitModel,
TaxBenefitModelVersion,
Expand Down Expand Up @@ -144,14 +145,27 @@ def get_parameter_children(
prefix = f"{parent_path}." if parent_path else ""

# Fetch all parameters under this path
query = (
param_query = (
select(Parameter)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(Parameter.name.startswith(prefix))
)
descendants = session.exec(query).all()
descendants = session.exec(param_query).all()

# Fetch all parameter nodes under this path for labels
node_query = (
select(ParameterNode)
.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == model_name)
.where(ParameterNode.name.startswith(prefix))
)
nodes = session.exec(node_query).all()

# Build a map of node path -> label for quick lookup
node_labels: dict[str, str | None] = {node.name: node.label for node in nodes}

# Group by direct child path
children_map: dict[str, dict] = {}
Expand Down Expand Up @@ -187,12 +201,13 @@ def get_parameter_children(
info = children_map[path]
if info["descendant_count"] > 0:
# Node: has children below it
# Priority: 1) parameter_nodes label, 2) direct_param label, 3) path segment
direct_param = info["direct_param"]
label = (
direct_param.label
if direct_param and direct_param.label
else path.rsplit(".", 1)[-1]
)
label = node_labels.get(path)
if not label and direct_param and direct_param.label:
label = direct_param.label
if not label:
label = path.rsplit(".", 1)[-1]
children.append(
ParameterChild(
path=path,
Expand Down
4 changes: 4 additions & 0 deletions src/policyengine_api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
AggregateType,
)
from .parameter import Parameter, ParameterCreate, ParameterRead
from .parameter_node import ParameterNode, ParameterNodeCreate, ParameterNodeRead
from .parameter_value import (
ParameterValue,
ParameterValueCreate,
Expand Down Expand Up @@ -166,6 +167,9 @@
"IntraDecileImpactRead",
"Parameter",
"ParameterCreate",
"ParameterNode",
"ParameterNodeCreate",
"ParameterNodeRead",
"ParameterRead",
"ParameterValue",
"ParameterValueCreate",
Expand Down
54 changes: 54 additions & 0 deletions src/policyengine_api/models/parameter_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import UUID, uuid4

from sqlmodel import Field, Relationship, SQLModel

if TYPE_CHECKING:
from .tax_benefit_model_version import TaxBenefitModelVersion


class ParameterNodeBase(SQLModel):
"""Base parameter node fields.

Parameter nodes represent folder/category nodes in the parameter hierarchy
(e.g., "gov", "gov.hmrc", "gov.hmrc.income_tax"). They provide structure
and human-readable labels for navigating the parameter tree, but don't
have values themselves.
"""

name: str = Field(description="Full path of the node (e.g., 'gov.hmrc')")
label: str | None = Field(
default=None, description="Human-readable label (e.g., 'HMRC')"
)
description: str | None = Field(default=None, description="Node description")
tax_benefit_model_version_id: UUID = Field(
foreign_key="tax_benefit_model_versions.id"
)


class ParameterNode(ParameterNodeBase, table=True):
"""Parameter node database model."""

__tablename__ = "parameter_nodes"

id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

# Relationships
tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship(
back_populates="parameter_nodes"
)


class ParameterNodeCreate(ParameterNodeBase):
"""Schema for creating parameter nodes."""

pass


class ParameterNodeRead(ParameterNodeBase):
"""Schema for reading parameter nodes."""

id: UUID
created_at: datetime
4 changes: 4 additions & 0 deletions src/policyengine_api/models/tax_benefit_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

if TYPE_CHECKING:
from .parameter import Parameter
from .parameter_node import ParameterNode
from .tax_benefit_model import TaxBenefitModel
from .variable import Variable

Expand Down Expand Up @@ -34,6 +35,9 @@ class TaxBenefitModelVersion(TaxBenefitModelVersionBase, table=True):
parameters: list["Parameter"] = Relationship(
back_populates="tax_benefit_model_version"
)
parameter_nodes: list["ParameterNode"] = Relationship(
back_populates="tax_benefit_model_version"
)


class TaxBenefitModelVersionCreate(TaxBenefitModelVersionBase):
Expand Down
Loading