diff --git a/CHANGELOG.md b/CHANGELOG.md index 0586f3e..5a97ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +0.4.0 (2026-03-13) + +# Added + +- Standardize all endpoints on `country_id` instead of `tax_benefit_model_name` (#109) +- Default metadata endpoints (variables, parameters, datasets) to latest model version with optional version pinning (#109) +- Dual policy IDs (`baseline_policy_id` / `reform_policy_id`) on reports and `EXECUTION_DEFERRED` report status (#109) +- Auto-start `EXECUTION_DEFERRED` reports on GET endpoint access (#109) +- Convert VARCHAR enum columns to native PostgreSQL enums with `values_callable` (#109) + + 0.3.1 (2026-03-11) # Fixed diff --git a/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py new file mode 100644 index 0000000..3d4f09e --- /dev/null +++ b/alembic/versions/20260309_62385cd8049d_rename_tax_benefit_model_name_to_.py @@ -0,0 +1,85 @@ +"""rename_tax_benefit_model_name_to_country_id + +Revision ID: 62385cd8049d +Revises: 886921687770 +Create Date: 2026-03-09 16:48:30.899791 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "62385cd8049d" +down_revision: Union[str, Sequence[str], None] = "886921687770" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema: rename tax_benefit_model_name → country_id with data migration.""" + # 1. Add country_id columns (nullable initially) + op.add_column( + "households", + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "household_jobs", + sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + + # 2. Populate country_id from tax_benefit_model_name + op.execute(""" + UPDATE households SET country_id = CASE + WHEN tax_benefit_model_name LIKE '%_us' OR tax_benefit_model_name LIKE '%-us' THEN 'us' + WHEN tax_benefit_model_name LIKE '%_uk' OR tax_benefit_model_name LIKE '%-uk' THEN 'uk' + ELSE 'us' + END + """) + op.execute(""" + UPDATE household_jobs SET country_id = CASE + WHEN tax_benefit_model_name LIKE '%_us' OR tax_benefit_model_name LIKE '%-us' THEN 'us' + WHEN tax_benefit_model_name LIKE '%_uk' OR tax_benefit_model_name LIKE '%-uk' THEN 'uk' + ELSE 'us' + END + """) + + # 3. Make country_id non-nullable + op.alter_column("households", "country_id", nullable=False) + op.alter_column("household_jobs", "country_id", nullable=False) + + # 4. Drop old columns + op.drop_column("households", "tax_benefit_model_name") + op.drop_column("household_jobs", "tax_benefit_model_name") + + +def downgrade() -> None: + """Downgrade schema: restore tax_benefit_model_name from country_id.""" + # 1. Re-add tax_benefit_model_name columns (nullable initially) + op.add_column( + "households", sa.Column("tax_benefit_model_name", sa.VARCHAR(), nullable=True) + ) + op.add_column( + "household_jobs", + sa.Column("tax_benefit_model_name", sa.VARCHAR(), nullable=True), + ) + + # 2. Populate from country_id + op.execute( + "UPDATE households SET tax_benefit_model_name = 'policyengine_' || country_id" + ) + op.execute( + "UPDATE household_jobs SET tax_benefit_model_name = 'policyengine_' || country_id" + ) + + # 3. Make non-nullable + op.alter_column("households", "tax_benefit_model_name", nullable=False) + op.alter_column("household_jobs", "tax_benefit_model_name", nullable=False) + + # 4. Drop country_id columns + op.drop_column("households", "country_id") + op.drop_column("household_jobs", "country_id") diff --git a/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py new file mode 100644 index 0000000..45c6b2f --- /dev/null +++ b/alembic/versions/20260310_f887cb5490bc_add_execution_deferred_to_reportstatus.py @@ -0,0 +1,31 @@ +"""add_execution_deferred_to_reportstatus + +Revision ID: f887cb5490bc +Revises: 62385cd8049d +Create Date: 2026-03-10 21:27:32.072364 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f887cb5490bc" +down_revision: Union[str, Sequence[str], None] = "62385cd8049d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add EXECUTION_DEFERRED value to the reportstatus enum.""" + op.execute("ALTER TYPE reportstatus ADD VALUE IF NOT EXISTS 'EXECUTION_DEFERRED'") + + +def downgrade() -> None: + """Downgrade: PostgreSQL does not support removing enum values. + + The 'EXECUTION_DEFERRED' value will remain in the enum type. + To fully remove it, drop and recreate the type (requires migrating data). + """ + pass diff --git a/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py new file mode 100644 index 0000000..6a00cbf --- /dev/null +++ b/alembic/versions/20260311_dac22a838dda_convert_varchar_enums_to_native_pg_enums.py @@ -0,0 +1,92 @@ +"""convert_varchar_enums_to_native_pg_enums + +Revision ID: dac22a838dda +Revises: f887cb5490bc +Create Date: 2026-03-11 01:37:08.928795 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dac22a838dda" +down_revision: Union[str, Sequence[str], None] = "f887cb5490bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Convert VARCHAR enum columns to native PostgreSQL enum types. + + The enum types may already exist with UPPERCASE values (created by + SQLAlchemy's default create_all behavior). Since the columns are still + VARCHAR, the types are unused — drop and recreate with lowercase values + matching the data and the values_callable convention. + """ + # Drop any pre-existing enum types (unused — columns are still VARCHAR) + op.execute("DROP TYPE IF EXISTS regiontype CASCADE") + op.execute("DROP TYPE IF EXISTS reporttype CASCADE") + op.execute("DROP TYPE IF EXISTS deciletype CASCADE") + + # Create PG enum types with lowercase values + op.execute(""" + CREATE TYPE regiontype AS ENUM ( + 'national', 'country', 'state', 'congressional_district', + 'constituency', 'local_authority', 'city', 'place' + ) + """) + op.execute(""" + CREATE TYPE reporttype AS ENUM ( + 'economy_comparison', 'household_comparison', 'household_single' + ) + """) + op.execute("CREATE TYPE deciletype AS ENUM ('income', 'wealth')") + + # Alter columns from VARCHAR to enum. + # LOWER() handles any databases where values were previously uppercased. + op.execute(""" + ALTER TABLE regions + ALTER COLUMN region_type TYPE regiontype + USING LOWER(region_type)::regiontype + """) + op.execute(""" + ALTER TABLE reports + ALTER COLUMN report_type TYPE reporttype + USING LOWER(report_type)::reporttype + """) + # decile_type has a VARCHAR default that must be dropped before type change + op.execute("ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type DROP DEFAULT") + op.execute(""" + ALTER TABLE intra_decile_impacts + ALTER COLUMN decile_type TYPE deciletype + USING LOWER(decile_type)::deciletype + """) + op.execute( + "ALTER TABLE intra_decile_impacts ALTER COLUMN decile_type SET DEFAULT 'income'::deciletype" + ) + + +def downgrade() -> None: + """Revert native PG enum columns back to VARCHAR.""" + op.execute(""" + ALTER TABLE regions + ALTER COLUMN region_type TYPE VARCHAR + USING region_type::text + """) + op.execute(""" + ALTER TABLE reports + ALTER COLUMN report_type TYPE VARCHAR + USING report_type::text + """) + op.execute(""" + ALTER TABLE intra_decile_impacts + ALTER COLUMN decile_type TYPE VARCHAR + USING decile_type::text + """) + + # Drop the PG enum types + op.execute("DROP TYPE IF EXISTS regiontype") + op.execute("DROP TYPE IF EXISTS reporttype") + op.execute("DROP TYPE IF EXISTS deciletype") diff --git a/scripts/seed_regions.py b/scripts/seed_regions.py index 016b8dc..5ab3d37 100644 --- a/scripts/seed_regions.py +++ b/scripts/seed_regions.py @@ -33,6 +33,7 @@ RegionDatasetLink, TaxBenefitModel, ) +from policyengine_api.models.region import RegionType # noqa: E402 def _group_us_datasets( @@ -195,7 +196,7 @@ def seed_us_regions( db_region = Region( code=pe_region.code, label=pe_region.label, - region_type=pe_region.region_type, + region_type=RegionType(pe_region.region_type), requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, @@ -293,7 +294,7 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]: db_region = Region( code=pe_region.code, label=pe_region.label, - region_type=pe_region.region_type, + region_type=RegionType(pe_region.region_type), requires_filter=pe_region.requires_filter, filter_field=pe_region.filter_field, filter_value=pe_region.filter_value, diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index bcf1ab0..0aa578a 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -53,9 +53,9 @@ def configure_logfire(traceparent: str | None = None): ## CRITICAL: Always filter by country -When searching for parameters or datasets, ALWAYS include tax_benefit_model_name: -- "policyengine-uk" for UK questions -- "policyengine-us" for US questions +When searching for parameters or datasets, ALWAYS include country_id: +- "uk" for UK questions +- "us" for US questions Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results and waste turns finding the right ones. @@ -66,14 +66,14 @@ def configure_logfire(traceparent: str | None = None): - Poll GET /household/calculate/{job_id} until completed 2. **Parameter lookup**: - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk (ALWAYS include country filter) + - GET /parameters/?search=...&country_id=uk (ALWAYS include country filter) - GET /parameter-values/?parameter_id=...¤t=true for the current value 3. **Economic impact analysis** (budget impact, decile impacts): - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk to find parameter_id + - GET /parameters/?search=...&country_id=uk to find parameter_id - POST /policies/ to create reform with parameter_values - - GET /datasets/?tax_benefit_model_name=policyengine-uk to find dataset_id - - POST /analysis/economic-impact with tax_benefit_model_name, policy_id and dataset_id + - GET /datasets/?country_id=uk to find dataset_id + - POST /analysis/economic-impact with country_id, policy_id and dataset_id - GET /analysis/economic-impact/{report_id} for results (includes decile_impacts and program_statistics) ## Response formatting diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index d2e5864..c83c950 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -16,7 +16,7 @@ """ import math -from typing import Literal +from typing import Literal, Union from uuid import UUID, uuid5 import logfire @@ -30,6 +30,7 @@ get_modules_for_country, validate_modules, ) +from policyengine_api.config.constants import MODEL_NAME_TO_COUNTRY, CountryId from policyengine_api.models import ( BudgetSummary, BudgetSummaryRead, @@ -61,6 +62,26 @@ TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, + resolve_country_model, + resolve_model_name, +) + +# Type for API policy inputs: UUID, "current_law", or None (omitted). +PolicyIdInput = Union[UUID, Literal["current_law"], None] + + +def _resolve_policy_input(value: PolicyIdInput) -> UUID | None: + """Convert API policy sentinel to DB policy_id. + + - UUID → that policy ID + - "current_law" → None (current law has no policy record) + - None → None + """ + if value is None or value == "current_law": + return None + return value def get_traceparent() -> str | None: @@ -128,23 +149,17 @@ def list_analysis_options( class EconomicImpactRequest(BaseModel): """Request body for economic impact analysis. - Example with dataset_id: - { - "tax_benefit_model_name": "policyengine_uk", - "dataset_id": "uuid-from-datasets-endpoint", - "policy_id": "uuid-of-reform-policy" - } - Example with region: { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "state/ca", - "policy_id": "uuid-of-reform-policy" + "baseline_policy_id": "current_law", + "reform_policy_id": "uuid-of-reform-policy" } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) dataset_id: UUID | None = Field( default=None, @@ -154,9 +169,13 @@ class EconomicImpactRequest(BaseModel): default=None, description="Region code (e.g., 'state/ca', 'us'). Either dataset_id or region must be provided.", ) - policy_id: UUID | None = Field( - default=None, - description="Reform policy ID to compare against baseline (current law)", + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( + default="current_law", + description="Reform policy. UUID for a specific policy, or 'current_law' for current law.", ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" @@ -165,6 +184,10 @@ class EconomicImpactRequest(BaseModel): default=None, description="Year for the analysis (e.g., 2026). Selects the dataset for that year. Uses latest available if omitted.", ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) @model_validator(mode="after") def check_dataset_or_region(self) -> "EconomicImpactRequest": @@ -215,33 +238,6 @@ class EconomicImpactResponse(BaseModel): intra_wealth_decile: list[IntraDecileImpactRead] | None = None -def _get_model_version( - tax_benefit_model_name: str, session: Session -) -> TaxBenefitModelVersion: - """Get the latest tax benefit model version.""" - model_name = tax_benefit_model_name.replace("_", "-") - - model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) - ).first() - if not model: - raise HTTPException( - status_code=404, detail=f"Tax benefit model {model_name} not found" - ) - - version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - if not version: - raise HTTPException( - status_code=404, detail=f"No version found for model {model_name}" - ) - - return version - - def _get_deterministic_simulation_id( simulation_type: SimulationType, model_version_id: UUID, @@ -1114,13 +1110,14 @@ def build_dynamic(dynamic_id): def _trigger_economy_comparison( job_id: str, - tax_benefit_model_name: str, + country_id: str, session: Session | None = None, modules: list[str] | None = None, ) -> None: """Trigger economy comparison analysis (local or Modal). Args: + country_id: Country code ('us' or 'uk'). modules: Optional list of module names to run. If None, runs all. """ from policyengine_api.config import settings @@ -1129,7 +1126,7 @@ def _trigger_economy_comparison( if not settings.agent_use_modal and session is not None: # Run locally - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": _run_local_economy_comparison_uk(job_id, session, modules=modules) else: _run_local_economy_comparison_us(job_id, session, modules=modules) @@ -1137,7 +1134,7 @@ def _trigger_economy_comparison( # Use Modal (modules param passed for future selective computation) import modal - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": fn = modal.Function.from_name( "policyengine", "economy_comparison_uk", @@ -1184,7 +1181,7 @@ def _resolve_dataset_and_region( """ if request.region: # Look up region by code - model_name = request.tax_benefit_model_name.replace("_", "-") + model_name = resolve_model_name(request.country_id) region = session.exec( select(Region) .join(TaxBenefitModel) @@ -1195,7 +1192,7 @@ def _resolve_dataset_and_region( if not region: raise HTTPException( status_code=404, - detail=f"Region '{request.region}' not found for model {model_name}", + detail=f"Region '{request.region}' not found for country {request.country_id}", ) # Resolve dataset from join table, filtered by year if provided @@ -1262,13 +1259,17 @@ def economic_impact( ) # Get model version - model_version = _get_model_version(request.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(request.country_id, session) + + # Resolve policy sentinels to DB policy IDs + baseline_policy_db = _resolve_policy_input(request.baseline_policy_id) + reform_policy_db = _resolve_policy_input(request.reform_policy_id) # Get or create simulations using the resolved dataset baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=None, + policy_id=baseline_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1282,7 +1283,7 @@ def economic_impact( reform_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=request.policy_id, + policy_id=reform_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1294,20 +1295,23 @@ def economic_impact( ) # Get or create report - label = f"Economic impact: {request.tax_benefit_model_name}" - if request.policy_id: - label += f" (policy {request.policy_id})" + label = f"Economic impact: {request.country_id}" + if request.reform_policy_id and request.reform_policy_id != "current_law": + label += f" (policy {request.reform_policy_id})" report = _get_or_create_report( baseline_sim.id, reform_sim.id, label, "economy_comparison", session ) - # Trigger computation if report is pending - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison( - str(report.id), request.tax_benefit_model_name, session - ) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison(str(report.id), request.country_id, session) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() return _build_response(report, baseline_sim, reform_sim, session, region) @@ -1331,6 +1335,16 @@ def get_economic_impact_status( if not baseline_sim or not reform_sim: raise HTTPException(status_code=500, detail="Simulation data missing") + # Auto-start deferred reports on first access + if report.status == ReportStatus.EXECUTION_DEFERRED: + report.status = ReportStatus.PENDING + session.add(report) + session.commit() + country_id = resolve_country_from_simulation(baseline_sim, session) + with logfire.span("auto_trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison(str(report.id), country_id, session) + session.refresh(report) + region = ( session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None ) @@ -1341,17 +1355,12 @@ def get_economic_impact_status( # POST /analysis/economy-custom — run selected economy modules # --------------------------------------------------------------------------- -_MODEL_TO_COUNTRY = { - "policyengine_uk": "uk", - "policyengine_us": "us", -} - class EconomyCustomRequest(BaseModel): """Request body for custom economy analysis with selected modules.""" - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) dataset_id: UUID | None = Field( default=None, @@ -1361,9 +1370,13 @@ class EconomyCustomRequest(BaseModel): default=None, description="Region code (e.g., 'state/ca', 'us').", ) - policy_id: UUID | None = Field( - default=None, - description="Reform policy ID to compare against baseline (current law)", + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( + default="current_law", + description="Reform policy. UUID for a specific policy, or 'current_law' for current law.", ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID" @@ -1375,6 +1388,10 @@ class EconomyCustomRequest(BaseModel): modules: list[str] = Field( description="List of module names to compute (see GET /analysis/options)" ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) @model_validator(mode="after") def check_dataset_or_region(self) -> "EconomyCustomRequest": @@ -1430,21 +1447,21 @@ def economy_custom( See GET /analysis/options for available module names. """ - country = _MODEL_TO_COUNTRY[request.tax_benefit_model_name] - try: - validate_modules(request.modules, country) + validate_modules(request.modules, request.country_id) except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) # Reuse the same request model for dataset/region resolution impact_request = EconomicImpactRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, dataset_id=request.dataset_id, region=request.region, - policy_id=request.policy_id, + baseline_policy_id=request.baseline_policy_id, + reform_policy_id=request.reform_policy_id, dynamic_id=request.dynamic_id, year=request.year, + run=request.run, ) dataset, region_obj = _resolve_dataset_and_region(impact_request, session) @@ -1461,12 +1478,16 @@ def economy_custom( else None ) - model_version = _get_model_version(request.tax_benefit_model_name, session) + # Resolve policy sentinels to DB policy IDs + baseline_policy_db = _resolve_policy_input(request.baseline_policy_id) + reform_policy_db = _resolve_policy_input(request.reform_policy_id) + + _model, model_version = resolve_country_model(request.country_id, session) baseline_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=None, + policy_id=baseline_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1480,7 +1501,7 @@ def economy_custom( reform_sim = _get_or_create_simulation( simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, - policy_id=request.policy_id, + policy_id=reform_policy_db, dynamic_id=request.dynamic_id, session=session, dataset_id=dataset.id, @@ -1491,22 +1512,28 @@ def economy_custom( year=dataset.year, ) - label = f"Custom analysis: {request.tax_benefit_model_name}" - if request.policy_id: - label += f" (policy {request.policy_id})" + label = f"Custom analysis: {request.country_id}" + if request.reform_policy_id and request.reform_policy_id != "current_law": + label += f" (policy {request.reform_policy_id})" report = _get_or_create_report( baseline_sim.id, reform_sim.id, label, "economy_comparison", session ) - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison( - str(report.id), - request.tax_benefit_model_name, - session, - modules=request.modules, - ) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_economy_comparison", job_id=str(report.id)): + _trigger_economy_comparison( + str(report.id), + request.country_id, + session, + modules=request.modules, + ) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() full_response = _build_response( report, baseline_sim, reform_sim, session, region_obj @@ -1599,7 +1626,7 @@ def rerun_report( if not baseline_sim: raise HTTPException(status_code=400, detail="Report has no baseline simulation") - # 3. Derive tax_benefit_model_name from simulation → model version → model + # 3. Derive country_id from simulation → model version → model model_version = session.get( TaxBenefitModelVersion, baseline_sim.tax_benefit_model_version_id ) @@ -1610,7 +1637,10 @@ def rerun_report( if not model: raise HTTPException(status_code=500, detail="Tax-benefit model not found") - tax_benefit_model_name = model.name.replace("-", "_") + # Reverse-lookup: model.name is "policyengine-us" → country_id is "us" + country_id = MODEL_NAME_TO_COUNTRY.get(model.name) + if not country_id: + raise HTTPException(status_code=500, detail=f"Unknown model name: {model.name}") # 4. Delete all result records for this report result_tables = [ @@ -1648,10 +1678,10 @@ def rerun_report( if is_economy: with logfire.span("rerun_economy_comparison", job_id=str(report.id)): - _trigger_economy_comparison(str(report.id), tax_benefit_model_name, session) + _trigger_economy_comparison(str(report.id), country_id, session) elif is_household: with logfire.span("rerun_household_impact", job_id=str(report.id)): - _trigger_household_impact(str(report.id), tax_benefit_model_name, session) + _trigger_household_impact(str(report.id), country_id, session) else: raise HTTPException( status_code=400, diff --git a/src/policyengine_api/api/datasets.py b/src/policyengine_api/api/datasets.py index 82540b7..e01dc29 100644 --- a/src/policyengine_api/api/datasets.py +++ b/src/policyengine_api/api/datasets.py @@ -11,15 +11,17 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import Dataset, DatasetRead, TaxBenefitModel from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/datasets", tags=["datasets"]) @router.get("/", response_model=List[DatasetRead]) def list_datasets( - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, session: Session = Depends(get_session), ): """List available datasets. @@ -28,16 +30,13 @@ def list_datasets( Each dataset represents population microdata for a specific country and year. Args: - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK datasets. - Use "policyengine-us" for US datasets. + country_id: Filter by country ("us" or "uk"). """ query = select(Dataset) - if tax_benefit_model_name: - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name - ) + if country_id: + model_name = resolve_model_name(country_id) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) datasets = session.exec(query).all() return datasets diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 5fda4b4..3abdc10 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -5,7 +5,7 @@ """ import math -from typing import Any, Literal +from typing import Any from uuid import UUID import logfire @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from sqlmodel import Session +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Dynamic, HouseholdJob, @@ -61,7 +62,7 @@ class HouseholdCalculateRequest(BaseModel): Example US request (single household, simple): { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"employment_income": 70000, "age": 40}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -70,7 +71,7 @@ class HouseholdCalculateRequest(BaseModel): Example US request (multiple households): { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"person_id": 0, "person_household_id": 0, "person_tax_unit_id": 0, "age": 40, "employment_income": 70000}, {"person_id": 1, "person_household_id": 1, "person_tax_unit_id": 1, "age": 30, "employment_income": 50000} @@ -88,15 +89,15 @@ class HouseholdCalculateRequest(BaseModel): Example UK request: { - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"employment_income": 50000, "age": 30}], "household": [{}], "year": 2026 } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) people: list[dict[str, Any]] = Field( description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." @@ -173,7 +174,7 @@ class HouseholdImpactRequest(BaseModel): Example: { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"employment_income": 70000, "age": 40}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -182,8 +183,8 @@ class HouseholdImpactRequest(BaseModel): } """ - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) people: list[dict[str, Any]] = Field( description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." @@ -726,7 +727,7 @@ def _trigger_modal_household( if not settings.agent_use_modal and session is not None: # Run locally - if request.tax_benefit_model_name == "policyengine_uk": + if request.country_id == "uk": _run_local_household_uk( job_id=job_id, people=request.people, @@ -755,7 +756,7 @@ def _trigger_modal_household( traceparent = get_traceparent() - if request.tax_benefit_model_name == "policyengine_uk": + if request.country_id == "uk": fn = modal.Function.from_name( "policyengine", "simulate_household_uk", @@ -864,7 +865,7 @@ def calculate_household( """ with logfire.span( "create_household_job", - model=request.tax_benefit_model_name, + model=request.country_id, num_people=len(request.people), year=request.year, has_policy=request.policy_id is not None, @@ -876,7 +877,7 @@ def calculate_household( # Create job record job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -995,7 +996,7 @@ def calculate_household_impact_comparison( """ with logfire.span( "create_household_impact_job", - model=request.tax_benefit_model_name, + model=request.country_id, num_people=len(request.people), year=request.year, has_policy=request.policy_id is not None, @@ -1006,7 +1007,7 @@ def calculate_household_impact_comparison( # Create baseline job (no policy) baseline_job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -1026,7 +1027,7 @@ def calculate_household_impact_comparison( # Create reform job (with policy) reform_job = HouseholdJob( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, request_data={ "people": request.people, "benunit": request.benunit, @@ -1055,7 +1056,7 @@ def calculate_household_impact_comparison( # Trigger Modal functions for both baseline_request = HouseholdCalculateRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, people=request.people, benunit=request.benunit, marital_unit=request.marital_unit, @@ -1068,7 +1069,7 @@ def calculate_household_impact_comparison( dynamic_id=request.dynamic_id, ) reform_request = HouseholdCalculateRequest( - tax_benefit_model_name=request.tax_benefit_model_name, + country_id=request.country_id, people=request.people, benunit=request.benunit, marital_unit=request.marital_unit, diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 9c22234..df98653 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -32,11 +32,16 @@ SimulationType, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, + resolve_country_model, +) from .analysis import ( - _get_model_version, + PolicyIdInput, _get_or_create_report, _get_or_create_simulation, + _resolve_policy_input, ) @@ -81,9 +86,9 @@ class CountryConfig: ) -def get_country_config(tax_benefit_model_name: str) -> CountryConfig: - """Get country configuration from model name.""" - if tax_benefit_model_name == "policyengine_uk": +def get_country_config(country_id: str) -> CountryConfig: + """Get country configuration from country_id.""" + if country_id == "uk": return UK_CONFIG return US_CONFIG @@ -136,9 +141,9 @@ def calculate_us_household( ) -def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: +def get_calculator(country_id: str) -> HouseholdCalculator: """Get the appropriate calculator for a country.""" - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": return calculate_uk_household return calculate_us_household @@ -331,7 +336,7 @@ def run_household_simulation(simulation_id: UUID, session: Session) -> None: mark_simulation_running(simulation, session) try: - calculator = get_calculator(household.tax_benefit_model_name) + calculator = get_calculator(household.country_id) result = calculator(household.household_data, household.year, policy_data) mark_simulation_completed(simulation, result, session) except Exception as e: @@ -413,8 +418,19 @@ def _run_local_household_impact(report_id: str, session: Session) -> None: def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: """Run a single household simulation within an existing session.""" simulation = session.get(Simulation, simulation_id) - if not simulation or simulation.status != SimulationStatus.PENDING: - return + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + if simulation.status == SimulationStatus.COMPLETED: + return # Already done, skip safely + if simulation.status != SimulationStatus.PENDING: + logfire.warn( + "Simulation in unexpected status", + simulation_id=str(simulation_id), + status=simulation.status.value, + ) + raise ValueError( + f"Simulation {simulation_id} in unexpected status: {simulation.status.value}" + ) household = session.get(Household, simulation.household_id) if not household: @@ -428,7 +444,7 @@ def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: session.commit() try: - calculator = get_calculator(household.tax_benefit_model_name) + calculator = get_calculator(household.country_id) result = calculator(household.household_data, household.year, policy_data) simulation.household_result = result @@ -446,9 +462,13 @@ def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: def _trigger_household_impact( - report_id: str, tax_benefit_model_name: str, session: Session | None = None + report_id: str, country_id: str, session: Session | None = None ) -> None: - """Trigger household impact calculation (local or Modal based on settings).""" + """Trigger household impact calculation (local or Modal based on settings). + + Args: + country_id: Country code ('us' or 'uk'). + """ from policyengine_api.config import settings traceparent = get_traceparent() @@ -460,7 +480,7 @@ def _trigger_household_impact( # Use Modal import modal - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": fn = modal.Function.from_name( "policyengine", "household_impact_uk", @@ -508,14 +528,25 @@ class HouseholdImpactRequest(BaseModel): """Request for household impact analysis.""" household_id: UUID = Field(description="ID of the household to analyze") - policy_id: UUID | None = Field( + baseline_policy_id: PolicyIdInput = Field( + default="current_law", + description="Baseline policy. UUID for a specific policy, or 'current_law' for current law.", + ) + reform_policy_id: PolicyIdInput = Field( default=None, - description="Reform policy ID. If None, runs single calculation under current law.", + description=( + "Reform policy. UUID for a specific policy, 'current_law' for current law, " + "or omit entirely for a single calculation (no comparison)." + ), ) dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response specification ID", ) + run: bool = Field( + default=True, + description="If false, create report and simulations with EXECUTION_DEFERRED status without triggering computation.", + ) class HouseholdSimulationInfo(BaseModel): @@ -603,7 +634,7 @@ def _compute_impact_if_comparison( if not household: return None - config = get_country_config(household.tax_benefit_model_name) + config = get_country_config(household.country_id) return compute_household_impact(baseline_result, reform_result, config) @@ -648,39 +679,59 @@ def household_impact( ) -> HouseholdImpactResponse: """Run household impact analysis. - If policy_id is None: single run under current law. - If policy_id is set: comparison (baseline vs reform). + If reform_policy_id is omitted (None): single run under baseline policy. + If reform_policy_id is set (UUID or "current_law"): comparison (baseline vs reform). This is an async operation. The endpoint returns immediately with a report_id and status="pending". Poll GET /analysis/household-impact/{report_id} until status="completed" to get results. """ household = validate_household_exists(request.household_id, session) - validate_policy_exists(request.policy_id, session) - model_version = _get_model_version(household.tax_benefit_model_name, session) + # Resolve policy sentinels to DB policy IDs + baseline_db_id = _resolve_policy_input(request.baseline_policy_id) + reform_db_id = _resolve_policy_input(request.reform_policy_id) + has_reform = request.reform_policy_id is not None + + # Validate policies exist in DB + validate_policy_exists(baseline_db_id, session) + validate_policy_exists(reform_db_id, session) + + _model, model_version = resolve_country_model(household.country_id, session) baseline_sim = _create_baseline_simulation( - household, model_version.id, request.dynamic_id, session + household, + model_version.id, + request.dynamic_id, + session, + policy_id=baseline_db_id, ) - reform_sim = _create_reform_simulation( - household, model_version.id, request.policy_id, request.dynamic_id, session + reform_sim = ( + _create_reform_simulation( + household, model_version.id, reform_db_id, request.dynamic_id, session + ) + if has_reform + else None ) - report_type = "household_comparison" if request.policy_id else "household_single" + report_type = "household_comparison" if has_reform else "household_single" report = _get_or_create_report( baseline_sim_id=baseline_sim.id, reform_sim_id=reform_sim.id if reform_sim else None, - label=f"Household impact: {household.tax_benefit_model_name}", + label=f"Household impact: {household.country_id}", report_type=report_type, session=session, ) - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_household_impact", job_id=str(report.id)): - _trigger_household_impact( - str(report.id), household.tax_benefit_model_name, session - ) + # Trigger computation or defer + if report.status in (ReportStatus.PENDING, ReportStatus.EXECUTION_DEFERRED): + if request.run: + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact(str(report.id), household.country_id, session) + elif report.status == ReportStatus.PENDING: + report.status = ReportStatus.EXECUTION_DEFERRED + session.add(report) + session.commit() return build_household_response(report, baseline_sim, reform_sim, session) @@ -709,6 +760,16 @@ def get_household_impact( if report.reform_simulation_id: reform_sim = session.get(Simulation, report.reform_simulation_id) + # Auto-start deferred reports on first access + if report.status == ReportStatus.EXECUTION_DEFERRED: + report.status = ReportStatus.PENDING + session.add(report) + session.commit() + country_id = resolve_country_from_simulation(baseline_sim, session) + with logfire.span("auto_trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact(str(report.id), country_id, session) + session.refresh(report) + return build_household_response(report, baseline_sim, reform_sim, session) @@ -722,12 +783,13 @@ def _create_baseline_simulation( model_version_id: UUID, dynamic_id: UUID | None, session: Session, + policy_id: UUID | None = None, ) -> Simulation: - """Create baseline simulation (current law, no policy).""" + """Create baseline simulation.""" return _get_or_create_simulation( simulation_type=SimulationType.HOUSEHOLD, model_version_id=model_version_id, - policy_id=None, + policy_id=policy_id, dynamic_id=dynamic_id, session=session, household_id=household.id, diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py index fdee1f7..ea64f9d 100644 --- a/src/policyengine_api/api/households.py +++ b/src/policyengine_api/api/households.py @@ -45,7 +45,7 @@ def _to_read(record: Household) -> HouseholdRead: data = record.household_data return HouseholdRead( id=record.id, - tax_benefit_model_name=record.tax_benefit_model_name, + country_id=record.country_id, year=record.year, label=record.label, people=data["people"], @@ -69,7 +69,7 @@ def create_household(body: HouseholdCreate, session: Session = Depends(get_sessi or /household/impact to run simulations. """ record = Household( - tax_benefit_model_name=body.tax_benefit_model_name, + country_id=body.country_id, year=body.year, label=body.label, household_data=_pack_household_data(body), @@ -82,15 +82,15 @@ def create_household(body: HouseholdCreate, session: Session = Depends(get_sessi @router.get("/", response_model=list[HouseholdRead]) def list_households( - tax_benefit_model_name: str | None = None, + country_id: str | None = None, limit: int = Query(default=50, le=200), offset: int = Query(default=0, ge=0), session: Session = Depends(get_session), ): """List stored households with optional filtering.""" query = select(Household) - if tax_benefit_model_name is not None: - query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + if country_id is not None: + query = query.where(Household.country_id == country_id) query = query.offset(offset).limit(limit) records = session.exec(query).all() return [_to_read(r) for r in records] diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 4668ab8..3311cbf 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -12,8 +12,10 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, or_, select -from policyengine_api.models import ParameterValue, ParameterValueRead +from policyengine_api.config.constants import CountryId +from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/parameter-values", tags=["parameter-values"]) @@ -23,6 +25,8 @@ def list_parameter_values( parameter_id: UUID | None = None, policy_id: UUID | None = None, current: bool = False, + country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, skip: int = 0, limit: int = 100, session: Session = Depends(get_session), @@ -37,6 +41,10 @@ def list_parameter_values( policy_id: Filter by a specific policy reform. current: If true, only return values that are currently in effect (start_date <= now and (end_date is null or end_date > now)). + country_id: Filter to values belonging to parameters from this country. + Defaults to the latest model version. + tax_benefit_model_version_id: Filter to values from a specific model + version. Takes precedence over country_id. """ query = select(ParameterValue) @@ -46,6 +54,12 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) + if version_id: + query = query.join(Parameter).where( + Parameter.tax_benefit_model_version_id == version_id + ) + if current: now = datetime.now(timezone.utc) query = query.where( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 6a807e7..93d052e 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -14,15 +14,14 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Parameter, ParameterNode, ParameterRead, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -32,7 +31,8 @@ def list_parameters( skip: int = 0, limit: int = 100, search: str | None = None, - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -42,22 +42,18 @@ def list_parameters( Args: search: Filter by parameter name, label, or description. - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK parameters. - Use "policyengine-us" for US parameters. + country_id: Filter by country ("us" or "uk"). + Defaults to the latest model version. + tax_benefit_model_version_id: Pin to a specific model version. + Takes precedence over country_id. """ query = select(Parameter) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Parameter.name.ilike(search_pattern) @@ -77,6 +73,7 @@ class ParameterByNameRequest(BaseModel): names: list[str] country_id: CountryId + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[ParameterRead]) @@ -96,18 +93,15 @@ def get_parameters_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - - query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.in_(request.names)) - .order_by(Parameter.name) + version_id = resolve_version_id( + request.country_id, request.tax_benefit_model_version_id, session ) - return session.exec(query).all() + query = select(Parameter).where(Parameter.name.in_(request.names)) + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Parameter.name)).all() class ParameterChild(BaseModel): @@ -133,6 +127,9 @@ def get_parameter_children( parent_path: str = Query( default="", description="Parent parameter path (e.g. 'gov' or 'gov.hmrc')" ), + tax_benefit_model_version_id: UUID | None = Query( + default=None, description="Optional specific model version ID" + ), session: Session = Depends(get_session), ) -> ParameterChildrenResponse: """Get direct children of a parameter path for tree navigation. @@ -141,27 +138,24 @@ def get_parameter_children( parameters (with full metadata). Use this to lazily load the parameter tree one level at a time. """ - model_name = COUNTRY_MODEL_NAMES[country_id] + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) + prefix = f"{parent_path}." if parent_path else "" # Fetch all parameters under this path - param_query = ( - select(Parameter) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Parameter.name.startswith(prefix)) - ) + param_query = select(Parameter).where(Parameter.name.startswith(prefix)) + if version_id: + param_query = param_query.where( + Parameter.tax_benefit_model_version_id == version_id + ) descendants = session.exec(param_query).all() # Fetch all parameter nodes under this path for labels - node_query = ( - select(ParameterNode) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(ParameterNode.name.startswith(prefix)) - ) + node_query = select(ParameterNode).where(ParameterNode.name.startswith(prefix)) + if version_id: + node_query = node_query.where( + ParameterNode.tax_benefit_model_version_id == version_id + ) nodes = session.exec(node_query).all() # Build a map of node path -> label for quick lookup diff --git a/src/policyengine_api/api/regions.py b/src/policyengine_api/api/regions.py index 1d0a34e..b5a5c6e 100644 --- a/src/policyengine_api/api/regions.py +++ b/src/policyengine_api/api/regions.py @@ -11,20 +11,22 @@ from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import Region, RegionRead, TaxBenefitModel from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_model_name router = APIRouter(prefix="/regions", tags=["regions"]) @router.get("/", response_model=List[RegionRead]) def list_regions( + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), tax_benefit_model_id: UUID | None = Query( None, description="Filter by tax-benefit model ID" ), - tax_benefit_model_name: str | None = Query( - None, description="Filter by tax-benefit model name (e.g., 'policyengine-us')" - ), region_type: str | None = Query( None, description="Filter by region type (e.g., 'state', 'congressional_district')", @@ -37,18 +39,17 @@ def list_regions( Each region represents a geographic area with an associated dataset. Args: - tax_benefit_model_id: Filter by tax-benefit model UUID. - tax_benefit_model_name: Filter by model name (e.g., "policyengine-us"). + country_id: Filter by country ("us" or "uk"). + tax_benefit_model_id: Filter by tax-benefit model UUID (alternative to country_id). region_type: Filter by region type (e.g., "state", "congressional_district"). """ query = select(Region) - if tax_benefit_model_id: + if country_id: + model_name = resolve_model_name(country_id) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + elif tax_benefit_model_id: query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) - elif tax_benefit_model_name: - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name - ) if region_type: query = query.where(Region.region_type == region_type) @@ -69,12 +70,12 @@ def get_region(region_id: UUID, session: Session = Depends(get_session)): @router.get("/by-code/{region_code:path}", response_model=RegionRead) def get_region_by_code( region_code: str, + country_id: CountryId | None = Query( + None, description="Filter by country ('us' or 'uk')" + ), tax_benefit_model_id: UUID | None = Query( None, - description="Tax-benefit model ID (required if multiple models have same region code)", - ), - tax_benefit_model_name: str | None = Query( - None, description="Tax-benefit model name (e.g., 'policyengine-us')" + description="Tax-benefit model ID (alternative to country_id)", ), session: Session = Depends(get_session), ): @@ -84,17 +85,16 @@ def get_region_by_code( Args: region_code: The region code (e.g., "state/ca", "us"). - tax_benefit_model_id: Filter by tax-benefit model UUID. - tax_benefit_model_name: Filter by model name. + country_id: Filter by country ("us" or "uk"). + tax_benefit_model_id: Filter by tax-benefit model UUID (alternative to country_id). """ query = select(Region).where(Region.code == region_code) - if tax_benefit_model_id: + if country_id: + model_name = resolve_model_name(country_id) + query = query.join(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + elif tax_benefit_model_id: query = query.where(Region.tax_benefit_model_id == tax_benefit_model_id) - elif tax_benefit_model_name: - query = query.join(TaxBenefitModel).where( - TaxBenefitModel.name == tax_benefit_model_name - ) region = session.exec(query).first() if not region: diff --git a/src/policyengine_api/api/simulations.py b/src/policyengine_api/api/simulations.py index 8477a68..ba0196d 100644 --- a/src/policyengine_api/api/simulations.py +++ b/src/policyengine_api/api/simulations.py @@ -8,13 +8,14 @@ For baseline-vs-reform comparisons, use the /analysis/ endpoints instead. """ -from typing import Any, List, Literal +from typing import Any, List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field, model_validator from sqlmodel import Session, select +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( Dataset, Household, @@ -28,10 +29,13 @@ TaxBenefitModel, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import ( + resolve_country_model, + resolve_model_name, +) from .analysis import ( RegionInfo, - _get_model_version, _get_or_create_simulation, ) @@ -71,8 +75,8 @@ class HouseholdSimulationResponse(BaseModel): class EconomySimulationRequest(BaseModel): """Request body for creating an economy simulation.""" - tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( - description="Which country model to use" + country_id: CountryId = Field( + description="Which country model to use ('us' or 'uk')" ) region: str | None = Field( default=None, @@ -122,7 +126,7 @@ class EconomySimulationResponse(BaseModel): def _resolve_economy_dataset( - tax_benefit_model_name: str, + country_id: str, region_code: str | None, dataset_id: UUID | None, session: Session, @@ -135,7 +139,7 @@ def _resolve_economy_dataset( otherwise the latest available year is used. """ if region_code: - model_name = tax_benefit_model_name.replace("_", "-") + model_name = resolve_model_name(country_id) region = session.exec( select(Region) .join(TaxBenefitModel) @@ -145,7 +149,7 @@ def _resolve_economy_dataset( if not region: raise HTTPException( status_code=404, - detail=f"Region '{region_code}' not found for model {model_name}", + detail=f"Region '{region_code}' not found for country {country_id}", ) # Resolve dataset from join table @@ -274,7 +278,7 @@ def create_household_simulation( ) # Get model version - model_version = _get_model_version(household.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(household.country_id, session) # Get or create simulation (deterministic UUID) simulation = _get_or_create_simulation( @@ -328,7 +332,7 @@ def create_economy_simulation( """ # Resolve dataset and region dataset, region = _resolve_economy_dataset( - request.tax_benefit_model_name, + request.country_id, request.region, request.dataset_id, session, @@ -349,7 +353,7 @@ def create_economy_simulation( filter_value = region.filter_value if region and region.requires_filter else None # Get model version - model_version = _get_model_version(request.tax_benefit_model_name, session) + _model, model_version = resolve_country_model(request.country_id, session) # Get or create simulation (deterministic UUID) simulation = _get_or_create_simulation( diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index e592820..162e179 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -12,14 +12,13 @@ from pydantic import BaseModel from sqlmodel import Session, select -from policyengine_api.config.constants import COUNTRY_MODEL_NAMES, CountryId +from policyengine_api.config.constants import CountryId from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, Variable, VariableRead, ) from policyengine_api.services.database import get_session +from policyengine_api.services.model_resolver import resolve_version_id router = APIRouter(prefix="/variables", tags=["variables"]) @@ -29,7 +28,8 @@ def list_variables( skip: int = 0, limit: int = 100, search: str | None = None, - tax_benefit_model_name: str | None = None, + country_id: CountryId | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -40,22 +40,18 @@ def list_variables( Args: search: Filter by variable name, label, or description. - tax_benefit_model_name: Filter by country model. - Use "policyengine-uk" for UK variables. - Use "policyengine-us" for US variables. + country_id: Filter by country ("us" or "uk"). + Defaults to the latest model version. + tax_benefit_model_version_id: Pin to a specific model version. + Takes precedence over country_id. """ query = select(Variable) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + version_id = resolve_version_id(country_id, tax_benefit_model_version_id, session) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) if search: - # Case-insensitive search using ILIKE search_pattern = f"%{search}%" search_filter = ( Variable.name.ilike(search_pattern) @@ -75,6 +71,7 @@ class VariableByNameRequest(BaseModel): names: list[str] country_id: CountryId + tax_benefit_model_version_id: UUID | None = None @router.post("/by-name", response_model=List[VariableRead]) @@ -95,17 +92,15 @@ def get_variables_by_name( if not request.names: return [] - model_name = COUNTRY_MODEL_NAMES[request.country_id] - query = ( - select(Variable) - .join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == model_name) - .where(Variable.name.in_(request.names)) - .order_by(Variable.name) + version_id = resolve_version_id( + request.country_id, request.tax_benefit_model_version_id, session ) - return session.exec(query).all() + query = select(Variable).where(Variable.name.in_(request.names)) + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) + + return session.exec(query.order_by(Variable.name)).all() @router.get("/{variable_id}", response_model=VariableRead) diff --git a/src/policyengine_api/config/constants.py b/src/policyengine_api/config/constants.py index a39d827..61673e1 100644 --- a/src/policyengine_api/config/constants.py +++ b/src/policyengine_api/config/constants.py @@ -10,3 +10,6 @@ "uk": "policyengine-uk", "us": "policyengine-us", } + +# Reverse mapping: model name → country ID +MODEL_NAME_TO_COUNTRY: dict[str, str] = {v: k for k, v in COUNTRY_MODEL_NAMES.items()} diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py index 8a96850..194b574 100644 --- a/src/policyengine_api/models/household.py +++ b/src/policyengine_api/models/household.py @@ -1,17 +1,19 @@ """Stored household definition model.""" from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any from uuid import UUID, uuid4 from sqlalchemy import JSON from sqlmodel import Column, Field, SQLModel +from policyengine_api.config.constants import CountryId + class HouseholdBase(SQLModel): """Base household fields.""" - tax_benefit_model_name: str + country_id: str year: int label: str | None = None household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) @@ -34,7 +36,7 @@ class HouseholdCreate(SQLModel): people as an array, entity groups as optional dicts. """ - tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + country_id: CountryId year: int label: str | None = None people: list[dict[str, Any]] diff --git a/src/policyengine_api/models/household_job.py b/src/policyengine_api/models/household_job.py index ac853a9..8294314 100644 --- a/src/policyengine_api/models/household_job.py +++ b/src/policyengine_api/models/household_job.py @@ -21,7 +21,7 @@ class HouseholdJobStatus(str, Enum): class HouseholdJobBase(SQLModel): """Base household job fields.""" - tax_benefit_model_name: str + country_id: str request_data: dict[str, Any] = Field(sa_column=Column(JSON)) policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") diff --git a/src/policyengine_api/models/intra_decile_impact.py b/src/policyengine_api/models/intra_decile_impact.py index 8771d55..a8c958a 100644 --- a/src/policyengine_api/models/intra_decile_impact.py +++ b/src/policyengine_api/models/intra_decile_impact.py @@ -19,6 +19,7 @@ from enum import Enum from uuid import UUID, uuid4 +import sqlalchemy as sa from sqlmodel import Field, SQLModel @@ -35,7 +36,10 @@ class IntraDecileImpactBase(SQLModel): baseline_simulation_id: UUID = Field(foreign_key="simulations.id") reform_simulation_id: UUID = Field(foreign_key="simulations.id") report_id: UUID | None = Field(default=None, foreign_key="reports.id") - decile_type: DecileType = Field(default=DecileType.INCOME) + decile_type: DecileType = Field( + default=DecileType.INCOME, + sa_type=sa.Enum(DecileType, values_callable=lambda x: [e.value for e in x]), + ) decile: int = Field(ge=0, le=10) lose_more_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) lose_less_than_5pct: float | None = Field(default=None, ge=0.0, le=1.0) diff --git a/src/policyengine_api/models/region.py b/src/policyengine_api/models/region.py index 0458284..a3daffc 100644 --- a/src/policyengine_api/models/region.py +++ b/src/policyengine_api/models/region.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from uuid import UUID, uuid4 +import sqlalchemy as sa from pydantic import model_validator from sqlmodel import Field, Relationship, SQLModel @@ -33,7 +34,9 @@ class RegionBase(SQLModel): code: str # e.g., "state/ca", "constituency/Sheffield Central" label: str # e.g., "California", "Sheffield Central" - region_type: RegionType # e.g., RegionType.STATE, RegionType.CONSTITUENCY + region_type: RegionType = Field( + sa_type=sa.Enum(RegionType, values_callable=lambda x: [e.value for e in x]), + ) requires_filter: bool = False filter_field: str | None = None # e.g., "state_code", "place_fips" filter_value: str | None = None # e.g., "CA", "44000" diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index b034dcb..0ec572b 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -2,6 +2,7 @@ from enum import Enum from uuid import UUID, uuid4 +import sqlalchemy as sa from sqlmodel import Column, Field, SQLModel, Text @@ -9,6 +10,7 @@ class ReportStatus(str, Enum): """Report processing status.""" PENDING = "pending" + EXECUTION_DEFERRED = "execution_deferred" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" @@ -27,7 +29,12 @@ class ReportBase(SQLModel): label: str description: str | None = None - report_type: ReportType | None = None + report_type: ReportType | None = Field( + default=None, + sa_type=sa.Enum( + ReportType, values_callable=lambda x: [e.value for e in x], nullable=True + ), + ) user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) status: ReportStatus = ReportStatus.PENDING diff --git a/src/policyengine_api/services/model_resolver.py b/src/policyengine_api/services/model_resolver.py new file mode 100644 index 0000000..d4b97bc --- /dev/null +++ b/src/policyengine_api/services/model_resolver.py @@ -0,0 +1,104 @@ +"""Shared resolver for country_id → tax-benefit model + latest version.""" + +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, select + +from policyengine_api.config.constants import ( + COUNTRY_MODEL_NAMES, + MODEL_NAME_TO_COUNTRY, + CountryId, +) +from policyengine_api.models.simulation import Simulation +from policyengine_api.models.tax_benefit_model import TaxBenefitModel +from policyengine_api.models.tax_benefit_model_version import ( + TaxBenefitModelVersion, +) + + +def resolve_model_name(country_id: CountryId) -> str: + """Resolve country_id → DB model name (with hyphens).""" + model_name = COUNTRY_MODEL_NAMES.get(country_id) + if not model_name: + raise HTTPException( + status_code=400, + detail=f"Unsupported country_id: '{country_id}'. Supported: {list(COUNTRY_MODEL_NAMES.keys())}", + ) + return model_name + + +def resolve_country_model( + country_id: CountryId, session: Session +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Resolve country_id → (model, latest_version). + + Explicitly selects the most recent version by created_at DESC. + """ + model_name = resolve_model_name(country_id) + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, detail=f"Model not found for country: {country_id}" + ) + + version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not version: + raise HTTPException( + status_code=404, + detail=f"No version found for model: {model_name}", + ) + + return model, version + + +def resolve_version_id( + country_id: CountryId | None, + tax_benefit_model_version_id: UUID | None, + session: Session, +) -> UUID | None: + """Resolve the model version ID from country_id or explicit version. + + Priority: + 1. If tax_benefit_model_version_id provided, validate and return it. + 2. If country_id provided, return the latest version's ID. + 3. If neither provided, return None (no filtering). + """ + if tax_benefit_model_version_id: + version = session.get(TaxBenefitModelVersion, tax_benefit_model_version_id) + if not version: + raise HTTPException( + status_code=404, + detail=f"Model version '{tax_benefit_model_version_id}' not found", + ) + return version.id + + if country_id: + _, version = resolve_country_model(country_id, session) + return version.id + + return None + + +def resolve_country_from_simulation(sim: Simulation, session: Session) -> str: + """Derive country_id from a simulation's model version.""" + version = session.get(TaxBenefitModelVersion, sim.tax_benefit_model_version_id) + if not version: + raise HTTPException(status_code=500, detail="Model version not found") + model = session.get(TaxBenefitModel, version.model_id) + if not model: + raise HTTPException(status_code=500, detail="Tax-benefit model not found") + country_id = MODEL_NAME_TO_COUNTRY.get(model.name) + if not country_id: + raise HTTPException( + status_code=500, + detail=f"Unknown model name: '{model.name}'. Expected: {list(MODEL_NAME_TO_COUNTRY.keys())}", + ) + return country_id diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py index 85401af..903f0c2 100644 --- a/test_fixtures/fixtures_household_analysis.py +++ b/test_fixtures/fixtures_household_analysis.py @@ -310,12 +310,12 @@ def create_policy_with_parameter_value( def create_household_for_analysis( session: Session, - tax_benefit_model_name: str = "policyengine_uk", + country_id: str = "uk", year: int = 2024, label: str = "Test household for analysis", ) -> Household: """Create a household suitable for analysis testing.""" - if tax_benefit_model_name == "policyengine_uk": + if country_id == "uk": household_data = { "people": [{"age": 30, "employment_income": 35000}], "benunit": {}, @@ -332,7 +332,7 @@ def create_household_for_analysis( } record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data=household_data, diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py index 4e676f4..c5d4512 100644 --- a/test_fixtures/fixtures_households.py +++ b/test_fixtures/fixtures_households.py @@ -7,7 +7,7 @@ # ----------------------------------------------------------------------------- MOCK_US_HOUSEHOLD_CREATE = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "year": 2024, "label": "US test household", "people": [ @@ -20,7 +20,7 @@ } MOCK_UK_HOUSEHOLD_CREATE = { - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "year": 2024, "label": "UK test household", "people": [ @@ -31,7 +31,7 @@ } MOCK_HOUSEHOLD_MINIMAL = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "year": 2024, "people": [{"age": 25}], } @@ -44,7 +44,7 @@ def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str | None = "Test household", people: list | None = None, @@ -55,7 +55,7 @@ def create_household( household_data.update(entity_groups) record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data=household_data, diff --git a/test_fixtures/fixtures_simulations_standalone.py b/test_fixtures/fixtures_simulations_standalone.py index 2b7ce39..5eb85cd 100644 --- a/test_fixtures/fixtures_simulations_standalone.py +++ b/test_fixtures/fixtures_simulations_standalone.py @@ -54,13 +54,13 @@ def create_uk_model_and_version( def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str = "Test household", ) -> Household: """Create and persist a Household record.""" household = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data={ diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py index 66b0835..a0dc9ee 100644 --- a/test_fixtures/fixtures_user_household_associations.py +++ b/test_fixtures/fixtures_user_household_associations.py @@ -25,13 +25,13 @@ def create_user( def create_household( session, - tax_benefit_model_name: str = "policyengine_us", + country_id: str = "us", year: int = 2024, label: str | None = "Test household", ) -> Household: """Create and persist a Household record.""" record = Household( - tax_benefit_model_name=tax_benefit_model_name, + country_id=country_id, year=year, label=label, household_data={"people": [{"age": 30}]}, diff --git a/tests/conftest.py b/tests/conftest.py index 1c52936..e2c9c36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,7 +73,7 @@ def uk_tax_benefit_model_fixture(session: Session): def simulation_fixture(session: Session): """Create a test simulation with required dependencies.""" # Create model - model = TaxBenefitModel(name="policyengine_uk", description="UK model") + model = TaxBenefitModel(name="policyengine-uk", description="UK model") session.add(model) session.commit() session.refresh(model) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index abd7489..5fa6674 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -53,7 +53,7 @@ def test__given_dataset_id__then_region_is_none(self, session: Session): model = create_tax_benefit_model(session, name="policyengine-uk") dataset = create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset.id, ) @@ -67,7 +67,7 @@ def test__given_dataset_id__then_dataset_is_returned(self, session: Session): model = create_tax_benefit_model(session, name="policyengine-uk") dataset = create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset.id, ) @@ -94,7 +94,7 @@ def test__given_dataset_id_and_region__then_region_takes_precedence( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=dataset1.id, region="uk", ) @@ -126,7 +126,7 @@ def test__given_region_requires_filter__then_returns_filter_fields( filter_value="ENGLAND", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -154,7 +154,7 @@ def test__given_us_state_region__then_returns_state_filter(self, session: Sessio filter_value="CA", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="state/ca", ) @@ -183,7 +183,7 @@ def test__given_region_with_filter__then_dataset_is_resolved( filter_value="ENGLAND", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -209,7 +209,7 @@ def test__given_national_uk_region__then_filter_params_none(self, session: Sessi requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -235,7 +235,7 @@ def test__given_national_us_region__then_filter_params_none(self, session: Sessi requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="us", ) @@ -263,7 +263,7 @@ def test__given_national_region__then_dataset_still_resolved( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -280,7 +280,7 @@ def test__given_nonexistent_region_code__then_raises_404(self, session: Session) model = create_tax_benefit_model(session, name="policyengine-uk") create_dataset(session, model, name="uk_enhanced_frs") request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="nonexistent/region", ) @@ -302,7 +302,7 @@ def test__given_region_for_wrong_model__then_raises_404(self, session: Session): region_type="national", ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_us", + country_id="us", region="uk", ) @@ -318,13 +318,13 @@ def test__given_neither_dataset_nor_region__then_raises_validation_error( with pytest.raises(ValidationError, match="dataset_id or region"): EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", ) def test__given_nonexistent_dataset_id__then_raises_404(self, session: Session): nonexistent_id = uuid4() request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", dataset_id=nonexistent_id, ) @@ -943,7 +943,7 @@ def test__given_region_with_row_filter_strategy__then_region_has_filter_strategy filter_strategy=FILTER_STRATEGIES["ROW_FILTER"], ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="country/england", ) @@ -972,7 +972,7 @@ def test__given_constituency_region__then_region_has_weight_replacement_strategy filter_strategy=FILTER_STRATEGIES["WEIGHT_REPLACEMENT"], ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="constituency/sheffield-central", ) @@ -1001,7 +1001,7 @@ def test__given_national_region__then_filter_strategy_is_none( requires_filter=False, ) request = EconomicImpactRequest( - tax_benefit_model_name="policyengine_uk", + country_id="uk", region="uk", ) @@ -1021,11 +1021,11 @@ def test__given_national_region__then_filter_strategy_is_none( class TestEconomicImpactValidation: """Tests for request validation (no database required).""" - def test_invalid_model_name(self): + def test_invalid_country_id(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "dataset_id": "00000000-0000-0000-0000-000000000000", }, ) @@ -1035,7 +1035,7 @@ def test_missing_dataset_id(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 @@ -1044,7 +1044,7 @@ def test_invalid_uuid(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": "not-a-uuid", }, ) @@ -1059,7 +1059,7 @@ def test_dataset_not_found(self): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": "00000000-0000-0000-0000-000000000000", }, ) @@ -1102,7 +1102,7 @@ def test_uk_economic_impact_baseline_only(self, uk_dataset_id): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": str(uk_dataset_id), }, ) @@ -1126,7 +1126,7 @@ def test_simulations_created(self, uk_dataset_id, session: Session): response = client.post( "/analysis/economic-impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "dataset_id": str(uk_dataset_id), }, ) diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py index 802562d..0e6f485 100644 --- a/tests/test_analysis_household_impact.py +++ b/tests/test_analysis_household_impact.py @@ -201,13 +201,13 @@ class TestGetCountryConfig: """Tests for get_country_config helper.""" def test_uk_model_returns_uk_config(self): - config = get_country_config("policyengine_uk") + config = get_country_config("uk") assert config == UK_CONFIG assert config.name == "uk" assert "benunit" in config.entity_types def test_us_model_returns_us_config(self): - config = get_country_config("policyengine_us") + config = get_country_config("us") assert config == US_CONFIG assert config.name == "us" assert "tax_unit" in config.entity_types @@ -223,13 +223,13 @@ class TestGetCalculator: def test_uk_model_returns_uk_calculator(self): from policyengine_api.api.household_analysis import calculate_uk_household - calc = get_calculator("policyengine_uk") + calc = get_calculator("uk") assert calc == calculate_uk_household def test_us_model_returns_us_calculator(self): from policyengine_api.api.household_analysis import calculate_us_household - calc = get_calculator("policyengine_us") + calc = get_calculator("us") assert calc == calculate_us_household def test_unknown_model_defaults_to_us(self): @@ -297,7 +297,7 @@ def test_policy_not_found(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(uuid4()), + "reform_policy_id": str(uuid4()), }, ) assert response.status_code == 404 @@ -346,7 +346,7 @@ def test_comparison_creates_two_simulations(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy.id), + "reform_policy_id": str(policy.id), }, ) data = response.json() @@ -386,7 +386,7 @@ def test_report_links_simulations(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy.id), + "reform_policy_id": str(policy.id), }, ) data = response.json() @@ -442,7 +442,7 @@ def test_different_policy_creates_different_simulation(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy1.id), + "reform_policy_id": str(policy1.id), }, ) data1 = response1.json() @@ -452,7 +452,7 @@ def test_different_policy_creates_different_simulation(self, client, session): "/analysis/household-impact", json={ "household_id": str(household.id), - "policy_id": str(policy2.id), + "reform_policy_id": str(policy2.id), }, ) data2 = response2.json() @@ -506,9 +506,7 @@ class TestUSHouseholdImpact: def test_us_household_creates_simulation(self, client, session): """US household creates simulation with correct model.""" _, version = setup_us_model_and_version(session) - household = create_household_for_analysis( - session, tax_benefit_model_name="policyengine_us" - ) + household = create_household_for_analysis(session, country_id="us") response = client.post( "/analysis/household-impact", diff --git a/tests/test_economy_custom.py b/tests/test_economy_custom.py index fda65fa..4be2968 100644 --- a/tests/test_economy_custom.py +++ b/tests/test_economy_custom.py @@ -218,7 +218,7 @@ def test_unknown_module_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["nonexistent_module"], }, @@ -230,7 +230,7 @@ def test_wrong_country_module_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["constituency"], }, @@ -242,7 +242,7 @@ def test_multiple_errors_in_module_validation(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["nonexistent", "constituency"], }, @@ -256,7 +256,7 @@ def test_empty_modules_list_passes_validation(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": [], }, @@ -272,7 +272,7 @@ def test_valid_modules_but_missing_region_returns_404(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", "modules": ["decile", "poverty"], }, @@ -284,7 +284,7 @@ def test_missing_modules_field_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", }, ) @@ -294,7 +294,7 @@ def test_invalid_model_name_returns_422(self, client): response = client.post( "/analysis/economy-custom", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "region": "us", "modules": ["decile"], }, diff --git a/tests/test_household.py b/tests/test_household.py index 0cf0288..ddefc15 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -34,7 +34,7 @@ def test_single_adult(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, }, @@ -55,7 +55,7 @@ def test_couple_with_children(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [ {"age": 35, "employment_income": 50000}, {"age": 33, "employment_income": 25000}, @@ -75,7 +75,7 @@ def test_with_household_data(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 40, "employment_income": 45000}], "household": [ { @@ -96,7 +96,7 @@ def test_output_contains_tax_variables(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 50000}], "year": 2026, }, @@ -118,7 +118,7 @@ def test_single_adult(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 30, "employment_income": 60000}], "year": 2024, }, @@ -140,7 +140,7 @@ def test_family_with_children(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"age": 35, "employment_income": 80000}, {"age": 33, "employment_income": 40000}, @@ -164,7 +164,7 @@ def test_multiple_uk_households(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [ # Person in household 0 { @@ -207,7 +207,7 @@ def test_multiple_us_households(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ # Person in household 0 { @@ -272,7 +272,7 @@ def test_invalid_model_name(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "people": [{"age": 30}], }, ) @@ -283,7 +283,7 @@ def test_missing_people(self): response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 @@ -354,7 +354,7 @@ def test_us_reform_changes_household_net_income(self): baseline_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 40, "employment_income": 70000}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -371,7 +371,7 @@ def test_us_reform_changes_household_net_income(self): reform_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 40, "employment_income": 70000}], "tax_unit": [{"state_code": "CA"}], "household": [{"state_fips": 6}], @@ -453,7 +453,7 @@ def test_uk_reform_changes_household_net_income(self): baseline_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, }, @@ -468,7 +468,7 @@ def test_uk_reform_changes_household_net_income(self): reform_response = client.post( "/household/calculate", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, "policy_id": policy_id, diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index e2ac120..4168882 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -18,7 +18,7 @@ def test_single_adult_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 30, "employment_income": 30000}], "year": 2026, # No policy_id means baseline vs baseline @@ -39,7 +39,7 @@ def test_impact_response_structure(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", "people": [{"age": 35, "employment_income": 50000}], "year": 2026, }, @@ -70,7 +70,7 @@ def test_single_adult_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [{"age": 30, "employment_income": 60000}], "year": 2024, }, @@ -86,7 +86,7 @@ def test_family_impact(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "people": [ {"age": 35, "employment_income": 80000}, {"age": 33, "employment_income": 40000}, @@ -110,7 +110,7 @@ def test_invalid_model_name(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "invalid_model", + "country_id": "invalid_model", "people": [{"age": 30}], }, ) @@ -121,7 +121,7 @@ def test_missing_people(self): response = client.post( "/household/impact", json={ - "tax_benefit_model_name": "policyengine_uk", + "country_id": "uk", }, ) assert response.status_code == 422 diff --git a/tests/test_households.py b/tests/test_households.py index 4c60062..4c1c759 100644 --- a/tests/test_households.py +++ b/tests/test_households.py @@ -22,7 +22,7 @@ def test_create_us_household(client): assert "id" in data assert "created_at" in data assert "updated_at" in data - assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["country_id"] == "us" assert data["year"] == 2024 assert data["label"] == "US test household" @@ -44,7 +44,7 @@ def test_create_uk_household(client): response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) assert response.status_code == 201 data = response.json() - assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["country_id"] == "uk" assert data["benunit"] == {"is_married": False} assert data["household"] == {"region": "LONDON"} @@ -59,9 +59,9 @@ def test_create_household_minimal(client): assert data["benunit"] is None -def test_create_household_invalid_model_name(client): - """Reject invalid tax_benefit_model_name.""" - payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} +def test_create_household_invalid_country_id(client): + """Reject invalid country_id.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "country_id": "invalid"} response = client.post("/households", json=payload) assert response.status_code == 422 @@ -78,7 +78,7 @@ def test_get_household(client, session): assert response.status_code == 200 data = response.json() assert data["id"] == str(record.id) - assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["country_id"] == "us" def test_get_household_not_found(client): @@ -111,16 +111,14 @@ def test_list_households_with_data(client, session): assert len(data) == 2 -def test_list_households_filter_by_model_name(client, session): - """Filter households by tax_benefit_model_name.""" - create_household(session, tax_benefit_model_name="policyengine_us") - create_household(session, tax_benefit_model_name="policyengine_uk") - response = client.get( - "/households", params={"tax_benefit_model_name": "policyengine_uk"} - ) +def test_list_households_filter_by_country_id(client, session): + """Filter households by country_id.""" + create_household(session, country_id="us") + create_household(session, country_id="uk") + response = client.get("/households", params={"country_id": "uk"}) data = response.json() assert len(data) == 1 - assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + assert data[0]["country_id"] == "uk" def test_list_households_limit_and_offset(client, session): diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py new file mode 100644 index 0000000..1888bfc --- /dev/null +++ b/tests/test_model_resolver.py @@ -0,0 +1,187 @@ +"""Tests for model_resolver service and related fixes.""" + +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from policyengine_api.models import ( + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) +from policyengine_api.services.model_resolver import ( + resolve_country_from_simulation, + resolve_country_model, + resolve_model_name, + resolve_version_id, +) + +# --------------------------------------------------------------------------- +# resolve_model_name +# --------------------------------------------------------------------------- + + +class TestResolveModelName: + def test_us_returns_policyengine_us(self): + assert resolve_model_name("us") == "policyengine-us" + + def test_uk_returns_policyengine_uk(self): + assert resolve_model_name("uk") == "policyengine-uk" + + def test_invalid_country_raises_400(self): + with pytest.raises(HTTPException) as exc_info: + resolve_model_name("fr") + assert exc_info.value.status_code == 400 + assert "Unsupported country_id" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# resolve_country_model +# --------------------------------------------------------------------------- + + +class TestResolveCountryModel: + def test_returns_model_and_latest_version(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + v1 = TaxBenefitModelVersion(model_id=model.id, version="1.0", description="Old") + session.add(v1) + session.commit() + + v2 = TaxBenefitModelVersion(model_id=model.id, version="2.0", description="New") + session.add(v2) + session.commit() + session.refresh(v2) + + result_model, result_version = resolve_country_model("us", session) + assert result_model.id == model.id + assert result_version.id == v2.id + + def test_missing_model_raises_404(self, session): + with pytest.raises(HTTPException) as exc_info: + resolve_country_model("us", session) + assert exc_info.value.status_code == 404 + assert "Model not found" in exc_info.value.detail + + def test_missing_version_raises_404(self, session): + model = TaxBenefitModel(name="policyengine-uk", description="UK") + session.add(model) + session.commit() + + with pytest.raises(HTTPException) as exc_info: + resolve_country_model("uk", session) + assert exc_info.value.status_code == 404 + assert "No version found" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# resolve_version_id +# --------------------------------------------------------------------------- + + +class TestResolveVersionId: + def test_explicit_version_id_returned(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + result = resolve_version_id(None, version.id, session) + assert result == version.id + + def test_explicit_version_id_not_found_raises_404(self, session): + with pytest.raises(HTTPException) as exc_info: + resolve_version_id(None, uuid4(), session) + assert exc_info.value.status_code == 404 + + def test_country_id_returns_latest_version(self, session): + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + result = resolve_version_id("us", None, session) + assert result == version.id + + def test_neither_returns_none(self, session): + assert resolve_version_id(None, None, session) is None + + +# --------------------------------------------------------------------------- +# resolve_country_from_simulation +# --------------------------------------------------------------------------- + + +class TestResolveCountryFromSimulation: + def _create_simulation(self, session, model_name="policyengine-us"): + model = TaxBenefitModel(name=model_name, description="Test") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion( + model_id=model.id, version="1.0", description="V1" + ) + session.add(version) + session.commit() + session.refresh(version) + + sim = Simulation( + tax_benefit_model_version_id=version.id, + status=SimulationStatus.PENDING, + simulation_type=SimulationType.HOUSEHOLD, + ) + session.add(sim) + session.commit() + session.refresh(sim) + return sim + + def test_us_simulation_returns_us(self, session): + sim = self._create_simulation(session, "policyengine-us") + assert resolve_country_from_simulation(sim, session) == "us" + + def test_uk_simulation_returns_uk(self, session): + sim = self._create_simulation(session, "policyengine-uk") + assert resolve_country_from_simulation(sim, session) == "uk" + + def test_unknown_model_name_raises_500(self, session): + sim = self._create_simulation(session, "policyengine-fr") + with pytest.raises(HTTPException) as exc_info: + resolve_country_from_simulation(sim, session) + assert exc_info.value.status_code == 500 + assert "Unknown model name" in exc_info.value.detail + + def test_missing_version_raises_500(self, session): + sim = Simulation( + tax_benefit_model_version_id=uuid4(), + status=SimulationStatus.PENDING, + simulation_type=SimulationType.HOUSEHOLD, + ) + session.add(sim) + session.commit() + session.refresh(sim) + + with pytest.raises(HTTPException) as exc_info: + resolve_country_from_simulation(sim, session) + assert exc_info.value.status_code == 500 + assert "Model version not found" in exc_info.value.detail diff --git a/tests/test_models.py b/tests/test_models.py index 465856c..05fc6b4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -224,13 +224,13 @@ def test_variable_with_empty_adds(): def test_household_creation(): """Test household model creation.""" household = Household( - tax_benefit_model_name="policyengine_us", + country_id="us", year=2024, label="Test household", household_data={"people": [{"age": 30}], "household": {}}, ) assert household.household_data == {"people": [{"age": 30}], "household": {}} assert household.label == "Test household" - assert household.tax_benefit_model_name == "policyengine_us" + assert household.country_id == "us" assert household.year == 2024 assert household.id is not None diff --git a/tests/test_simulation_status_handling.py b/tests/test_simulation_status_handling.py new file mode 100644 index 0000000..b85b21b --- /dev/null +++ b/tests/test_simulation_status_handling.py @@ -0,0 +1,72 @@ +"""Tests for _run_simulation_in_session status handling (Critical #3 fix).""" + +from uuid import uuid4 + +import pytest + +from policyengine_api.api.household_analysis import _run_simulation_in_session +from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + SimulationType, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def _setup_household_simulation(session, status=SimulationStatus.PENDING): + """Create a household + simulation for testing.""" + model = TaxBenefitModel(name="policyengine-us", description="US") + session.add(model) + session.commit() + session.refresh(model) + + version = TaxBenefitModelVersion(model_id=model.id, version="1.0", description="V1") + session.add(version) + session.commit() + session.refresh(version) + + household = Household( + country_id="us", + household_data={"people": {"you": {"age": {"2024": 30}}}}, + year=2024, + ) + session.add(household) + session.commit() + session.refresh(household) + + sim = Simulation( + tax_benefit_model_version_id=version.id, + status=status, + simulation_type=SimulationType.HOUSEHOLD, + household_id=household.id, + ) + session.add(sim) + session.commit() + session.refresh(sim) + return sim + + +class TestRunSimulationInSession: + def test_missing_simulation_raises_valueerror(self, session): + with pytest.raises(ValueError, match="not found"): + _run_simulation_in_session(uuid4(), session) + + def test_completed_simulation_skips_silently(self, session): + sim = _setup_household_simulation(session, SimulationStatus.COMPLETED) + # Should not raise + _run_simulation_in_session(sim.id, session) + # Status should still be COMPLETED + session.refresh(sim) + assert sim.status == SimulationStatus.COMPLETED + + def test_running_simulation_raises_valueerror(self, session): + sim = _setup_household_simulation(session, SimulationStatus.RUNNING) + with pytest.raises(ValueError, match="unexpected status"): + _run_simulation_in_session(sim.id, session) + + def test_failed_simulation_raises_valueerror(self, session): + sim = _setup_household_simulation(session, SimulationStatus.FAILED) + with pytest.raises(ValueError, match="unexpected status"): + _run_simulation_in_session(sim.id, session) diff --git a/tests/test_simulations_standalone.py b/tests/test_simulations_standalone.py index 5a18414..a448bac 100644 --- a/tests/test_simulations_standalone.py +++ b/tests/test_simulations_standalone.py @@ -139,7 +139,7 @@ def test_create_economy_simulation_with_region(client, session): create_region(session, model, dataset, code="us", label="United States") payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "us", } response = client.post("/simulations/economy", json=payload) @@ -158,7 +158,7 @@ def test_create_economy_simulation_with_dataset(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), } response = client.post("/simulations/economy", json=payload) @@ -187,7 +187,7 @@ def test_create_economy_simulation_with_region_filter(client, session): ) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "state/ca", } response = client.post("/simulations/economy", json=payload) @@ -204,7 +204,7 @@ def test_create_economy_simulation_invalid_region(client, session): model, version = create_us_model_and_version(session) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "region": "nonexistent/region", } response = client.post("/simulations/economy", json=payload) @@ -217,7 +217,7 @@ def test_create_economy_simulation_no_region_or_dataset(client, session): """Creating without region or dataset_id returns 422 (Pydantic validation).""" model, version = create_us_model_and_version(session) - payload = {"tax_benefit_model_name": "policyengine_us"} + payload = {"country_id": "us"} response = client.post("/simulations/economy", json=payload) assert response.status_code == 422 @@ -229,7 +229,7 @@ def test_create_economy_simulation_policy_not_found(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), "policy_id": str(uuid4()), } @@ -245,7 +245,7 @@ def test_economy_simulation_deduplication(client, session): dataset = create_dataset(session, model) payload = { - "tax_benefit_model_name": "policyengine_us", + "country_id": "us", "dataset_id": str(dataset.id), } response1 = client.post("/simulations/economy", json=payload) diff --git a/tests/test_variable_labels.py b/tests/test_variable_labels.py index 4ee7705..879da13 100644 --- a/tests/test_variable_labels.py +++ b/tests/test_variable_labels.py @@ -109,7 +109,7 @@ def test_search_matches_label( "/variables", params={ "search": "Employment", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -135,7 +135,7 @@ def test_search_label_case_insensitive( "/variables", params={ "search": "INCOME TAX", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -159,7 +159,7 @@ def test_search_partial_label_match( "/variables", params={ "search": "income", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) assert response.status_code == 200 @@ -308,14 +308,14 @@ def test_search_by_label_isolated_by_country( "/variables", params={ "search": "tax", - "tax_benefit_model_name": "policyengine-us", + "country_id": "us", }, ) uk_response = client.get( "/variables", params={ "search": "tax", - "tax_benefit_model_name": "policyengine-uk", + "country_id": "uk", }, )