From 392076f3f9e20a16686f309cfa0275431ac9b349 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 9 Mar 2026 17:48:57 +0100 Subject: [PATCH 01/10] feat: Standardize all endpoints on country_id instead of tax_benefit_model_name All API endpoints now accept country_id ('us' or 'uk') instead of various tax_benefit_model_name formats. Adds shared model_resolver service, converts Convention B (raw model name) and Convention C (Literal with underscore-to-hyphen conversion) endpoints, renames Household DB column via Alembic migration, and removes the old _get_model_version() resolver. Co-Authored-By: Claude Opus 4.6 --- ...8049d_rename_tax_benefit_model_name_to_.py | 69 ++++++++++++++ src/policyengine_api/agent_sandbox.py | 14 +-- src/policyengine_api/api/analysis.py | 95 ++++++++----------- src/policyengine_api/api/datasets.py | 13 +-- src/policyengine_api/api/household.py | 37 ++++---- .../api/household_analysis.py | 36 +++---- src/policyengine_api/api/households.py | 10 +- src/policyengine_api/api/parameters.py | 14 +-- src/policyengine_api/api/regions.py | 42 ++++---- src/policyengine_api/api/simulations.py | 24 +++-- src/policyengine_api/api/variables.py | 14 +-- src/policyengine_api/models/household.py | 8 +- src/policyengine_api/models/household_job.py | 2 +- .../services/model_resolver.py | 46 +++++++++ 14 files changed, 267 insertions(+), 157 deletions(-) create mode 100644 alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py create mode 100644 src/policyengine_api/services/model_resolver.py diff --git a/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py new file mode 100644 index 0000000..5923b22 --- /dev/null +++ b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py @@ -0,0 +1,69 @@ +"""rename_tax_benefit_model_name_to_country_id + +Revision ID: 62385cd8049d +Revises: 886921687770 +Create Date: 2026-03-09 16:48:30.899791 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '62385cd8049d' +down_revision: Union[str, Sequence[str], None] = '886921687770' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema: rename tax_benefit_model_name → country_id with data migration.""" + # 1. Add country_id columns (nullable initially) + op.add_column('households', sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column('household_jobs', sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + + # 2. Populate country_id from tax_benefit_model_name + op.execute(""" + UPDATE households SET country_id = CASE + WHEN tax_benefit_model_name LIKE '%_us' OR tax_benefit_model_name LIKE '%-us' THEN 'us' + WHEN tax_benefit_model_name LIKE '%_uk' OR tax_benefit_model_name LIKE '%-uk' THEN 'uk' + ELSE 'us' + END + """) + op.execute(""" + UPDATE household_jobs SET country_id = CASE + WHEN tax_benefit_model_name LIKE '%_us' OR tax_benefit_model_name LIKE '%-us' THEN 'us' + WHEN tax_benefit_model_name LIKE '%_uk' OR tax_benefit_model_name LIKE '%-uk' THEN 'uk' + ELSE 'us' + END + """) + + # 3. Make country_id non-nullable + op.alter_column('households', 'country_id', nullable=False) + op.alter_column('household_jobs', 'country_id', nullable=False) + + # 4. Drop old columns + op.drop_column('households', 'tax_benefit_model_name') + op.drop_column('household_jobs', 'tax_benefit_model_name') + + +def downgrade() -> None: + """Downgrade schema: restore tax_benefit_model_name from country_id.""" + # 1. Re-add tax_benefit_model_name columns (nullable initially) + op.add_column('households', sa.Column('tax_benefit_model_name', sa.VARCHAR(), nullable=True)) + op.add_column('household_jobs', sa.Column('tax_benefit_model_name', sa.VARCHAR(), nullable=True)) + + # 2. Populate from country_id + op.execute("UPDATE households SET tax_benefit_model_name = 'policyengine_' || country_id") + op.execute("UPDATE household_jobs SET tax_benefit_model_name = 'policyengine_' || country_id") + + # 3. Make non-nullable + op.alter_column('households', 'tax_benefit_model_name', nullable=False) + op.alter_column('household_jobs', 'tax_benefit_model_name', nullable=False) + + # 4. Drop country_id columns + op.drop_column('households', 'country_id') + op.drop_column('household_jobs', 'country_id') diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index bcf1ab0..0aa578a 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -53,9 +53,9 @@ def configure_logfire(traceparent: str | None = None): ## CRITICAL: Always filter by country -When searching for parameters or datasets, ALWAYS include tax_benefit_model_name: -- "policyengine-uk" for UK questions -- "policyengine-us" for US questions +When searching for parameters or datasets, ALWAYS include country_id: +- "uk" for UK questions +- "us" for US questions Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results and waste turns finding the right ones. @@ -66,14 +66,14 @@ def configure_logfire(traceparent: str | None = None): - Poll GET /household/calculate/{job_id} until completed 2. **Parameter lookup**: - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk (ALWAYS include country filter) + - GET /parameters/?search=...&country_id=uk (ALWAYS include country filter) - GET /parameter-values/?parameter_id=...¤t=true for the current value 3. **Economic impact analysis** (budget impact, decile impacts): - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk to find parameter_id + - GET /parameters/?search=...&country_id=uk to find parameter_id - POST /policies/ to create reform with parameter_values - - GET /datasets/?tax_benefit_model_name=policyengine-uk to find dataset_id - - POST /analysis/economic-impact with tax_benefit_model_name, policy_id and dataset_id + - GET /datasets/?country_id=uk to find dataset_id + - POST /analysis/economic-impact with country_id, policy_id and dataset_id - GET /analysis/economic-impact/{report_id} for results (includes decile_impacts and program_statistics) ## Response formatting diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index d2e5864..dc7d5b3 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -60,7 +60,12 @@ TaxBenefitModel, TaxBenefitModelVersion, ) +from policyengine_api.config.constants import CountryId from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import ( + resolve_country_model, + resolve_model_name, +) def get_traceparent() -> str | None: @@ -130,21 +135,21 @@ class EconomicImpactRequest(BaseModel): Example with dataset_id: { - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": "uuid-from-datasets-endpoint", "policy_id": "uuid-of-reform-policy" } Example with region: { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "state/ca", "policy_id": "uuid-of-reform-policy" } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) dataset_id: UUID | None = Field( default=None, @@ -215,32 +220,6 @@ class EconomicImpactResponse(BaseModel): intra_wealth_decile: list[IntraDecileImpactRead] | None = None -def _get_model_version( - tax_benefit_model_name: str, session: Session -) -> TaxBenefitModelVersion: - """Get the latest tax benefit model version.""" - model_name = tax_benefit_model_name.replace("_", "-") - - model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) - ).first() - if not model: - raise HTTPException( - status_code=404, detail=f"Tax benefit model {model_name} not found" - ) - - version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - if not version: - raise HTTPException( - status_code=404, detail=f"No version found for model {model_name}" - ) - - return version - def _get_deterministic_simulation_id( simulation_type: SimulationType, @@ -1114,13 +1093,14 @@ def build_dynamic(dynamic_id): def _trigger_economy_comparison( job_id: str, - tax_benefit_model_name: str, + country_id: str, session: Session | None = None, modules: list[str] | None = None, ) -> None: """Trigger economy comparison analysis (local or Modal). Args: + country_id: Country code ('us' or 'uk'). modules: Optional list of module names to run. If None, runs all. """ from policyengine_api.config import settings @@ -1129,7 +1109,7 @@ def _trigger_economy_comparison( if not settings.agent_use_modal and session is not None: # Run locally - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": _run_local_economy_comparison_uk(job_id, session, modules=modules) else: _run_local_economy_comparison_us(job_id, session, modules=modules) @@ -1137,7 +1117,7 @@ def _trigger_economy_comparison( # Use Modal (modules param passed for future selective computation) import modal - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": fn = modal.Function.from_name( "policyengine", "economy_comparison_uk", @@ -1184,7 +1164,7 @@ def _resolve_dataset_and_region( """ if request.region: # Look up region by code - model_name = request.tax_benefit_model_name.replace("_", "-") + model_name = resolve_model_name(request.country_id) region = session.exec( select(Region) .join(TaxBenefitModel) @@ -1195,7 +1175,7 @@ def _resolve_dataset_and_region( if not region: raise HTTPException( status_code=404, - detail=f"Region '{request.region}' not found for model {model_name}", + detail=f"Region '{request.region}' not found for country {request.country_id}", ) # Resolve dataset from join table, filtered by year if provided @@ -1262,7 +1242,7 @@ def economic_impact( ) # Get model version - model_version = _get_model_version(request.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(request.country_id, session) # Get or create simulations using the resolved dataset baseline_sim = _get_or_create_simulation( @@ -1294,7 +1274,7 @@ def economic_impact( ) # Get or create report - label = f"Economic impact: {request.tax_benefit_model_name}" + label = f"Economic impact: {request.country_id}" if request.policy_id: label += f" (policy {request.policy_id})" @@ -1306,7 +1286,7 @@ def economic_impact( if report.status == ReportStatus.PENDING: with logfire.span("trigger_economy_comparison", job_id=str(report.id)): _trigger_economy_comparison( - str(report.id), request.tax_benefit_model_name, session + str(report.id), request.country_id, session ) return _build_response(report, baseline_sim, reform_sim, session, region) @@ -1341,17 +1321,11 @@ def get_economic_impact_status( # POST /analysis/economy-custom — run selected economy modules # --------------------------------------------------------------------------- -_MODEL_TO_COUNTRY = { - "policyengine_uk": "uk", - "policyengine_us": "us", -} - - class EconomyCustomRequest(BaseModel): """Request body for custom economy analysis with selected modules.""" - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) dataset_id: UUID | None = Field( default=None, @@ -1430,16 +1404,14 @@ def economy_custom( See GET /analysis/options for available module names. """ - country = _MODEL_TO_COUNTRY[request.tax_benefit_model_name] - try: - validate_modules(request.modules, country) + validate_modules(request.modules, request.country_id) except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) # Reuse the same request model for dataset/region resolution impact_request = EconomicImpactRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, dataset_id=request.dataset_id, region=request.region, policy_id=request.policy_id, @@ -1461,7 +1433,7 @@ def economy_custom( else None ) - model_version = _get_model_version(request.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(request.country_id, session) baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, @@ -1491,7 +1463,7 @@ def economy_custom( year=dataset.year, ) - label = f"Custom analysis: {request.tax_benefit_model_name}" + label = f"Custom analysis: {request.country_id}" if request.policy_id: label += f" (policy {request.policy_id})" @@ -1503,7 +1475,7 @@ def economy_custom( with logfire.span("trigger_economy_comparison", job_id=str(report.id)): _trigger_economy_comparison( str(report.id), - request.tax_benefit_model_name, + request.country_id, session, modules=request.modules, ) @@ -1599,7 +1571,7 @@ def rerun_report( if not baseline_sim: raise HTTPException(status_code=400, detail="Report has no baseline simulation") - # 3. Derive tax_benefit_model_name from simulation → model version → model + # 3. Derive country_id from simulation → model version → model model_version = session.get( TaxBenefitModelVersion, baseline_sim.tax_benefit_model_version_id ) @@ -1610,7 +1582,16 @@ def rerun_report( if not model: raise HTTPException(status_code=500, detail="Tax-benefit model not found") - tax_benefit_model_name = model.name.replace("-", "_") + # Reverse-lookup: model.name is "policyengine-us" → country_id is "us" + from policyengine_api.config.constants import COUNTRY_MODEL_NAMES + + country_id = next( + (k for k, v in COUNTRY_MODEL_NAMES.items() if v == model.name), None + ) + if not country_id: + raise HTTPException( + status_code=500, detail=f"Unknown model name: {model.name}" + ) # 4. Delete all result records for this report result_tables = [ @@ -1648,10 +1629,10 @@ def rerun_report( if is_economy: with logfire.span("rerun_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison(str(report.id), tax_benefit_model_name, session) + _trigger_economy_comparison(str(report.id), country_id, session) elif is_household: with logfire.span("rerun_household_impact", job_id=str(report.id)): - _trigger_household_impact(str(report.id), tax_benefit_model_name, session) + _trigger_household_impact(str(report.id), country_id, session) else: raise HTTPException( status_code=400, diff --git a/src/policyengine_api/api/datasets.py b/src/policyengine_api/api/datasets.py index 82540b7..f7ac2a8 100644 --- a/src/policyengine_api/api/datasets.py +++ b/src/policyengine_api/api/datasets.py @@ -11,15 +11,17 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import Dataset, DatasetRead, TaxBenefitModel from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/datasets", tags=["datasets"]) @router.get("/", response_model=List[DatasetRead]) def list_datasets( - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, session: Session = Depends(get_session), ): """List available datasets. @@ -28,15 +30,14 @@ def list_datasets( Each dataset represents population microdata for a specific country and year. Args: - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK datasets. - Use "policyengine-us" for US datasets. + country_id: Filter by country ("us" or "uk"). """ query = select(Dataset) - if tax_benefit_model_name: + if country_id: + model_name = resolve_model_name(country_id) query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name + TaxBenefitModel.name == model_name ) datasets = session.exec(query).all() diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 5fda4b4..4fcaab0 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -5,7 +5,7 @@ """ import math -from typing import Any, Literal +from typing import Any from uuid import UUID import logfire @@ -20,6 +20,7 @@ HouseholdJobStatus, Policy, ) +from policyengine_api.config.constants import CountryId from policyengine_api.services.database import get_session @@ -61,7 +62,7 @@ class HouseholdCalculateRequest(BaseModel): Example US request (single household, simple): { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"employment_income": 70000, "age": 40}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -70,7 +71,7 @@ class HouseholdCalculateRequest(BaseModel): Example US request (multiple households): { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"person_id": 0, "person_household_id": 0, "person_tax_unit_id": 0, "age": 40, "employment_income": 70000}, {"person_id": 1, "person_household_id": 1, "person_tax_unit_id": 1, "age": 30, "employment_income": 50000} @@ -88,15 +89,15 @@ class HouseholdCalculateRequest(BaseModel): Example UK request: { - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"employment_income": 50000, "age": 30}], "household": [{}], "year": 2026 } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) people: list[dict[str, Any]] = Field( description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." @@ -173,7 +174,7 @@ class HouseholdImpactRequest(BaseModel): Example: { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"employment_income": 70000, "age": 40}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -182,8 +183,8 @@ class HouseholdImpactRequest(BaseModel): } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) people: list[dict[str, Any]] = Field( description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." @@ -726,7 +727,7 @@ def _trigger_modal_household( if not settings.agent_use_modal and session is not None: # Run locally - if request.tax_benefit_model_name == "policyengine_uk": + if request.country_id == "uk": _run_local_household_uk( job_id=job_id, people=request.people, @@ -755,7 +756,7 @@ def _trigger_modal_household( traceparent = get_traceparent() - if request.tax_benefit_model_name == "policyengine_uk": + if request.country_id == "uk": fn = modal.Function.from_name( "policyengine", "simulate_household_uk", @@ -864,7 +865,7 @@ def calculate_household( """ with logfire.span( "create_household_job", - model=request.tax_benefit_model_name, + model=request.country_id, num_people=len(request.people), year=request.year, has_policy=request.policy_id is not None, @@ -876,7 +877,7 @@ def calculate_household( # Create job record job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -995,7 +996,7 @@ def calculate_household_impact_comparison( """ with logfire.span( "create_household_impact_job", - model=request.tax_benefit_model_name, + model=request.country_id, num_people=len(request.people), year=request.year, has_policy=request.policy_id is not None, @@ -1006,7 +1007,7 @@ def calculate_household_impact_comparison( # Create baseline job (no policy) baseline_job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -1026,7 +1027,7 @@ def calculate_household_impact_comparison( # Create reform job (with policy) reform_job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -1055,7 +1056,7 @@ def calculate_household_impact_comparison( # Trigger Modal functions for both baseline_request = HouseholdCalculateRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, people=request.people, benunit=request.benunit, marital_unit=request.marital_unit, @@ -1068,7 +1069,7 @@ def calculate_household_impact_comparison( dynamic_id=request.dynamic_id, ) reform_request = HouseholdCalculateRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, people=request.people, benunit=request.benunit, marital_unit=request.marital_unit, diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 9c22234..82fce2a 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -32,9 +32,9 @@ SimulationType, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_country_model from .analysis import ( - _get_model_version, _get_or_create_report, _get_or_create_simulation, ) @@ -81,9 +81,9 @@ class CountryConfig: ) -def get_country_config(tax_benefit_model_name: str) -> CountryConfig: - """Get country configuration from model name.""" - if tax_benefit_model_name == "policyengine_uk": +def get_country_config(country_id: str) -> CountryConfig: + """Get country configuration from country_id.""" + if country_id == "uk": return UK_CONFIG return US_CONFIG @@ -136,9 +136,9 @@ def calculate_us_household( ) -def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: +def get_calculator(country_id: str) -> HouseholdCalculator: """Get the appropriate calculator for a country.""" - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": return calculate_uk_household return calculate_us_household @@ -331,7 +331,7 @@ def run_household_simulation(simulation_id: UUID, session: Session) -> None: mark_simulation_running(simulation, session) try: - calculator = get_calculator(household.tax_benefit_model_name) + calculator = get_calculator(household.country_id) result = calculator(household.household_data, household.year, policy_data) mark_simulation_completed(simulation, result, session) except Exception as e: @@ -428,7 +428,7 @@ def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: session.commit() try: - calculator = get_calculator(household.tax_benefit_model_name) + calculator = get_calculator(household.country_id) result = calculator(household.household_data, household.year, policy_data) simulation.household_result = result @@ -446,9 +446,13 @@ def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: def _trigger_household_impact( - report_id: str, tax_benefit_model_name: str, session: Session | None = None + report_id: str, country_id: str, session: Session | None = None ) -> None: - """Trigger household impact calculation (local or Modal based on settings).""" + """Trigger household impact calculation (local or Modal based on settings). + + Args: + country_id: Country code ('us' or 'uk'). + """ from policyengine_api.config import settings traceparent = get_traceparent() @@ -460,7 +464,7 @@ def _trigger_household_impact( # Use Modal import modal - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": fn = modal.Function.from_name( "policyengine", "household_impact_uk", @@ -603,7 +607,7 @@ def _compute_impact_if_comparison( if not household: return None - config = get_country_config(household.tax_benefit_model_name) + config = get_country_config(household.country_id) return compute_household_impact(baseline_result, reform_result, config) @@ -658,7 +662,7 @@ def household_impact( household = validate_household_exists(request.household_id, session) validate_policy_exists(request.policy_id, session) - model_version = _get_model_version(household.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(household.country_id, session) baseline_sim = _create_baseline_simulation( household, model_version.id, request.dynamic_id, session @@ -671,16 +675,14 @@ def household_impact( report = _get_or_create_report( baseline_sim_id=baseline_sim.id, reform_sim_id=reform_sim.id if reform_sim else None, - label=f"Household impact: {household.tax_benefit_model_name}", + label=f"Household impact: {household.country_id}", report_type=report_type, session=session, ) if report.status == ReportStatus.PENDING: with logfire.span("trigger_household_impact", job_id=str(report.id)): - _trigger_household_impact( - str(report.id), household.tax_benefit_model_name, session - ) + _trigger_household_impact(str(report.id), household.country_id, session) return build_household_response(report, baseline_sim, reform_sim, session) diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py index fdee1f7..ea64f9d 100644 --- a/src/policyengine_api/api/households.py +++ b/src/policyengine_api/api/households.py @@ -45,7 +45,7 @@ def _to_read(record: Household) -> HouseholdRead: data = record.household_data return HouseholdRead( id=record.id, - tax_benefit_model_name=record.tax_benefit_model_name, + country_id=record.country_id, year=record.year, label=record.label, people=data["people"], @@ -69,7 +69,7 @@ def create_household(body: HouseholdCreate, session: Session = Depends(get_sessi or /household/impact to run simulations. """ record = Household( - tax_benefit_model_name=body.tax_benefit_model_name, + country_id=body.country_id, year=body.year, label=body.label, household_data=_pack_household_data(body), @@ -82,15 +82,15 @@ def create_household(body: HouseholdCreate, session: Session = Depends(get_sessi @router.get("/", response_model=list[HouseholdRead]) def list_households( - tax_benefit_model_name: str | None = None, + country_id: str | None = None, limit: int = Query(default=50, le=200), offset: int = Query(default=0, ge=0), session: Session = Depends(get_session), ): """List stored households with optional filtering.""" query = select(Household) - if tax_benefit_model_name is not None: - query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + if country_id is not None: + query = query.where(Household.country_id == country_id) query = query.offset(offset).limit(limit) records = session.exec(query).all() return [_to_read(r) for r in records] diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 6a807e7..d3ee710 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -23,6 +23,7 @@ TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -32,7 +33,7 @@ def list_parameters( skip: int = 0, limit: int = 100, search: str | None = None, - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -42,18 +43,17 @@ def list_parameters( Args: search: Filter by parameter name, label, or description. - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK parameters. - Use "policyengine-us" for US parameters. + country_id: Filter by country ("us" or "uk"). """ query = select(Parameter) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: + # Filter by country + if country_id: + model_name = resolve_model_name(country_id) query = ( query.join(TaxBenefitModelVersion) .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) + .where(TaxBenefitModel.name == model_name) ) if search: diff --git a/src/policyengine_api/api/regions.py b/src/policyengine_api/api/regions.py index 1d0a34e..88e9597 100644 --- a/src/policyengine_api/api/regions.py +++ b/src/policyengine_api/api/regions.py @@ -11,20 +11,22 @@ from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import Region, RegionRead, TaxBenefitModel from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/regions", tags=["regions"]) @router.get("/", response_model=List[RegionRead]) def list_regions( + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), tax_benefit_model_id: UUID | None = Query( None, description="Filter by tax-benefit model ID" ), - tax_benefit_model_name: str | None = Query( - None, description="Filter by tax-benefit model name (e.g., 'policyengine-us')" - ), region_type: str | None = Query( None, description="Filter by region type (e.g., 'state', 'congressional_district')", @@ -37,18 +39,19 @@ def list_regions( Each region represents a geographic area with an associated dataset. Args: - tax_benefit_model_id: Filter by tax-benefit model UUID. - tax_benefit_model_name: Filter by model name (e.g., "policyengine-us"). + country_id: Filter by country ("us" or "uk"). + tax_benefit_model_id: Filter by tax-benefit model UUID (alternative to country_id). region_type: Filter by region type (e.g., "state", "congressional_district"). """ query = select(Region) - if tax_benefit_model_id: - query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) - elif tax_benefit_model_name: + if country_id: + model_name = resolve_model_name(country_id) query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name + TaxBenefitModel.name == model_name ) + elif tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) if region_type: query = query.where(Region.region_type == region_type) @@ -69,12 +72,12 @@ def get_region(region_id: UUID, session: Session = Depends(get_session)): @router.get("/by-code/{region_code:path}", response_model=RegionRead) def get_region_by_code( region_code: str, + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), tax_benefit_model_id: UUID | None = Query( None, - description="Tax-benefit model ID (required if multiple models have same region code)", - ), - tax_benefit_model_name: str | None = Query( - None, description="Tax-benefit model name (e.g., 'policyengine-us')" + description="Tax-benefit model ID (alternative to country_id)", ), session: Session = Depends(get_session), ): @@ -84,17 +87,18 @@ def get_region_by_code( Args: region_code: The region code (e.g., "state/ca", "us"). - tax_benefit_model_id: Filter by tax-benefit model UUID. - tax_benefit_model_name: Filter by model name. + country_id: Filter by country ("us" or "uk"). + tax_benefit_model_id: Filter by tax-benefit model UUID (alternative to country_id). """ query = select(Region).where(Region.code == region_code) - if tax_benefit_model_id: - query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) - elif tax_benefit_model_name: + if country_id: + model_name = resolve_model_name(country_id) query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name + TaxBenefitModel.name == model_name ) + elif tax_benefit_model_id: + query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) region = session.exec(query).first() if not region: diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 8477a68..b3f40d4 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -8,7 +8,7 @@ For baseline-vs-reform comparisons, use the /analysis/ endpoints instead. """ -from typing import Any, List, Literal +from typing import Any, List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query @@ -27,11 +27,15 @@ SimulationType, TaxBenefitModel, ) +from policyengine_api.config.constants import CountryId from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import ( + resolve_country_model, + resolve_model_name, +) from .analysis import ( RegionInfo, - _get_model_version, _get_or_create_simulation, ) @@ -71,8 +75,8 @@ class HouseholdSimulationResponse(BaseModel): class EconomySimulationRequest(BaseModel): """Request body for creating an economy simulation.""" - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) region: str | None = Field( default=None, @@ -122,7 +126,7 @@ class EconomySimulationResponse(BaseModel): def _resolve_economy_dataset( - tax_benefit_model_name: str, + country_id: str, region_code: str | None, dataset_id: UUID | None, session: Session, @@ -135,7 +139,7 @@ def _resolve_economy_dataset( otherwise the latest available year is used. """ if region_code: - model_name = tax_benefit_model_name.replace("_", "-") + model_name = resolve_model_name(country_id) region = session.exec( select(Region) .join(TaxBenefitModel) @@ -145,7 +149,7 @@ def _resolve_economy_dataset( if not region: raise HTTPException( status_code=404, - detail=f"Region '{region_code}' not found for model {model_name}", + detail=f"Region '{region_code}' not found for country {country_id}", ) # Resolve dataset from join table @@ -274,7 +278,7 @@ def create_household_simulation( ) # Get model version - model_version = _get_model_version(household.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(household.country_id, session) # Get or create simulation (deterministic UUID) simulation = _get_or_create_simulation( @@ -328,7 +332,7 @@ def create_economy_simulation( """ # Resolve dataset and region dataset, region = _resolve_economy_dataset( - request.tax_benefit_model_name, + request.country_id, request.region, request.dataset_id, session, @@ -349,7 +353,7 @@ def create_economy_simulation( filter_value = region.filter_value if region and region.requires_filter else None # Get model version - model_version = _get_model_version(request.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(request.country_id, session) # Get or create simulation (deterministic UUID) simulation = _get_or_create_simulation( diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index e592820..efdc32f 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -20,6 +20,7 @@ VariableRead, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/variables", tags=["variables"]) @@ -29,7 +30,7 @@ def list_variables( skip: int = 0, limit: int = 100, search: str | None = None, - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -40,18 +41,17 @@ def list_variables( Args: search: Filter by variable name, label, or description. - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK variables. - Use "policyengine-us" for US variables. + country_id: Filter by country ("us" or "uk"). """ query = select(Variable) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: + # Filter by country + if country_id: + model_name = resolve_model_name(country_id) query = ( query.join(TaxBenefitModelVersion) .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) + .where(TaxBenefitModel.name == model_name) ) if search: diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py index 8a96850..194b574 100644 --- a/src/policyengine_api/models/household.py +++ b/src/policyengine_api/models/household.py @@ -1,17 +1,19 @@ """Stored household definition model.""" from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any from uuid import UUID, uuid4 from sqlalchemy import JSON from sqlmodel import Column, Field, SQLModel +from policyengine_api.config.constants import CountryId + class HouseholdBase(SQLModel): """Base household fields.""" - tax_benefit_model_name: str + country_id: str year: int label: str | None = None household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) @@ -34,7 +36,7 @@ class HouseholdCreate(SQLModel): people as an array, entity groups as optional dicts. """ - tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + country_id: CountryId year: int label: str | None = None people: list[dict[str, Any]] diff --git a/src/policyengine_api/models/household_job.py b/src/policyengine_api/models/household_job.py index ac853a9..8294314 100644 --- a/src/policyengine_api/models/household_job.py +++ b/src/policyengine_api/models/household_job.py @@ -21,7 +21,7 @@ class HouseholdJobStatus(str, Enum): class HouseholdJobBase(SQLModel): """Base household job fields.""" - tax_benefit_model_name: str + country_id: str request_data: dict[str, Any] = Field(sa_column=Column(JSON)) policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") diff --git a/src/policyengine_api/services/model_resolver.py b/src/policyengine_api/services/model_resolver.py new file mode 100644 index 0000000..abeb92b --- /dev/null +++ b/src/policyengine_api/services/model_resolver.py @@ -0,0 +1,46 @@ +"""Shared resolver for country_id → tax-benefit model + latest version.""" + +from fastapi import HTTPException +from sqlmodel import Session, select + +from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.models.tax_benefit_model import TaxBenefitModel +from policyengine_api.models.tax_benefit_model_version import ( + TaxBenefitModelVersion, +) + + +def resolve_model_name(country_id: CountryId) -> str: + """Resolve country_id → DB model name (with hyphens).""" + return COUNTRY_MODEL_NAMES[country_id] + + +def resolve_country_model( + country_id: CountryId, session: Session +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Resolve country_id → (model, latest_version). + + Explicitly selects the most recent version by created_at DESC. + """ + model_name = COUNTRY_MODEL_NAMES[country_id] + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, detail=f"Model not found for country: {country_id}" + ) + + version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not version: + raise HTTPException( + status_code=404, + detail=f"No version found for model: {model_name}", + ) + + return model, version From ace856971376cb1b476dbd15275e80cb3d3a223a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 9 Mar 2026 18:44:34 +0100 Subject: [PATCH 02/10] feat: Default metadata endpoints to latest model version with version pinning All metadata endpoints (parameters, variables, parameter-values) now resolve country_id to the latest model version by default, preventing duplicate results when multiple versions exist. Adds optional tax_benefit_model_version_id param for pinning to a specific version. Supersedes #99 by combining version resolution with the country_id standardization. Closes #98 Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/parameter_values.py | 18 ++++- src/policyengine_api/api/parameters.py | 74 +++++++++---------- src/policyengine_api/api/variables.py | 41 +++++----- .../services/model_resolver.py | 30 ++++++++ 4 files changed, 102 insertions(+), 61 deletions(-) diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 4668ab8..11ccded 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -12,8 +12,10 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, or_, select -from policyengine_api.models import ParameterValue, ParameterValueRead +from policyengine_api.config.constants import CountryId +from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/parameter-values", tags=["parameter-values"]) @@ -23,6 +25,8 @@ def list_parameter_values( parameter_id: UUID | None = None, policy_id: UUID | None = None, current: bool = False, + country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, skip: int = 0, limit: int = 100, session: Session = Depends(get_session), @@ -37,6 +41,10 @@ def list_parameter_values( policy_id: Filter by a specific policy reform. current: If true, only return values that are currently in effect (start_date <= now and (end_date is null or end_date > now)). + country_id: Filter to values belonging to parameters from this country. + Defaults to the latest model version. + tax_benefit_model_version_id: Filter to values from a specific model + version. Takes precedence over country_id. """ query = select(ParameterValue) @@ -46,6 +54,14 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) + version_id = resolve_version_id( + country_id, tax_benefit_model_version_id, session + ) + if version_id: + query = query.join(Parameter).where( + Parameter.tax_benefit_model_version_id == version_id + ) + if current: now = datetime.now(timezone.utc) query = query.where( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index d3ee710..3075719 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -14,16 +14,14 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Parameter, ParameterNode, ParameterRead, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session -from policyengine_api.services.model_resolver import resolve_model_name +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -34,6 +32,7 @@ def list_parameters( limit: int = 100, search: str | None = None, country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -44,20 +43,19 @@ def list_parameters( Args: search: Filter by parameter name, label, or description. country_id: Filter by country ("us" or "uk"). + Defaults to the latest model version. + tax_benefit_model_version_id: Pin to a specific model version. + Takes precedence over country_id. """ query = select(Parameter) - # Filter by country - if country_id: - model_name = resolve_model_name(country_id) - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - ) + version_id = resolve_version_id( + country_id, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Parameter.name.ilike(search_pattern) @@ -77,6 +75,7 @@ class ParameterByNameRequest(BaseModel): names: list[str] country_id: CountryId + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[ParameterRead]) @@ -96,18 +95,15 @@ def get_parameters_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.in_(request.names)) - .order_by(Parameter.name) + version_id = resolve_version_id( + request.country_id, request.tax_benefit_model_version_id, session ) - return session.exec(query).all() + query = select(Parameter).where(Parameter.name.in_(request.names)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Parameter.name)).all() class ParameterChild(BaseModel): @@ -133,6 +129,9 @@ def get_parameter_children( parent_path: str = Query( default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" ), + tax_benefit_model_version_id: UUID | None = Query( + default=None, description="Optional specific model version ID" + ), session: Session = Depends(get_session), ) -> ParameterChildrenResponse: """Get direct children of a parameter path for tree navigation. @@ -141,27 +140,26 @@ def get_parameter_children( parameters (with full metadata). Use this to lazily load the parameter tree one level at a time. """ - model_name = COUNTRY_MODEL_NAMES[country_id] + version_id = resolve_version_id( + country_id, tax_benefit_model_version_id, session + ) + prefix = f"{parent_path}." if parent_path else "" # Fetch all parameters under this path - param_query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.startswith(prefix)) - ) + param_query = select(Parameter).where(Parameter.name.startswith(prefix)) + if version_id: + param_query = param_query.where( + Parameter.tax_benefit_model_version_id == version_id + ) descendants = session.exec(param_query).all() # Fetch all parameter nodes under this path for labels - node_query = ( - select(ParameterNode) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(ParameterNode.name.startswith(prefix)) - ) + node_query = select(ParameterNode).where(ParameterNode.name.startswith(prefix)) + if version_id: + node_query = node_query.where( + ParameterNode.tax_benefit_model_version_id == version_id + ) nodes = session.exec(node_query).all() # Build a map of node path -> label for quick lookup diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index efdc32f..cc5195e 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -12,15 +12,13 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, Variable, VariableRead, ) from policyengine_api.services.database import get_session -from policyengine_api.services.model_resolver import resolve_model_name +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/variables", tags=["variables"]) @@ -31,6 +29,7 @@ def list_variables( limit: int = 100, search: str | None = None, country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -42,20 +41,19 @@ def list_variables( Args: search: Filter by variable name, label, or description. country_id: Filter by country ("us" or "uk"). + Defaults to the latest model version. + tax_benefit_model_version_id: Pin to a specific model version. + Takes precedence over country_id. """ query = select(Variable) - # Filter by country - if country_id: - model_name = resolve_model_name(country_id) - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - ) + version_id = resolve_version_id( + country_id, tax_benefit_model_version_id, session + ) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Variable.name.ilike(search_pattern) @@ -75,6 +73,7 @@ class VariableByNameRequest(BaseModel): names: list[str] country_id: CountryId + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[VariableRead]) @@ -95,17 +94,15 @@ def get_variables_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - query = ( - select(Variable) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Variable.name.in_(request.names)) - .order_by(Variable.name) + version_id = resolve_version_id( + request.country_id, request.tax_benefit_model_version_id, session ) - return session.exec(query).all() + query = select(Variable).where(Variable.name.in_(request.names)) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Variable.name)).all() @router.get("/{variable_id}", response_model=VariableRead) diff --git a/src/policyengine_api/services/model_resolver.py b/src/policyengine_api/services/model_resolver.py index abeb92b..4eed5a4 100644 --- a/src/policyengine_api/services/model_resolver.py +++ b/src/policyengine_api/services/model_resolver.py @@ -1,5 +1,7 @@ """Shared resolver for country_id → tax-benefit model + latest version.""" +from uuid import UUID + from fastapi import HTTPException from sqlmodel import Session, select @@ -44,3 +46,31 @@ def resolve_country_model( ) return model, version + + +def resolve_version_id( + country_id: CountryId | None, + tax_benefit_model_version_id: UUID | None, + session: Session, +) -> UUID | None: + """Resolve the model version ID from country_id or explicit version. + + Priority: + 1. If tax_benefit_model_version_id provided, validate and return it. + 2. If country_id provided, return the latest version's ID. + 3. If neither provided, return None (no filtering). + """ + if tax_benefit_model_version_id: + version = session.get(TaxBenefitModelVersion, tax_benefit_model_version_id) + if not version: + raise HTTPException( + status_code=404, + detail=f"Model version '{tax_benefit_model_version_id}' not found", + ) + return version.id + + if country_id: + _, version = resolve_country_model(country_id, session) + return version.id + + return None From 23a85fb15027e7d70c32bdc2b4e5cbd4ffc7e0d9 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 10 Mar 2026 21:43:37 +0100 Subject: [PATCH 03/10] feat: Dual policy IDs, EXECUTION_DEFERRED status, and deferred execution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace single policy_id with baseline_policy_id + reform_policy_id on all analysis endpoints. Add PolicyIdInput sentinel type (UUID, "current_law", or None) with _resolve_policy_input() converter. Add run=True parameter to defer computation — when run=False, report status is set to EXECUTION_DEFERRED instead of triggering simulation. Includes Alembic migration to add EXECUTION_DEFERRED to reportstatus enum. Co-Authored-By: Claude Opus 4.6 --- ..._add_execution_deferred_to_reportstatus.py | 31 +++++ src/policyengine_api/api/analysis.py | 123 ++++++++++++------ .../api/household_analysis.py | 57 ++++++-- src/policyengine_api/models/report.py | 1 + 4 files changed, 160 insertions(+), 52 deletions(-) create mode 100644 alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py diff --git a/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py new file mode 100644 index 0000000..25440d2 --- /dev/null +++ b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py @@ -0,0 +1,31 @@ +"""add_execution_deferred_to_reportstatus + +Revision ID: f887cb5490bc +Revises: 62385cd8049d +Create Date: 2026-03-10 21:27:32.072364 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'f887cb5490bc' +down_revision: Union[str, Sequence[str], None] = '62385cd8049d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add EXECUTION_DEFERRED value to the reportstatus enum.""" + op.execute("ALTER TYPE reportstatus ADD VALUE IF NOT EXISTS 'EXECUTION_DEFERRED'") + + +def downgrade() -> None: + """Downgrade: PostgreSQL does not support removing enum values. + + The 'EXECUTION_DEFERRED' value will remain in the enum type. + To fully remove it, drop and recreate the type (requires migrating data). + """ + pass diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index dc7d5b3..ec212c8 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -16,7 +16,7 @@ """ import math -from typing import Literal +from typing import Literal, Union from uuid import UUID, uuid5 import logfire @@ -68,6 +68,22 @@ ) +# Type for API policy inputs: UUID, "current_law", or None (omitted). +PolicyIdInput = Union[UUID, Literal["current_law"], None] + + +def _resolve_policy_input(value: PolicyIdInput) -> UUID | None: + """Convert API policy sentinel to DB policy_id. + + - UUID → that policy ID + - "current_law" → None (current law has no policy record) + - None → None + """ + if value is None or value == "current_law": + return None + return value + + def get_traceparent() -> str | None: """Get the current W3C traceparent header for distributed tracing.""" carrier: dict[str, str] = {} @@ -133,18 +149,12 @@ def list_analysis_options( class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. - Example with dataset_id: - { - "country_id": "uk", - "dataset_id": "uuid-from-datasets-endpoint", - "policy_id": "uuid-of-reform-policy" - } - Example with region: { "country_id": "us", "region": "state/ca", - "policy_id": "uuid-of-reform-policy" + "baseline_policy_id": "current_law", + "reform_policy_id": "uuid-of-reform-policy" } """ @@ -159,9 +169,13 @@ class EconomicImpactRequest(BaseModel): default=None, description="Region code (e.g., 'state/ca', 'us'). Either dataset_id or region must be provided.", ) - policy_id: UUID | None = Field( - default=None, - description="Reform policy ID to compare against baseline (current law)", + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( + default="current_law", + description="Reform policy. UUID for a specific policy, or 'current_law' for current law.", ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" @@ -170,6 +184,10 @@ class EconomicImpactRequest(BaseModel): default=None, description="Year for the analysis (e.g., 2026). Selects the dataset for that year. Uses latest available if omitted.", ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) @model_validator(mode="after") def check_dataset_or_region(self) -> "EconomicImpactRequest": @@ -1244,11 +1262,15 @@ def economic_impact( # Get model version _model, model_version = resolve_country_model(request.country_id, session) + # Resolve policy sentinels to DB policy IDs + baseline_policy_db = _resolve_policy_input(request.baseline_policy_id) + reform_policy_db = _resolve_policy_input(request.reform_policy_id) + # Get or create simulations using the resolved dataset baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=None, + policy_id=baseline_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1262,7 +1284,7 @@ def economic_impact( reform_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=request.policy_id, + policy_id=reform_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1275,19 +1297,24 @@ def economic_impact( # Get or create report label = f"Economic impact: {request.country_id}" - if request.policy_id: - label += f" (policy {request.policy_id})" + if request.reform_policy_id and request.reform_policy_id != "current_law": + label += f" (policy {request.reform_policy_id})" report = _get_or_create_report( baseline_sim.id, reform_sim.id, label, "economy_comparison", session ) - # Trigger computation if report is pending - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison( - str(report.id), request.country_id, session - ) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), request.country_id, session + ) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() return _build_response(report, baseline_sim, reform_sim, session, region) @@ -1335,9 +1362,13 @@ class EconomyCustomRequest(BaseModel): default=None, description="Region code (e.g., 'state/ca', 'us').", ) - policy_id: UUID | None = Field( - default=None, - description="Reform policy ID to compare against baseline (current law)", + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( + default="current_law", + description="Reform policy. UUID for a specific policy, or 'current_law' for current law.", ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" @@ -1349,6 +1380,10 @@ class EconomyCustomRequest(BaseModel): modules: list[str] = Field( description="List of module names to compute (see GET /analysis/options)" ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) @model_validator(mode="after") def check_dataset_or_region(self) -> "EconomyCustomRequest": @@ -1414,9 +1449,11 @@ def economy_custom( country_id=request.country_id, dataset_id=request.dataset_id, region=request.region, - policy_id=request.policy_id, + baseline_policy_id=request.baseline_policy_id, + reform_policy_id=request.reform_policy_id, dynamic_id=request.dynamic_id, year=request.year, + run=request.run, ) dataset, region_obj = _resolve_dataset_and_region(impact_request, session) @@ -1433,12 +1470,16 @@ def economy_custom( else None ) + # Resolve policy sentinels to DB policy IDs + baseline_policy_db = _resolve_policy_input(request.baseline_policy_id) + reform_policy_db = _resolve_policy_input(request.reform_policy_id) + _model, model_version = resolve_country_model(request.country_id, session) baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=None, + policy_id=baseline_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1452,7 +1493,7 @@ def economy_custom( reform_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=request.policy_id, + policy_id=reform_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1464,21 +1505,27 @@ def economy_custom( ) label = f"Custom analysis: {request.country_id}" - if request.policy_id: - label += f" (policy {request.policy_id})" + if request.reform_policy_id and request.reform_policy_id != "current_law": + label += f" (policy {request.reform_policy_id})" report = _get_or_create_report( baseline_sim.id, reform_sim.id, label, "economy_comparison", session ) - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison( - str(report.id), - request.country_id, - session, - modules=request.modules, - ) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), + request.country_id, + session, + modules=request.modules, + ) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() full_response = _build_response( report, baseline_sim, reform_sim, session, region_obj diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 82fce2a..e0854cf 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -35,8 +35,10 @@ from policyengine_api.services.model_resolver import resolve_country_model from .analysis import ( + PolicyIdInput, _get_or_create_report, _get_or_create_simulation, + _resolve_policy_input, ) @@ -512,14 +514,25 @@ class HouseholdImpactRequest(BaseModel): """Request for household impact analysis.""" household_id: UUID = Field(description="ID of the household to analyze") - policy_id: UUID | None = Field( + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( default=None, - description="Reform policy ID. If None, runs single calculation under current law.", + description=( + "Reform policy. UUID for a specific policy, 'current_law' for current law, " + "or omit entirely for a single calculation (no comparison)." + ), ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID", ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) class HouseholdSimulationInfo(BaseModel): @@ -652,26 +665,35 @@ def household_impact( ) -> HouseholdImpactResponse: """Run household impact analysis. - If policy_id is None: single run under current law. - If policy_id is set: comparison (baseline vs reform). + If reform_policy_id is omitted (None): single run under baseline policy. + If reform_policy_id is set (UUID or "current_law"): comparison (baseline vs reform). This is an async operation. The endpoint returns immediately with a report_id and status="pending". Poll GET /analysis/household-impact/{report_id} until status="completed" to get results. """ household = validate_household_exists(request.household_id, session) - validate_policy_exists(request.policy_id, session) + + # Resolve policy sentinels to DB policy IDs + baseline_db_id = _resolve_policy_input(request.baseline_policy_id) + reform_db_id = _resolve_policy_input(request.reform_policy_id) + has_reform = request.reform_policy_id is not None + + # Validate policies exist in DB + validate_policy_exists(baseline_db_id, session) + validate_policy_exists(reform_db_id, session) _model, model_version = resolve_country_model(household.country_id, session) baseline_sim = _create_baseline_simulation( - household, model_version.id, request.dynamic_id, session + household, model_version.id, request.dynamic_id, session, + policy_id=baseline_db_id, ) reform_sim = _create_reform_simulation( - household, model_version.id, request.policy_id, request.dynamic_id, session - ) + household, model_version.id, reform_db_id, request.dynamic_id, session + ) if has_reform else None - report_type = "household_comparison" if request.policy_id else "household_single" + report_type = "household_comparison" if has_reform else "household_single" report = _get_or_create_report( baseline_sim_id=baseline_sim.id, reform_sim_id=reform_sim.id if reform_sim else None, @@ -680,9 +702,15 @@ def household_impact( session=session, ) - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_household_impact", job_id=str(report.id)): - _trigger_household_impact(str(report.id), household.country_id, session) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact(str(report.id), household.country_id, session) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() return build_household_response(report, baseline_sim, reform_sim, session) @@ -724,12 +752,13 @@ def _create_baseline_simulation( model_version_id: UUID, dynamic_id: UUID | None, session: Session, + policy_id: UUID | None = None, ) -> Simulation: - """Create baseline simulation (current law, no policy).""" + """Create baseline simulation.""" return _get_or_create_simulation( simulation_type=SimulationType.HOUSEHOLD, model_version_id=model_version_id, - policy_id=None, + policy_id=policy_id, dynamic_id=dynamic_id, session=session, household_id=household.id, diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index b034dcb..68724e4 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -9,6 +9,7 @@ class ReportStatus(str, Enum): """Report processing status.""" PENDING = "pending" + EXECUTION_DEFERRED = "execution_deferred" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" From e6727dc4585dccdf4fa595470b12fabaceb254f6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 18:35:25 +0100 Subject: [PATCH 04/10] feat: Convert VARCHAR enum columns to native PG enums with values_callable Three enum columns (region_type, report_type, decile_type) were stored as VARCHAR but needed native PG enum types for proper SQLAlchemy deserialization. Adds values_callable to store lowercase values and includes Alembic migration that drops pre-existing uppercase enum types before recreating with lowercase. Also fixes seed_regions.py to pass RegionType enum members instead of raw strings. Co-Authored-By: Claude Opus 4.6 --- ...onvert_varchar_enums_to_native_pg_enums.py | 90 +++++++++++++++++++ scripts/seed_regions.py | 5 +- .../models/intra_decile_impact.py | 6 +- src/policyengine_api/models/region.py | 5 +- src/policyengine_api/models/report.py | 6 +- 5 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py diff --git a/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py new file mode 100644 index 0000000..3307f43 --- /dev/null +++ b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py @@ -0,0 +1,90 @@ +"""convert_varchar_enums_to_native_pg_enums + +Revision ID: dac22a838dda +Revises: f887cb5490bc +Create Date: 2026-03-11 01:37:08.928795 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'dac22a838dda' +down_revision: Union[str, Sequence[str], None] = 'f887cb5490bc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Convert VARCHAR enum columns to native PostgreSQL enum types. + + The enum types may already exist with UPPERCASE values (created by + SQLAlchemy's default create_all behavior). Since the columns are still + VARCHAR, the types are unused — drop and recreate with lowercase values + matching the data and the values_callable convention. + """ + # Drop any pre-existing enum types (unused — columns are still VARCHAR) + op.execute("DROP TYPE IF EXISTS regiontype CASCADE") + op.execute("DROP TYPE IF EXISTS reporttype CASCADE") + op.execute("DROP TYPE IF EXISTS deciletype CASCADE") + + # Create PG enum types with lowercase values + op.execute(""" + CREATE TYPE regiontype AS ENUM ( + 'national', 'country', 'state', 'congressional_district', + 'constituency', 'local_authority', 'city', 'place' + ) + """) + op.execute(""" + CREATE TYPE reporttype AS ENUM ( + 'economy_comparison', 'household_comparison', 'household_single' + ) + """) + op.execute("CREATE TYPE deciletype AS ENUM ('income', 'wealth')") + + # Alter columns from VARCHAR to enum. + # LOWER() handles any databases where values were previously uppercased. + op.execute(""" + ALTER TABLE regions + ALTER COLUMN region_type TYPE regiontype + USING LOWER(region_type)::regiontype + """) + op.execute(""" + ALTER TABLE reports + ALTER COLUMN report_type TYPE reporttype + USING LOWER(report_type)::reporttype + """) + # decile_type has a VARCHAR default that must be dropped before type change + op.execute("ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type DROP DEFAULT") + op.execute(""" + ALTER TABLE intra_decile_impacts + ALTER COLUMN decile_type TYPE deciletype + USING LOWER(decile_type)::deciletype + """) + op.execute("ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type SET DEFAULT 'income'::deciletype") + + +def downgrade() -> None: + """Revert native PG enum columns back to VARCHAR.""" + op.execute(""" + ALTER TABLE regions + ALTER COLUMN region_type TYPE VARCHAR + USING region_type::text + """) + op.execute(""" + ALTER TABLE reports + ALTER COLUMN report_type TYPE VARCHAR + USING report_type::text + """) + op.execute(""" + ALTER TABLE intra_decile_impacts + ALTER COLUMN decile_type TYPE VARCHAR + USING decile_type::text + """) + + # Drop the PG enum types + op.execute("DROP TYPE IF EXISTS regiontype") + op.execute("DROP TYPE IF EXISTS reporttype") + op.execute("DROP TYPE IF EXISTS deciletype") diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index 016b8dc..5ab3d37 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -33,6 +33,7 @@ RegionDatasetLink, TaxBenefitModel, ) +from policyengine_api.models.region import RegionType # noqa: E402 def _group_us_datasets( @@ -195,7 +196,7 @@ def seed_us_regions( db_region = Region( code=pe_region.code, label=pe_region.label, - region_type=pe_region.region_type, + region_type=RegionType(pe_region.region_type), requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, @@ -293,7 +294,7 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]: db_region = Region( code=pe_region.code, label=pe_region.label, - region_type=pe_region.region_type, + region_type=RegionType(pe_region.region_type), requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, diff --git a/src/policyengine_api/models/intra_decile_impact.py b/src/policyengine_api/models/intra_decile_impact.py index 8771d55..a8c958a 100644 --- a/src/policyengine_api/models/intra_decile_impact.py +++ b/src/policyengine_api/models/intra_decile_impact.py @@ -19,6 +19,7 @@ from enum import Enum from uuid import UUID, uuid4 +import sqlalchemy as sa from sqlmodel import Field, SQLModel @@ -35,7 +36,10 @@ class IntraDecileImpactBase(SQLModel): baseline_simulation_id: UUID = Field(foreign_key="simulations.id") reform_simulation_id: UUID = Field(foreign_key="simulations.id") report_id: UUID | None = Field(default=None, foreign_key="reports.id") - decile_type: DecileType = Field(default=DecileType.INCOME) + decile_type: DecileType = Field( + default=DecileType.INCOME, + sa_type=sa.Enum(DecileType, values_callable=lambda x: [e.value for e in x]), + ) decile: int = Field(ge=0, le=10) lose_more_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) lose_less_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py index 0458284..a3daffc 100644 --- a/src/policyengine_api/models/region.py +++ b/src/policyengine_api/models/region.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from uuid import UUID, uuid4 +import sqlalchemy as sa from pydantic import model_validator from sqlmodel import Field, Relationship, SQLModel @@ -33,7 +34,9 @@ class RegionBase(SQLModel): code: str # e.g., "state/ca", "constituency/Sheffield Central" label: str # e.g., "California", "Sheffield Central" - region_type: RegionType # e.g., RegionType.STATE, RegionType.CONSTITUENCY + region_type: RegionType = Field( + sa_type=sa.Enum(RegionType, values_callable=lambda x: [e.value for e in x]), + ) requires_filter: bool = False filter_field: str | None = None # e.g., "state_code", "place_fips" filter_value: str | None = None # e.g., "CA", "44000" diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index 68724e4..57b0262 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -2,6 +2,7 @@ from enum import Enum from uuid import UUID, uuid4 +import sqlalchemy as sa from sqlmodel import Column, Field, SQLModel, Text @@ -28,7 +29,10 @@ class ReportBase(SQLModel): label: str description: str | None = None - report_type: ReportType | None = None + report_type: ReportType | None = Field( + default=None, + sa_type=sa.Enum(ReportType, values_callable=lambda x: [e.value for e in x], nullable=True), + ) user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) status: ReportStatus = ReportStatus.PENDING From f1254ec772f02cffe6c535e0053498b348a52e56 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Mar 2026 20:32:28 +0100 Subject: [PATCH 05/10] feat: Auto-start EXECUTION_DEFERRED reports on GET endpoint access GET endpoints for economic-impact and household-impact now detect EXECUTION_DEFERRED status and automatically trigger computation on first access. Adds MODEL_NAME_TO_COUNTRY reverse lookup and _resolve_country_from_simulation helper to derive country_id from the simulation's model version chain. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 20 +++++++++++++++++- .../api/household_analysis.py | 21 +++++++++++++++++++ src/policyengine_api/config/constants.py | 3 +++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index ec212c8..a6ec0ee 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -60,7 +60,7 @@ TaxBenefitModel, TaxBenefitModelVersion, ) -from policyengine_api.config.constants import CountryId +from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY, CountryId from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import ( resolve_country_model, @@ -1338,12 +1338,30 @@ def get_economic_impact_status( if not baseline_sim or not reform_sim: raise HTTPException(status_code=500, detail="Simulation data missing") + # Auto-start deferred reports on first access + if report.status == ReportStatus.EXECUTION_DEFERRED: + country_id = _resolve_country_from_simulation(baseline_sim, session) + with logfire.span("auto_trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison(str(report.id), country_id, session) + session.refresh(report) + region = ( session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None ) return _build_response(report, baseline_sim, reform_sim, session, region) +def _resolve_country_from_simulation(sim: Simulation, session: Session) -> str: + """Derive country_id from a simulation's model version.""" + version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) + if not version: + raise HTTPException(status_code=500, detail="Model version not found") + model = session.get(TaxBenefitModel, version.model_id) + if not model: + raise HTTPException(status_code=500, detail="Tax-benefit model not found") + return MODEL_NAME_TO_COUNTRY[model.name] + + # --------------------------------------------------------------------------- # POST /analysis/economy-custom — run selected economy modules # --------------------------------------------------------------------------- diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index e0854cf..fffd2f6 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -30,7 +30,10 @@ Simulation, SimulationStatus, SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, ) +from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import resolve_country_model @@ -739,9 +742,27 @@ def get_household_impact( if report.reform_simulation_id: reform_sim = session.get(Simulation, report.reform_simulation_id) + # Auto-start deferred reports on first access + if report.status == ReportStatus.EXECUTION_DEFERRED: + country_id = _resolve_country_from_simulation(baseline_sim, session) + with logfire.span("auto_trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact(str(report.id), country_id, session) + session.refresh(report) + return build_household_response(report, baseline_sim, reform_sim, session) +def _resolve_country_from_simulation(sim: Simulation, session: Session) -> str: + """Derive country_id from a simulation's model version.""" + version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) + if not version: + raise HTTPException(status_code=500, detail="Model version not found") + model = session.get(TaxBenefitModel, version.model_id) + if not model: + raise HTTPException(status_code=500, detail="Tax-benefit model not found") + return MODEL_NAME_TO_COUNTRY[model.name] + + # ============================================================================= # Simulation Creation Helpers # ============================================================================= diff --git a/src/policyengine_api/config/constants.py b/src/policyengine_api/config/constants.py index a39d827..61673e1 100644 --- a/src/policyengine_api/config/constants.py +++ b/src/policyengine_api/config/constants.py @@ -10,3 +10,6 @@ "uk": "policyengine-uk", "us": "policyengine-us", } + +# Reverse mapping: model name → country ID +MODEL_NAME_TO_COUNTRY: dict[str, str] = {v: k for k, v in COUNTRY_MODEL_NAMES.items()} From 6592db39ce2f7d8daf74f5e925aeaf05c05e868f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 13 Mar 2026 15:20:23 +0100 Subject: [PATCH 06/10] chore: Fix import sorting and apply ruff formatting Co-Authored-By: Claude Opus 4.6 --- ...8049d_rename_tax_benefit_model_name_to_.py | 50 ++++++++++++------- ..._add_execution_deferred_to_reportstatus.py | 6 +-- ...onvert_varchar_enums_to_native_pg_enums.py | 10 ++-- src/policyengine_api/api/analysis.py | 13 ++--- src/policyengine_api/api/datasets.py | 4 +- src/policyengine_api/api/household.py | 2 +- .../api/household_analysis.py | 17 +++++-- src/policyengine_api/api/parameter_values.py | 4 +- src/policyengine_api/api/parameters.py | 8 +-- src/policyengine_api/api/regions.py | 8 +-- src/policyengine_api/api/simulations.py | 2 +- src/policyengine_api/api/variables.py | 4 +- src/policyengine_api/models/report.py | 4 +- 13 files changed, 70 insertions(+), 62 deletions(-) diff --git a/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py index 5923b22..3d4f09e 100644 --- a/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py +++ b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py @@ -5,16 +5,17 @@ Create Date: 2026-03-09 16:48:30.899791 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes +from alembic import op # revision identifiers, used by Alembic. -revision: str = '62385cd8049d' -down_revision: Union[str, Sequence[str], None] = '886921687770' +revision: str = "62385cd8049d" +down_revision: Union[str, Sequence[str], None] = "886921687770" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,8 +23,14 @@ def upgrade() -> None: """Upgrade schema: rename tax_benefit_model_name → country_id with data migration.""" # 1. Add country_id columns (nullable initially) - op.add_column('households', sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) - op.add_column('household_jobs', sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column( + "households", + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "household_jobs", + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) # 2. Populate country_id from tax_benefit_model_name op.execute(""" @@ -42,28 +49,37 @@ def upgrade() -> None: """) # 3. Make country_id non-nullable - op.alter_column('households', 'country_id', nullable=False) - op.alter_column('household_jobs', 'country_id', nullable=False) + op.alter_column("households", "country_id", nullable=False) + op.alter_column("household_jobs", "country_id", nullable=False) # 4. Drop old columns - op.drop_column('households', 'tax_benefit_model_name') - op.drop_column('household_jobs', 'tax_benefit_model_name') + op.drop_column("households", "tax_benefit_model_name") + op.drop_column("household_jobs", "tax_benefit_model_name") def downgrade() -> None: """Downgrade schema: restore tax_benefit_model_name from country_id.""" # 1. Re-add tax_benefit_model_name columns (nullable initially) - op.add_column('households', sa.Column('tax_benefit_model_name', sa.VARCHAR(), nullable=True)) - op.add_column('household_jobs', sa.Column('tax_benefit_model_name', sa.VARCHAR(), nullable=True)) + op.add_column( + "households", sa.Column("tax_benefit_model_name", sa.VARCHAR(), nullable=True) + ) + op.add_column( + "household_jobs", + sa.Column("tax_benefit_model_name", sa.VARCHAR(), nullable=True), + ) # 2. Populate from country_id - op.execute("UPDATE households SET tax_benefit_model_name = 'policyengine_' || country_id") - op.execute("UPDATE household_jobs SET tax_benefit_model_name = 'policyengine_' || country_id") + op.execute( + "UPDATE households SET tax_benefit_model_name = 'policyengine_' || country_id" + ) + op.execute( + "UPDATE household_jobs SET tax_benefit_model_name = 'policyengine_' || country_id" + ) # 3. Make non-nullable - op.alter_column('households', 'tax_benefit_model_name', nullable=False) - op.alter_column('household_jobs', 'tax_benefit_model_name', nullable=False) + op.alter_column("households", "tax_benefit_model_name", nullable=False) + op.alter_column("household_jobs", "tax_benefit_model_name", nullable=False) # 4. Drop country_id columns - op.drop_column('households', 'country_id') - op.drop_column('household_jobs', 'country_id') + op.drop_column("households", "country_id") + op.drop_column("household_jobs", "country_id") diff --git a/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py index 25440d2..45c6b2f 100644 --- a/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py +++ b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py @@ -5,14 +5,14 @@ Create Date: 2026-03-10 21:27:32.072364 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. -revision: str = 'f887cb5490bc' -down_revision: Union[str, Sequence[str], None] = '62385cd8049d' +revision: str = "f887cb5490bc" +down_revision: Union[str, Sequence[str], None] = "62385cd8049d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py index 3307f43..6a00cbf 100644 --- a/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py +++ b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py @@ -5,14 +5,14 @@ Create Date: 2026-03-11 01:37:08.928795 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. -revision: str = 'dac22a838dda' -down_revision: Union[str, Sequence[str], None] = 'f887cb5490bc' +revision: str = "dac22a838dda" +down_revision: Union[str, Sequence[str], None] = "f887cb5490bc" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -63,7 +63,9 @@ def upgrade() -> None: ALTER COLUMN decile_type TYPE deciletype USING LOWER(decile_type)::deciletype """) - op.execute("ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type SET DEFAULT 'income'::deciletype") + op.execute( + "ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type SET DEFAULT 'income'::deciletype" + ) def downgrade() -> None: diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index a6ec0ee..dc441d8 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -30,6 +30,7 @@ get_modules_for_country, validate_modules, ) +from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY, CountryId from policyengine_api.models import ( BudgetSummary, BudgetSummaryRead, @@ -60,14 +61,12 @@ TaxBenefitModel, TaxBenefitModelVersion, ) -from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY, CountryId from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import ( resolve_country_model, resolve_model_name, ) - # Type for API policy inputs: UUID, "current_law", or None (omitted). PolicyIdInput = Union[UUID, Literal["current_law"], None] @@ -238,7 +237,6 @@ class EconomicImpactResponse(BaseModel): intra_wealth_decile: list[IntraDecileImpactRead] | None = None - def _get_deterministic_simulation_id( simulation_type: SimulationType, model_version_id: UUID, @@ -1308,9 +1306,7 @@ def economic_impact( if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): if request.run: with logfire.span("trigger_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison( - str(report.id), request.country_id, session - ) + _trigger_economy_comparison(str(report.id), request.country_id, session) elif report.status == ReportStatus.PENDING: report.status = ReportStatus.EXECUTION_DEFERRED session.add(report) @@ -1366,6 +1362,7 @@ def _resolve_country_from_simulation(sim: Simulation, session: Session) -> str: # POST /analysis/economy-custom — run selected economy modules # --------------------------------------------------------------------------- + class EconomyCustomRequest(BaseModel): """Request body for custom economy analysis with selected modules.""" @@ -1654,9 +1651,7 @@ def rerun_report( (k for k, v in COUNTRY_MODEL_NAMES.items() if v == model.name), None ) if not country_id: - raise HTTPException( - status_code=500, detail=f"Unknown model name: {model.name}" - ) + raise HTTPException(status_code=500, detail=f"Unknown model name: {model.name}") # 4. Delete all result records for this report result_tables = [ diff --git a/src/policyengine_api/api/datasets.py b/src/policyengine_api/api/datasets.py index f7ac2a8..e01dc29 100644 --- a/src/policyengine_api/api/datasets.py +++ b/src/policyengine_api/api/datasets.py @@ -36,9 +36,7 @@ def list_datasets( if country_id: model_name = resolve_model_name(country_id) - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == model_name - ) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) datasets = session.exec(query).all() return datasets diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 4fcaab0..3abdc10 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -14,13 +14,13 @@ from pydantic import BaseModel, Field from sqlmodel import Session +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Dynamic, HouseholdJob, HouseholdJobStatus, Policy, ) -from policyengine_api.config.constants import CountryId from policyengine_api.services.database import get_session diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index fffd2f6..fc77163 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -22,6 +22,7 @@ from pydantic import BaseModel, Field from sqlmodel import Session +from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY from policyengine_api.models import ( Household, Policy, @@ -33,7 +34,6 @@ TaxBenefitModel, TaxBenefitModelVersion, ) -from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import resolve_country_model @@ -689,12 +689,19 @@ def household_impact( _model, model_version = resolve_country_model(household.country_id, session) baseline_sim = _create_baseline_simulation( - household, model_version.id, request.dynamic_id, session, + household, + model_version.id, + request.dynamic_id, + session, policy_id=baseline_db_id, ) - reform_sim = _create_reform_simulation( - household, model_version.id, reform_db_id, request.dynamic_id, session - ) if has_reform else None + reform_sim = ( + _create_reform_simulation( + household, model_version.id, reform_db_id, request.dynamic_id, session + ) + if has_reform + else None + ) report_type = "household_comparison" if has_reform else "household_single" report = _get_or_create_report( diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 11ccded..3311cbf 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -54,9 +54,7 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) - version_id = resolve_version_id( - country_id, tax_benefit_model_version_id, session - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) if version_id: query = query.join(Parameter).where( Parameter.tax_benefit_model_version_id == version_id diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 3075719..93d052e 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -49,9 +49,7 @@ def list_parameters( """ query = select(Parameter) - version_id = resolve_version_id( - country_id, tax_benefit_model_version_id, session - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) if version_id: query = query.where(Parameter.tax_benefit_model_version_id == version_id) @@ -140,9 +138,7 @@ def get_parameter_children( parameters (with full metadata). Use this to lazily load the parameter tree one level at a time. """ - version_id = resolve_version_id( - country_id, tax_benefit_model_version_id, session - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) prefix = f"{parent_path}." if parent_path else "" diff --git a/src/policyengine_api/api/regions.py b/src/policyengine_api/api/regions.py index 88e9597..b5a5c6e 100644 --- a/src/policyengine_api/api/regions.py +++ b/src/policyengine_api/api/regions.py @@ -47,9 +47,7 @@ def list_regions( if country_id: model_name = resolve_model_name(country_id) - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == model_name - ) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) elif tax_benefit_model_id: query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) @@ -94,9 +92,7 @@ def get_region_by_code( if country_id: model_name = resolve_model_name(country_id) - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == model_name - ) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) elif tax_benefit_model_id: query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index b3f40d4..ba0196d 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field, model_validator from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Dataset, Household, @@ -27,7 +28,6 @@ SimulationType, TaxBenefitModel, ) -from policyengine_api.config.constants import CountryId from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import ( resolve_country_model, diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index cc5195e..162e179 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -47,9 +47,7 @@ def list_variables( """ query = select(Variable) - version_id = resolve_version_id( - country_id, tax_benefit_model_version_id, session - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) if version_id: query = query.where(Variable.tax_benefit_model_version_id == version_id) diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index 57b0262..0ec572b 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -31,7 +31,9 @@ class ReportBase(SQLModel): description: str | None = None report_type: ReportType | None = Field( default=None, - sa_type=sa.Enum(ReportType, values_callable=lambda x: [e.value for e in x], nullable=True), + sa_type=sa.Enum( + ReportType, values_callable=lambda x: [e.value for e in x], nullable=True + ), ) user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) From 858e45bc4bf94293553f2284e1bec77daab51b55 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 13 Mar 2026 15:23:57 +0100 Subject: [PATCH 07/10] docs: Add changelog for v0.4.0 Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0586f3e..5a97ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +0.4.0 (2026-03-13) + +# Added + +- Standardize all endpoints on `country_id` instead of `tax_benefit_model_name` (#109) +- Default metadata endpoints (variables, parameters, datasets) to latest model version with optional version pinning (#109) +- Dual policy IDs (`baseline_policy_id` / `reform_policy_id`) on reports and `EXECUTION_DEFERRED` report status (#109) +- Auto-start `EXECUTION_DEFERRED` reports on GET endpoint access (#109) +- Convert VARCHAR enum columns to native PostgreSQL enums with `values_callable` (#109) + + 0.3.1 (2026-03-11) # Fixed From 404c07bca2b132efc9f518519cc38047b8046b6a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 13 Mar 2026 20:23:50 +0100 Subject: [PATCH 08/10] fix: Migrate test fixtures and tests from tax_benefit_model_name to country_id Update all test files to match the standardized country_id API: - Replace tax_benefit_model_name with country_id in request payloads - Change "policyengine_us"/"policyengine_uk" to "us"/"uk" - Update policy_id to reform_policy_id in household impact tests - Fix get_country_config/get_calculator calls to use short country IDs - Fix model name in conftest fixture (underscore to hyphen) Co-Authored-By: Claude Opus 4.6 --- test_fixtures/fixtures_household_analysis.py | 6 +-- test_fixtures/fixtures_households.py | 10 ++-- .../fixtures_simulations_standalone.py | 4 +- .../fixtures_user_household_associations.py | 4 +- tests/conftest.py | 2 +- tests/test_analysis.py | 46 +++++++++---------- tests/test_analysis_household_impact.py | 22 ++++----- tests/test_economy_custom.py | 14 +++--- tests/test_household.py | 28 +++++------ tests/test_household_impact.py | 12 ++--- tests/test_households.py | 26 +++++------ tests/test_models.py | 4 +- tests/test_simulations_standalone.py | 14 +++--- tests/test_variable_labels.py | 10 ++-- 14 files changed, 99 insertions(+), 103 deletions(-) diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py index 85401af..903f0c2 100644 --- a/test_fixtures/fixtures_household_analysis.py +++ b/test_fixtures/fixtures_household_analysis.py @@ -310,12 +310,12 @@ def create_policy_with_parameter_value( def create_household_for_analysis( session: Session, - tax_benefit_model_name: str = "policyengine_uk", + country_id: str = "uk", year: int = 2024, label: str = "Test household for analysis", ) -> Household: """Create a household suitable for analysis testing.""" - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": household_data = { "people": [{"age": 30, "employment_income": 35000}], "benunit": {}, @@ -332,7 +332,7 @@ def create_household_for_analysis( } record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data=household_data, diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py index 4e676f4..c5d4512 100644 --- a/test_fixtures/fixtures_households.py +++ b/test_fixtures/fixtures_households.py @@ -7,7 +7,7 @@ # ----------------------------------------------------------------------------- MOCK_US_HOUSEHOLD_CREATE = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "year": 2024, "label": "US test household", "people": [ @@ -20,7 +20,7 @@ } MOCK_UK_HOUSEHOLD_CREATE = { - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "year": 2024, "label": "UK test household", "people": [ @@ -31,7 +31,7 @@ } MOCK_HOUSEHOLD_MINIMAL = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "year": 2024, "people": [{"age": 25}], } @@ -44,7 +44,7 @@ def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str | None = "Test household", people: list | None = None, @@ -55,7 +55,7 @@ def create_household( household_data.update(entity_groups) record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data=household_data, diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py index 2b7ce39..5eb85cd 100644 --- a/test_fixtures/fixtures_simulations_standalone.py +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -54,13 +54,13 @@ def create_uk_model_and_version( def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str = "Test household", ) -> Household: """Create and persist a Household record.""" household = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data={ diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py index 66b0835..a0dc9ee 100644 --- a/test_fixtures/fixtures_user_household_associations.py +++ b/test_fixtures/fixtures_user_household_associations.py @@ -25,13 +25,13 @@ def create_user( def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str | None = "Test household", ) -> Household: """Create and persist a Household record.""" record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data={"people": [{"age": 30}]}, diff --git a/tests/conftest.py b/tests/conftest.py index 1c52936..e2c9c36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,7 +73,7 @@ def uk_tax_benefit_model_fixture(session: Session): def simulation_fixture(session: Session): """Create a test simulation with required dependencies.""" # Create model - model = TaxBenefitModel(name="policyengine_uk", description="UK model") + model = TaxBenefitModel(name="policyengine-uk", description="UK model") session.add(model) session.commit() session.refresh(model) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index abd7489..5fa6674 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -53,7 +53,7 @@ def test__given_dataset_id__then_region_is_none(self, session: Session): model = create_tax_benefit_model(session, name="policyengine-uk") dataset = create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset.id, ) @@ -67,7 +67,7 @@ def test__given_dataset_id__then_dataset_is_returned(self, session: Session): model = create_tax_benefit_model(session, name="policyengine-uk") dataset = create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset.id, ) @@ -94,7 +94,7 @@ def test__given_dataset_id_and_region__then_region_takes_precedence( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset1.id, region="uk", ) @@ -126,7 +126,7 @@ def test__given_region_requires_filter__then_returns_filter_fields( filter_value="ENGLAND", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -154,7 +154,7 @@ def test__given_us_state_region__then_returns_state_filter(self, session: Sessio filter_value="CA", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="state/ca", ) @@ -183,7 +183,7 @@ def test__given_region_with_filter__then_dataset_is_resolved( filter_value="ENGLAND", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -209,7 +209,7 @@ def test__given_national_uk_region__then_filter_params_none(self, session: Sessi requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -235,7 +235,7 @@ def test__given_national_us_region__then_filter_params_none(self, session: Sessi requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="us", ) @@ -263,7 +263,7 @@ def test__given_national_region__then_dataset_still_resolved( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -280,7 +280,7 @@ def test__given_nonexistent_region_code__then_raises_404(self, session: Session) model = create_tax_benefit_model(session, name="policyengine-uk") create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="nonexistent/region", ) @@ -302,7 +302,7 @@ def test__given_region_for_wrong_model__then_raises_404(self, session: Session): region_type="national", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="uk", ) @@ -318,13 +318,13 @@ def test__given_neither_dataset_nor_region__then_raises_validation_error( with pytest.raises(ValidationError, match="dataset_id or region"): EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", ) def test__given_nonexistent_dataset_id__then_raises_404(self, session: Session): nonexistent_id = uuid4() request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=nonexistent_id, ) @@ -943,7 +943,7 @@ def test__given_region_with_row_filter_strategy__then_region_has_filter_strategy filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -972,7 +972,7 @@ def test__given_constituency_region__then_region_has_weight_replacement_strategy filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="constituency/sheffield-central", ) @@ -1001,7 +1001,7 @@ def test__given_national_region__then_filter_strategy_is_none( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -1021,11 +1021,11 @@ def test__given_national_region__then_filter_strategy_is_none( class TestEconomicImpactValidation: """Tests for request validation (no database required).""" - def test_invalid_model_name(self): + def test_invalid_country_id(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "dataset_id": "00000000-0000-0000-0000-000000000000", }, ) @@ -1035,7 +1035,7 @@ def test_missing_dataset_id(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 @@ -1044,7 +1044,7 @@ def test_invalid_uuid(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": "not-a-uuid", }, ) @@ -1059,7 +1059,7 @@ def test_dataset_not_found(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": "00000000-0000-0000-0000-000000000000", }, ) @@ -1102,7 +1102,7 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": str(uk_dataset_id), }, ) @@ -1126,7 +1126,7 @@ def test_simulations_created(self, uk_dataset_id, session: Session): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": str(uk_dataset_id), }, ) diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py index 802562d..0e6f485 100644 --- a/tests/test_analysis_household_impact.py +++ b/tests/test_analysis_household_impact.py @@ -201,13 +201,13 @@ class TestGetCountryConfig: """Tests for get_country_config helper.""" def test_uk_model_returns_uk_config(self): - config = get_country_config("policyengine_uk") + config = get_country_config("uk") assert config == UK_CONFIG assert config.name == "uk" assert "benunit" in config.entity_types def test_us_model_returns_us_config(self): - config = get_country_config("policyengine_us") + config = get_country_config("us") assert config == US_CONFIG assert config.name == "us" assert "tax_unit" in config.entity_types @@ -223,13 +223,13 @@ class TestGetCalculator: def test_uk_model_returns_uk_calculator(self): from policyengine_api.api.household_analysis import calculate_uk_household - calc = get_calculator("policyengine_uk") + calc = get_calculator("uk") assert calc == calculate_uk_household def test_us_model_returns_us_calculator(self): from policyengine_api.api.household_analysis import calculate_us_household - calc = get_calculator("policyengine_us") + calc = get_calculator("us") assert calc == calculate_us_household def test_unknown_model_defaults_to_us(self): @@ -297,7 +297,7 @@ def test_policy_not_found(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(uuid4()), + "reform_policy_id": str(uuid4()), }, ) assert response.status_code == 404 @@ -346,7 +346,7 @@ def test_comparison_creates_two_simulations(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy.id), + "reform_policy_id": str(policy.id), }, ) data = response.json() @@ -386,7 +386,7 @@ def test_report_links_simulations(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy.id), + "reform_policy_id": str(policy.id), }, ) data = response.json() @@ -442,7 +442,7 @@ def test_different_policy_creates_different_simulation(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy1.id), + "reform_policy_id": str(policy1.id), }, ) data1 = response1.json() @@ -452,7 +452,7 @@ def test_different_policy_creates_different_simulation(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy2.id), + "reform_policy_id": str(policy2.id), }, ) data2 = response2.json() @@ -506,9 +506,7 @@ class TestUSHouseholdImpact: def test_us_household_creates_simulation(self, client, session): """US household creates simulation with correct model.""" _, version = setup_us_model_and_version(session) - household = create_household_for_analysis( - session, tax_benefit_model_name="policyengine_us" - ) + household = create_household_for_analysis(session, country_id="us") response = client.post( "/analysis/household-impact", diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py index fda65fa..4be2968 100644 --- a/tests/test_economy_custom.py +++ b/tests/test_economy_custom.py @@ -218,7 +218,7 @@ def test_unknown_module_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["nonexistent_module"], }, @@ -230,7 +230,7 @@ def test_wrong_country_module_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["constituency"], }, @@ -242,7 +242,7 @@ def test_multiple_errors_in_module_validation(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["nonexistent", "constituency"], }, @@ -256,7 +256,7 @@ def test_empty_modules_list_passes_validation(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": [], }, @@ -272,7 +272,7 @@ def test_valid_modules_but_missing_region_returns_404(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["decile", "poverty"], }, @@ -284,7 +284,7 @@ def test_missing_modules_field_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", }, ) @@ -294,7 +294,7 @@ def test_invalid_model_name_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "region": "us", "modules": ["decile"], }, diff --git a/tests/test_household.py b/tests/test_household.py index 0cf0288..ddefc15 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -34,7 +34,7 @@ def test_single_adult(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, }, @@ -55,7 +55,7 @@ def test_couple_with_children(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [ {"age": 35, "employment_income": 50000}, {"age": 33, "employment_income": 25000}, @@ -75,7 +75,7 @@ def test_with_household_data(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 40, "employment_income": 45000}], "household": [ { @@ -96,7 +96,7 @@ def test_output_contains_tax_variables(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 50000}], "year": 2026, }, @@ -118,7 +118,7 @@ def test_single_adult(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 30, "employment_income": 60000}], "year": 2024, }, @@ -140,7 +140,7 @@ def test_family_with_children(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"age": 35, "employment_income": 80000}, {"age": 33, "employment_income": 40000}, @@ -164,7 +164,7 @@ def test_multiple_uk_households(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [ # Person in household 0 { @@ -207,7 +207,7 @@ def test_multiple_us_households(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ # Person in household 0 { @@ -272,7 +272,7 @@ def test_invalid_model_name(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "people": [{"age": 30}], }, ) @@ -283,7 +283,7 @@ def test_missing_people(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 @@ -354,7 +354,7 @@ def test_us_reform_changes_household_net_income(self): baseline_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 40, "employment_income": 70000}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -371,7 +371,7 @@ def test_us_reform_changes_household_net_income(self): reform_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 40, "employment_income": 70000}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -453,7 +453,7 @@ def test_uk_reform_changes_household_net_income(self): baseline_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, }, @@ -468,7 +468,7 @@ def test_uk_reform_changes_household_net_income(self): reform_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, "policy_id": policy_id, diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index e2ac120..4168882 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -18,7 +18,7 @@ def test_single_adult_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, # No policy_id means baseline vs baseline @@ -39,7 +39,7 @@ def test_impact_response_structure(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 35, "employment_income": 50000}], "year": 2026, }, @@ -70,7 +70,7 @@ def test_single_adult_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 30, "employment_income": 60000}], "year": 2024, }, @@ -86,7 +86,7 @@ def test_family_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"age": 35, "employment_income": 80000}, {"age": 33, "employment_income": 40000}, @@ -110,7 +110,7 @@ def test_invalid_model_name(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "people": [{"age": 30}], }, ) @@ -121,7 +121,7 @@ def test_missing_people(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 diff --git a/tests/test_households.py b/tests/test_households.py index 4c60062..4c1c759 100644 --- a/tests/test_households.py +++ b/tests/test_households.py @@ -22,7 +22,7 @@ def test_create_us_household(client): assert "id" in data assert "created_at" in data assert "updated_at" in data - assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["country_id"] == "us" assert data["year"] == 2024 assert data["label"] == "US test household" @@ -44,7 +44,7 @@ def test_create_uk_household(client): response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) assert response.status_code == 201 data = response.json() - assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["country_id"] == "uk" assert data["benunit"] == {"is_married": False} assert data["household"] == {"region": "LONDON"} @@ -59,9 +59,9 @@ def test_create_household_minimal(client): assert data["benunit"] is None -def test_create_household_invalid_model_name(client): - """Reject invalid tax_benefit_model_name.""" - payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} +def test_create_household_invalid_country_id(client): + """Reject invalid country_id.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "country_id": "invalid"} response = client.post("/households", json=payload) assert response.status_code == 422 @@ -78,7 +78,7 @@ def test_get_household(client, session): assert response.status_code == 200 data = response.json() assert data["id"] == str(record.id) - assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["country_id"] == "us" def test_get_household_not_found(client): @@ -111,16 +111,14 @@ def test_list_households_with_data(client, session): assert len(data) == 2 -def test_list_households_filter_by_model_name(client, session): - """Filter households by tax_benefit_model_name.""" - create_household(session, tax_benefit_model_name="policyengine_us") - create_household(session, tax_benefit_model_name="policyengine_uk") - response = client.get( - "/households", params={"tax_benefit_model_name": "policyengine_uk"} - ) +def test_list_households_filter_by_country_id(client, session): + """Filter households by country_id.""" + create_household(session, country_id="us") + create_household(session, country_id="uk") + response = client.get("/households", params={"country_id": "uk"}) data = response.json() assert len(data) == 1 - assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + assert data[0]["country_id"] == "uk" def test_list_households_limit_and_offset(client, session): diff --git a/tests/test_models.py b/tests/test_models.py index 465856c..05fc6b4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -224,13 +224,13 @@ def test_variable_with_empty_adds(): def test_household_creation(): """Test household model creation.""" household = Household( - tax_benefit_model_name="policyengine_us", + country_id="us", year=2024, label="Test household", household_data={"people": [{"age": 30}], "household": {}}, ) assert household.household_data == {"people": [{"age": 30}], "household": {}} assert household.label == "Test household" - assert household.tax_benefit_model_name == "policyengine_us" + assert household.country_id == "us" assert household.year == 2024 assert household.id is not None diff --git a/tests/test_simulations_standalone.py b/tests/test_simulations_standalone.py index 5a18414..a448bac 100644 --- a/tests/test_simulations_standalone.py +++ b/tests/test_simulations_standalone.py @@ -139,7 +139,7 @@ def test_create_economy_simulation_with_region(client, session): create_region(session, model, dataset, code="us", label="United States") payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", } response = client.post("/simulations/economy", json=payload) @@ -158,7 +158,7 @@ def test_create_economy_simulation_with_dataset(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), } response = client.post("/simulations/economy", json=payload) @@ -187,7 +187,7 @@ def test_create_economy_simulation_with_region_filter(client, session): ) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "state/ca", } response = client.post("/simulations/economy", json=payload) @@ -204,7 +204,7 @@ def test_create_economy_simulation_invalid_region(client, session): model, version = create_us_model_and_version(session) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "nonexistent/region", } response = client.post("/simulations/economy", json=payload) @@ -217,7 +217,7 @@ def test_create_economy_simulation_no_region_or_dataset(client, session): """Creating without region or dataset_id returns 422 (Pydantic validation).""" model, version = create_us_model_and_version(session) - payload = {"tax_benefit_model_name": "policyengine_us"} + payload = {"country_id": "us"} response = client.post("/simulations/economy", json=payload) assert response.status_code == 422 @@ -229,7 +229,7 @@ def test_create_economy_simulation_policy_not_found(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), "policy_id": str(uuid4()), } @@ -245,7 +245,7 @@ def test_economy_simulation_deduplication(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), } response1 = client.post("/simulations/economy", json=payload) diff --git a/tests/test_variable_labels.py b/tests/test_variable_labels.py index 4ee7705..879da13 100644 --- a/tests/test_variable_labels.py +++ b/tests/test_variable_labels.py @@ -109,7 +109,7 @@ def test_search_matches_label( "/variables", params={ "search": "Employment", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -135,7 +135,7 @@ def test_search_label_case_insensitive( "/variables", params={ "search": "INCOME TAX", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -159,7 +159,7 @@ def test_search_partial_label_match( "/variables", params={ "search": "income", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -308,14 +308,14 @@ def test_search_by_label_isolated_by_country( "/variables", params={ "search": "tax", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) uk_response = client.get( "/variables", params={ "search": "tax", - "tax_benefit_model_name": "policyengine-uk", + "country_id": "uk", }, ) From 2fe870bc508c09fdfa916a773333071470aab7b6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 14 Mar 2026 01:28:43 +0100 Subject: [PATCH 09/10] fix: Guard dict lookups, prevent duplicate auto-triggers, and reject unexpected simulation states - Fix race condition: set report status to PENDING before triggering deferred computation, preventing duplicate Modal spawns on repeated GET polls - Guard MODEL_NAME_TO_COUNTRY and COUNTRY_MODEL_NAMES dict lookups with .get() and explicit HTTP errors instead of unhandled KeyError - Move _resolve_country_from_simulation to model_resolver.py, eliminating duplication across analysis.py and household_analysis.py - Reject non-PENDING simulations with ValueError instead of silently skipping, so reports are correctly marked FAILED - Add 18 tests covering model_resolver and simulation status handling Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 23 +-- .../api/household_analysis.py | 39 ++-- .../services/model_resolver.py | 34 +++- tests/test_model_resolver.py | 192 ++++++++++++++++++ tests/test_simulation_status_handling.py | 74 +++++++ 5 files changed, 324 insertions(+), 38 deletions(-) create mode 100644 tests/test_model_resolver.py create mode 100644 tests/test_simulation_status_handling.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index dc441d8..c83c950 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -63,6 +63,7 @@ ) from policyengine_api.services.database import get_session from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, resolve_country_model, resolve_model_name, ) @@ -1336,7 +1337,10 @@ def get_economic_impact_status( # Auto-start deferred reports on first access if report.status == ReportStatus.EXECUTION_DEFERRED: - country_id = _resolve_country_from_simulation(baseline_sim, session) + report.status = ReportStatus.PENDING + session.add(report) + session.commit() + country_id = resolve_country_from_simulation(baseline_sim, session) with logfire.span("auto_trigger_economy_comparison", job_id=str(report.id)): _trigger_economy_comparison(str(report.id), country_id, session) session.refresh(report) @@ -1347,17 +1351,6 @@ def get_economic_impact_status( return _build_response(report, baseline_sim, reform_sim, session, region) -def _resolve_country_from_simulation(sim: Simulation, session: Session) -> str: - """Derive country_id from a simulation's model version.""" - version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) - if not version: - raise HTTPException(status_code=500, detail="Model version not found") - model = session.get(TaxBenefitModel, version.model_id) - if not model: - raise HTTPException(status_code=500, detail="Tax-benefit model not found") - return MODEL_NAME_TO_COUNTRY[model.name] - - # --------------------------------------------------------------------------- # POST /analysis/economy-custom — run selected economy modules # --------------------------------------------------------------------------- @@ -1645,11 +1638,7 @@ def rerun_report( raise HTTPException(status_code=500, detail="Tax-benefit model not found") # Reverse-lookup: model.name is "policyengine-us" → country_id is "us" - from policyengine_api.config.constants import COUNTRY_MODEL_NAMES - - country_id = next( - (k for k, v in COUNTRY_MODEL_NAMES.items() if v == model.name), None - ) + country_id = MODEL_NAME_TO_COUNTRY.get(model.name) if not country_id: raise HTTPException(status_code=500, detail=f"Unknown model name: {model.name}") diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index fc77163..df98653 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -22,7 +22,6 @@ from pydantic import BaseModel, Field from sqlmodel import Session -from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY from policyengine_api.models import ( Household, Policy, @@ -31,11 +30,12 @@ Simulation, SimulationStatus, SimulationType, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session -from policyengine_api.services.model_resolver import resolve_country_model +from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, + resolve_country_model, +) from .analysis import ( PolicyIdInput, @@ -418,8 +418,19 @@ def _run_local_household_impact(report_id: str, session: Session) -> None: def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: """Run a single household simulation within an existing session.""" simulation = session.get(Simulation, simulation_id) - if not simulation or simulation.status != SimulationStatus.PENDING: - return + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + if simulation.status == SimulationStatus.COMPLETED: + return # Already done, skip safely + if simulation.status != SimulationStatus.PENDING: + logfire.warn( + "Simulation in unexpected status", + simulation_id=str(simulation_id), + status=simulation.status.value, + ) + raise ValueError( + f"Simulation {simulation_id} in unexpected status: {simulation.status.value}" + ) household = session.get(Household, simulation.household_id) if not household: @@ -751,7 +762,10 @@ def get_household_impact( # Auto-start deferred reports on first access if report.status == ReportStatus.EXECUTION_DEFERRED: - country_id = _resolve_country_from_simulation(baseline_sim, session) + report.status = ReportStatus.PENDING + session.add(report) + session.commit() + country_id = resolve_country_from_simulation(baseline_sim, session) with logfire.span("auto_trigger_household_impact", job_id=str(report.id)): _trigger_household_impact(str(report.id), country_id, session) session.refresh(report) @@ -759,17 +773,6 @@ def get_household_impact( return build_household_response(report, baseline_sim, reform_sim, session) -def _resolve_country_from_simulation(sim: Simulation, session: Session) -> str: - """Derive country_id from a simulation's model version.""" - version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) - if not version: - raise HTTPException(status_code=500, detail="Model version not found") - model = session.get(TaxBenefitModel, version.model_id) - if not model: - raise HTTPException(status_code=500, detail="Tax-benefit model not found") - return MODEL_NAME_TO_COUNTRY[model.name] - - # ============================================================================= # Simulation Creation Helpers # ============================================================================= diff --git a/src/policyengine_api/services/model_resolver.py b/src/policyengine_api/services/model_resolver.py index 4eed5a4..d4b97bc 100644 --- a/src/policyengine_api/services/model_resolver.py +++ b/src/policyengine_api/services/model_resolver.py @@ -5,7 +5,12 @@ from fastapi import HTTPException from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.config.constants import ( + COUNTRY_MODEL_NAMES, + MODEL_NAME_TO_COUNTRY, + CountryId, +) +from policyengine_api.models.simulation import Simulation from policyengine_api.models.tax_benefit_model import TaxBenefitModel from policyengine_api.models.tax_benefit_model_version import ( TaxBenefitModelVersion, @@ -14,7 +19,13 @@ def resolve_model_name(country_id: CountryId) -> str: """Resolve country_id → DB model name (with hyphens).""" - return COUNTRY_MODEL_NAMES[country_id] + model_name = COUNTRY_MODEL_NAMES.get(country_id) + if not model_name: + raise HTTPException( + status_code=400, + detail=f"Unsupported country_id: '{country_id}'. Supported: {list(COUNTRY_MODEL_NAMES.keys())}", + ) + return model_name def resolve_country_model( @@ -24,7 +35,7 @@ def resolve_country_model( Explicitly selects the most recent version by created_at DESC. """ - model_name = COUNTRY_MODEL_NAMES[country_id] + model_name = resolve_model_name(country_id) model = session.exec( select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) @@ -74,3 +85,20 @@ def resolve_version_id( return version.id return None + + +def resolve_country_from_simulation(sim: Simulation, session: Session) -> str: + """Derive country_id from a simulation's model version.""" + version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) + if not version: + raise HTTPException(status_code=500, detail="Model version not found") + model = session.get(TaxBenefitModel, version.model_id) + if not model: + raise HTTPException(status_code=500, detail="Tax-benefit model not found") + country_id = MODEL_NAME_TO_COUNTRY.get(model.name) + if not country_id: + raise HTTPException( + status_code=500, + detail=f"Unknown model name: '{model.name}'. Expected: {list(MODEL_NAME_TO_COUNTRY.keys())}", + ) + return country_id diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py new file mode 100644 index 0000000..96e2bab --- /dev/null +++ b/tests/test_model_resolver.py @@ -0,0 +1,192 @@ +"""Tests for model_resolver service and related fixes.""" + +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from policyengine_api.models import ( + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) +from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, + resolve_country_model, + resolve_model_name, + resolve_version_id, +) + + +# --------------------------------------------------------------------------- +# resolve_model_name +# --------------------------------------------------------------------------- + + +class TestResolveModelName: + def test_us_returns_policyengine_us(self): + assert resolve_model_name("us") == "policyengine-us" + + def test_uk_returns_policyengine_uk(self): + assert resolve_model_name("uk") == "policyengine-uk" + + def test_invalid_country_raises_400(self): + with pytest.raises(HTTPException) as exc_info: + resolve_model_name("fr") + assert exc_info.value.status_code == 400 + assert "Unsupported country_id" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# resolve_country_model +# --------------------------------------------------------------------------- + + +class TestResolveCountryModel: + def test_returns_model_and_latest_version(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="Old" + ) + session.add(v1) + session.commit() + + v2 = TaxBenefitModelVersion( + model_id=model.id, version="2.0", description="New" + ) + session.add(v2) + session.commit() + session.refresh(v2) + + result_model, result_version = resolve_country_model("us", session) + assert result_model.id == model.id + assert result_version.id == v2.id + + def test_missing_model_raises_404(self, session): + with pytest.raises(HTTPException) as exc_info: + resolve_country_model("us", session) + assert exc_info.value.status_code == 404 + assert "Model not found" in exc_info.value.detail + + def test_missing_version_raises_404(self, session): + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + + with pytest.raises(HTTPException) as exc_info: + resolve_country_model("uk", session) + assert exc_info.value.status_code == 404 + assert "No version found" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# resolve_version_id +# --------------------------------------------------------------------------- + + +class TestResolveVersionId: + def test_explicit_version_id_returned(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + result = resolve_version_id(None, version.id, session) + assert result == version.id + + def test_explicit_version_id_not_found_raises_404(self, session): + with pytest.raises(HTTPException) as exc_info: + resolve_version_id(None, uuid4(), session) + assert exc_info.value.status_code == 404 + + def test_country_id_returns_latest_version(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + result = resolve_version_id("us", None, session) + assert result == version.id + + def test_neither_returns_none(self, session): + assert resolve_version_id(None, None, session) is None + + +# --------------------------------------------------------------------------- +# resolve_country_from_simulation +# --------------------------------------------------------------------------- + + +class TestResolveCountryFromSimulation: + def _create_simulation(self, session, model_name="policyengine-us"): + model = TaxBenefitModel(name=model_name, description="Test") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + sim = Simulation( + tax_benefit_model_version_id=version.id, + status=SimulationStatus.PENDING, + simulation_type=SimulationType.HOUSEHOLD, + ) + session.add(sim) + session.commit() + session.refresh(sim) + return sim + + def test_us_simulation_returns_us(self, session): + sim = self._create_simulation(session, "policyengine-us") + assert resolve_country_from_simulation(sim, session) == "us" + + def test_uk_simulation_returns_uk(self, session): + sim = self._create_simulation(session, "policyengine-uk") + assert resolve_country_from_simulation(sim, session) == "uk" + + def test_unknown_model_name_raises_500(self, session): + sim = self._create_simulation(session, "policyengine-fr") + with pytest.raises(HTTPException) as exc_info: + resolve_country_from_simulation(sim, session) + assert exc_info.value.status_code == 500 + assert "Unknown model name" in exc_info.value.detail + + def test_missing_version_raises_500(self, session): + sim = Simulation( + tax_benefit_model_version_id=uuid4(), + status=SimulationStatus.PENDING, + simulation_type=SimulationType.HOUSEHOLD, + ) + session.add(sim) + session.commit() + session.refresh(sim) + + with pytest.raises(HTTPException) as exc_info: + resolve_country_from_simulation(sim, session) + assert exc_info.value.status_code == 500 + assert "Model version not found" in exc_info.value.detail diff --git a/tests/test_simulation_status_handling.py b/tests/test_simulation_status_handling.py new file mode 100644 index 0000000..9314de9 --- /dev/null +++ b/tests/test_simulation_status_handling.py @@ -0,0 +1,74 @@ +"""Tests for _run_simulation_in_session status handling (Critical #3 fix).""" + +from uuid import uuid4 + +import pytest + +from policyengine_api.api.household_analysis import _run_simulation_in_session +from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def _setup_household_simulation(session, status=SimulationStatus.PENDING): + """Create a household + simulation for testing.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + household = Household( + country_id="us", + household_data={"people": {"you": {"age": {"2024": 30}}}}, + year=2024, + ) + session.add(household) + session.commit() + session.refresh(household) + + sim = Simulation( + tax_benefit_model_version_id=version.id, + status=status, + simulation_type=SimulationType.HOUSEHOLD, + household_id=household.id, + ) + session.add(sim) + session.commit() + session.refresh(sim) + return sim + + +class TestRunSimulationInSession: + def test_missing_simulation_raises_valueerror(self, session): + with pytest.raises(ValueError, match="not found"): + _run_simulation_in_session(uuid4(), session) + + def test_completed_simulation_skips_silently(self, session): + sim = _setup_household_simulation(session, SimulationStatus.COMPLETED) + # Should not raise + _run_simulation_in_session(sim.id, session) + # Status should still be COMPLETED + session.refresh(sim) + assert sim.status == SimulationStatus.COMPLETED + + def test_running_simulation_raises_valueerror(self, session): + sim = _setup_household_simulation(session, SimulationStatus.RUNNING) + with pytest.raises(ValueError, match="unexpected status"): + _run_simulation_in_session(sim.id, session) + + def test_failed_simulation_raises_valueerror(self, session): + sim = _setup_household_simulation(session, SimulationStatus.FAILED) + with pytest.raises(ValueError, match="unexpected status"): + _run_simulation_in_session(sim.id, session) From 9fd3030f0f30f5cec6376a8dcb8913e5a628bb6c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 14 Mar 2026 02:04:09 +0100 Subject: [PATCH 10/10] chore: Lint and format test files Co-Authored-By: Claude Opus 4.6 --- tests/test_model_resolver.py | 9 ++------- tests/test_simulation_status_handling.py | 4 +--- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py index 96e2bab..1888bfc 100644 --- a/tests/test_model_resolver.py +++ b/tests/test_model_resolver.py @@ -19,7 +19,6 @@ resolve_version_id, ) - # --------------------------------------------------------------------------- # resolve_model_name # --------------------------------------------------------------------------- @@ -51,15 +50,11 @@ def test_returns_model_and_latest_version(self, session): session.commit() session.refresh(model) - v1 = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="Old" - ) + v1 = TaxBenefitModelVersion(model_id=model.id, version="1.0", description="Old") session.add(v1) session.commit() - v2 = TaxBenefitModelVersion( - model_id=model.id, version="2.0", description="New" - ) + v2 = TaxBenefitModelVersion(model_id=model.id, version="2.0", description="New") session.add(v2) session.commit() session.refresh(v2) diff --git a/tests/test_simulation_status_handling.py b/tests/test_simulation_status_handling.py index 9314de9..b85b21b 100644 --- a/tests/test_simulation_status_handling.py +++ b/tests/test_simulation_status_handling.py @@ -22,9 +22,7 @@ def _setup_household_simulation(session, status=SimulationStatus.PENDING): session.commit() session.refresh(model) - version = TaxBenefitModelVersion( - model_id=model.id, version="1.0", description="V1" - ) + version = TaxBenefitModelVersion(model_id=model.id, version="1.0", description="V1") session.add(version) session.commit() session.refresh(version)