Skip to content

Commit 5b2cc84

Browse files
authored
Merge pull request #105 from PolicyEngine/feat/enable-uk-districts
feat: Add filter_strategy column and strategy reconstruction
2 parents 6fc2fb6 + 3c3fe4f commit 5b2cc84

11 files changed

Lines changed: 1093 additions & 2 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""add filter_strategy to regions and simulations
2+
3+
Revision ID: add_filter_strategy
4+
Revises: 886921687770
5+
Create Date: 2026-03-08
6+
7+
Adds filter_strategy column to regions and simulations tables.
8+
Values are 'row_filter' or 'weight_replacement', indicating which
9+
scoping strategy to use when running simulations for that region.
10+
11+
Data migration:
12+
- Existing regions with filter_field != 'household_weight' -> 'row_filter'
13+
- Existing regions with filter_field = 'household_weight' -> 'weight_replacement'
14+
- Simulations inherit from their region's strategy
15+
"""
16+
17+
from typing import Sequence, Union
18+
19+
import sqlalchemy as sa
20+
import sqlmodel.sql.sqltypes
21+
22+
from alembic import op
23+
24+
# revision identifiers, used by Alembic.
25+
revision: str = "add_filter_strategy"
26+
down_revision: Union[str, Sequence[str], None] = "886921687770"
27+
branch_labels: Union[str, Sequence[str], None] = None
28+
depends_on: Union[str, Sequence[str], None] = None
29+
30+
31+
def upgrade() -> None:
32+
"""Add filter_strategy column and backfill existing data."""
33+
# Add column to regions
34+
op.add_column(
35+
"regions",
36+
sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
37+
)
38+
39+
# Add column to simulations
40+
op.add_column(
41+
"simulations",
42+
sa.Column("filter_strategy", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
43+
)
44+
45+
# Backfill regions: set strategy based on existing filter_field
46+
conn = op.get_bind()
47+
48+
# Regions with filter_field = 'household_weight' use weight replacement
49+
conn.execute(
50+
sa.text(
51+
"UPDATE regions SET filter_strategy = 'weight_replacement' "
52+
"WHERE filter_field = 'household_weight'"
53+
)
54+
)
55+
56+
# Regions with other non-null filter_field use row filtering
57+
conn.execute(
58+
sa.text(
59+
"UPDATE regions SET filter_strategy = 'row_filter' "
60+
"WHERE filter_field IS NOT NULL AND filter_field != 'household_weight'"
61+
)
62+
)
63+
64+
# Backfill simulations based on their region's strategy
65+
conn.execute(
66+
sa.text(
67+
"UPDATE simulations SET filter_strategy = regions.filter_strategy "
68+
"FROM regions "
69+
"WHERE simulations.region_id = regions.id "
70+
"AND regions.filter_strategy IS NOT NULL"
71+
)
72+
)
73+
74+
75+
def downgrade() -> None:
76+
"""Remove filter_strategy columns."""
77+
op.drop_column("simulations", "filter_strategy")
78+
op.drop_column("regions", "filter_strategy")

scripts/seed_regions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ def seed_us_regions(
199199
requires_filter=pe_region.requires_filter,
200200
filter_field=pe_region.filter_field,
201201
filter_value=pe_region.filter_value,
202+
filter_strategy=(
203+
pe_region.scoping_strategy.strategy_type
204+
if pe_region.scoping_strategy
205+
else None
206+
),
202207
parent_code=pe_region.parent_code,
203208
state_code=pe_region.state_code,
204209
state_name=pe_region.state_name,
@@ -292,6 +297,11 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]:
292297
requires_filter=pe_region.requires_filter,
293298
filter_field=pe_region.filter_field,
294299
filter_value=pe_region.filter_value,
300+
filter_strategy=(
301+
pe_region.scoping_strategy.strategy_type
302+
if pe_region.scoping_strategy
303+
else None
304+
),
295305
parent_code=pe_region.parent_code,
296306
state_code=None,
297307
state_name=None,

src/policyengine_api/api/analysis.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,15 @@ def _get_deterministic_simulation_id(
251251
household_id: UUID | None = None,
252252
filter_field: str | None = None,
253253
filter_value: str | None = None,
254+
filter_strategy: str | None = None,
254255
) -> UUID:
255256
"""Generate a deterministic UUID from simulation parameters."""
256257
if simulation_type == SimulationType.ECONOMY:
257258
key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}:{filter_field}:{filter_value}"
259+
# Only append filter_strategy when non-null to preserve backward
260+
# compatibility with existing simulation IDs
261+
if filter_strategy is not None:
262+
key += f":{filter_strategy}"
258263
else:
259264
key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}"
260265
return uuid5(SIMULATION_NAMESPACE, key)
@@ -279,6 +284,7 @@ def _get_or_create_simulation(
279284
household_id: UUID | None = None,
280285
filter_field: str | None = None,
281286
filter_value: str | None = None,
287+
filter_strategy: str | None = None,
282288
region_id: UUID | None = None,
283289
year: int | None = None,
284290
) -> Simulation:
@@ -292,6 +298,7 @@ def _get_or_create_simulation(
292298
household_id=household_id,
293299
filter_field=filter_field,
294300
filter_value=filter_value,
301+
filter_strategy=filter_strategy,
295302
)
296303

297304
existing = session.get(Simulation, sim_id)
@@ -309,6 +316,7 @@ def _get_or_create_simulation(
309316
status=SimulationStatus.PENDING,
310317
filter_field=filter_field,
311318
filter_value=filter_value,
319+
filter_strategy=filter_strategy,
312320
region_id=region_id,
313321
year=year,
314322
)
@@ -846,12 +854,32 @@ def build_dynamic(dynamic_id):
846854
year=dataset.year,
847855
)
848856

849-
# Run simulations (with optional regional filtering)
857+
# Reconstruct scoping strategy from DB columns (if applicable)
858+
from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy
859+
860+
baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None
861+
baseline_strategy = reconstruct_strategy(
862+
filter_strategy=baseline_sim.filter_strategy,
863+
filter_field=baseline_sim.filter_field,
864+
filter_value=baseline_sim.filter_value,
865+
region_type=baseline_region.region_type.value if baseline_region else None,
866+
)
867+
868+
reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None
869+
reform_strategy = reconstruct_strategy(
870+
filter_strategy=reform_sim.filter_strategy,
871+
filter_field=reform_sim.filter_field,
872+
filter_value=reform_sim.filter_value,
873+
region_type=reform_region.region_type.value if reform_region else None,
874+
)
875+
876+
# Run simulations (with optional regional scoping)
850877
pe_baseline_sim = PESimulation(
851878
dataset=pe_dataset,
852879
tax_benefit_model_version=pe_model_version,
853880
policy=baseline_policy,
854881
dynamic=baseline_dynamic,
882+
scoping_strategy=baseline_strategy,
855883
filter_field=baseline_sim.filter_field,
856884
filter_value=baseline_sim.filter_value,
857885
)
@@ -862,6 +890,7 @@ def build_dynamic(dynamic_id):
862890
tax_benefit_model_version=pe_model_version,
863891
policy=reform_policy,
864892
dynamic=reform_dynamic,
893+
scoping_strategy=reform_strategy,
865894
filter_field=reform_sim.filter_field,
866895
filter_value=reform_sim.filter_value,
867896
)
@@ -1006,12 +1035,32 @@ def build_dynamic(dynamic_id):
10061035
year=dataset.year,
10071036
)
10081037

1009-
# Run simulations (with optional regional filtering)
1038+
# Reconstruct scoping strategy from DB columns (if applicable)
1039+
from policyengine_api.utils.strategy_reconstruction import reconstruct_strategy
1040+
1041+
baseline_region = session.get(Region, baseline_sim.region_id) if baseline_sim.region_id else None
1042+
baseline_strategy = reconstruct_strategy(
1043+
filter_strategy=baseline_sim.filter_strategy,
1044+
filter_field=baseline_sim.filter_field,
1045+
filter_value=baseline_sim.filter_value,
1046+
region_type=baseline_region.region_type.value if baseline_region else None,
1047+
)
1048+
1049+
reform_region = session.get(Region, reform_sim.region_id) if reform_sim.region_id else None
1050+
reform_strategy = reconstruct_strategy(
1051+
filter_strategy=reform_sim.filter_strategy,
1052+
filter_field=reform_sim.filter_field,
1053+
filter_value=reform_sim.filter_value,
1054+
region_type=reform_region.region_type.value if reform_region else None,
1055+
)
1056+
1057+
# Run simulations (with optional regional scoping)
10101058
pe_baseline_sim = PESimulation(
10111059
dataset=pe_dataset,
10121060
tax_benefit_model_version=pe_model_version,
10131061
policy=baseline_policy,
10141062
dynamic=baseline_dynamic,
1063+
scoping_strategy=baseline_strategy,
10151064
filter_field=baseline_sim.filter_field,
10161065
filter_value=baseline_sim.filter_value,
10171066
)
@@ -1022,6 +1071,7 @@ def build_dynamic(dynamic_id):
10221071
tax_benefit_model_version=pe_model_version,
10231072
policy=reform_policy,
10241073
dynamic=reform_dynamic,
1074+
scoping_strategy=reform_strategy,
10251075
filter_field=reform_sim.filter_field,
10261076
filter_value=reform_sim.filter_value,
10271077
)
@@ -1199,6 +1249,7 @@ def economic_impact(
11991249
# Extract filter parameters from region (if present)
12001250
filter_field = region.filter_field if region and region.requires_filter else None
12011251
filter_value = region.filter_value if region and region.requires_filter else None
1252+
filter_strategy = region.filter_strategy if region and region.requires_filter else None
12021253

12031254
# Get model version
12041255
model_version = _get_model_version(request.tax_benefit_model_name, session)
@@ -1213,6 +1264,7 @@ def economic_impact(
12131264
dataset_id=dataset.id,
12141265
filter_field=filter_field,
12151266
filter_value=filter_value,
1267+
filter_strategy=filter_strategy,
12161268
region_id=region.id if region else None,
12171269
year=dataset.year,
12181270
)
@@ -1226,6 +1278,7 @@ def economic_impact(
12261278
dataset_id=dataset.id,
12271279
filter_field=filter_field,
12281280
filter_value=filter_value,
1281+
filter_strategy=filter_strategy,
12291282
region_id=region.id if region else None,
12301283
year=dataset.year,
12311284
)
@@ -1392,6 +1445,9 @@ def economy_custom(
13921445
filter_value = (
13931446
region_obj.filter_value if region_obj and region_obj.requires_filter else None
13941447
)
1448+
filter_strategy = (
1449+
region_obj.filter_strategy if region_obj and region_obj.requires_filter else None
1450+
)
13951451

13961452
model_version = _get_model_version(request.tax_benefit_model_name, session)
13971453

@@ -1404,6 +1460,7 @@ def economy_custom(
14041460
dataset_id=dataset.id,
14051461
filter_field=filter_field,
14061462
filter_value=filter_value,
1463+
filter_strategy=filter_strategy,
14071464
region_id=region_obj.id if region_obj else None,
14081465
year=dataset.year,
14091466
)
@@ -1417,6 +1474,7 @@ def economy_custom(
14171474
dataset_id=dataset.id,
14181475
filter_field=filter_field,
14191476
filter_value=filter_value,
1477+
filter_strategy=filter_strategy,
14201478
region_id=region_obj.id if region_obj else None,
14211479
year=dataset.year,
14221480
)

src/policyengine_api/models/region.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class RegionBase(SQLModel):
3737
requires_filter: bool = False
3838
filter_field: str | None = None # e.g., "state_code", "place_fips"
3939
filter_value: str | None = None # e.g., "CA", "44000"
40+
filter_strategy: str | None = None # "row_filter" or "weight_replacement"
4041
parent_code: str | None = None # e.g., "us", "state/ca"
4142
state_code: str | None = None # For US regions
4243
state_name: str | None = None # For US regions

src/policyengine_api/models/simulation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class SimulationBase(SQLModel):
6060
default=None,
6161
description="Value to match when filtering (e.g., '44000', 'ENGLAND')",
6262
)
63+
filter_strategy: str | None = Field(
64+
default=None,
65+
description="Scoping strategy: 'row_filter' or 'weight_replacement'",
66+
)
6367

6468
year: int | None = None
6569

@@ -118,6 +122,7 @@ class SimulationCreate(SQLModel):
118122
region_id: UUID | None = None
119123
filter_field: str | None = None
120124
filter_value: str | None = None
125+
filter_strategy: str | None = None
121126
year: int | None = None
122127

123128
@model_validator(mode="after")

src/policyengine_api/utils/__init__.py

Whitespace-only changes.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Reconstruct policyengine.py scoping strategy objects from DB columns.
2+
3+
Rather than storing serialized strategy objects in the database, we store
4+
a simple filter_strategy string ('row_filter' or 'weight_replacement')
5+
and reconstruct the full strategy object at runtime using the existing
6+
filter_field, filter_value, and region_type columns plus a constant
7+
config mapping for weight matrix locations.
8+
"""
9+
10+
# GCS locations for weight matrices, keyed by region type
11+
WEIGHT_MATRIX_CONFIG: dict[str, dict[str, str]] = {
12+
"constituency": {
13+
"weight_matrix_bucket": "policyengine-uk-data-private",
14+
"weight_matrix_key": "parliamentary_constituency_weights.h5",
15+
"lookup_csv_bucket": "policyengine-uk-data-private",
16+
"lookup_csv_key": "constituencies_2024.csv",
17+
},
18+
"local_authority": {
19+
"weight_matrix_bucket": "policyengine-uk-data-private",
20+
"weight_matrix_key": "local_authority_weights.h5",
21+
"lookup_csv_bucket": "policyengine-uk-data-private",
22+
"lookup_csv_key": "local_authorities_2021.csv",
23+
},
24+
}
25+
26+
27+
def reconstruct_strategy(
28+
filter_strategy: str | None,
29+
filter_field: str | None,
30+
filter_value: str | None,
31+
region_type: str | None,
32+
) -> object | None:
33+
"""Reconstruct a ScopingStrategy from DB columns.
34+
35+
Imports from policyengine.core.scoping_strategy are deferred to avoid
36+
import errors when the published policyengine package does not yet
37+
include the scoping_strategy module.
38+
39+
Args:
40+
filter_strategy: Strategy type ('row_filter' or 'weight_replacement').
41+
filter_field: The household variable name (for row_filter).
42+
filter_value: The value to match or region code.
43+
region_type: The region type (e.g., 'constituency', 'local_authority').
44+
45+
Returns:
46+
A ScopingStrategy instance, or None if no strategy is needed.
47+
"""
48+
if filter_strategy is None:
49+
return None
50+
51+
from policyengine.core.scoping_strategy import (
52+
RowFilterStrategy,
53+
WeightReplacementStrategy,
54+
)
55+
56+
if filter_strategy == "row_filter":
57+
if not filter_field or not filter_value:
58+
return None
59+
return RowFilterStrategy(
60+
variable_name=filter_field,
61+
variable_value=filter_value,
62+
)
63+
64+
if filter_strategy == "weight_replacement":
65+
if not filter_value or not region_type:
66+
return None
67+
config = WEIGHT_MATRIX_CONFIG.get(region_type)
68+
if not config:
69+
raise ValueError(
70+
f"No weight matrix config for region type '{region_type}'. "
71+
f"Known types: {list(WEIGHT_MATRIX_CONFIG.keys())}"
72+
)
73+
return WeightReplacementStrategy(
74+
region_code=filter_value,
75+
**config,
76+
)
77+
78+
raise ValueError(
79+
f"Unknown filter_strategy '{filter_strategy}'. "
80+
f"Expected 'row_filter' or 'weight_replacement'."
81+
)

0 commit comments

Comments
 (0)