From da33af28afe73820c8734ee8b7a1d560f1b843df Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 6 Mar 2026 20:05:52 +0100 Subject: [PATCH] feat: Add adds/subtracts to Variable model with parameter path resolution Variables using declarative aggregation (adds/subtracts) now have these fields populated on the Variable model. Parameter path strings are resolved to list[str] at init time so downstream consumers always get arrays. Closes #247 Co-Authored-By: Claude Opus 4.6 --- src/policyengine/core/variable.py | 2 + .../tax_benefit_models/uk/model.py | 33 ++++++ .../tax_benefit_models/us/model.py | 33 ++++++ tests/test_models.py | 109 ++++++++++++++++++ 4 files changed, 177 insertions(+) diff --git a/src/policyengine/core/variable.py b/src/policyengine/core/variable.py index fa8e50e2..69830122 100644 --- a/src/policyengine/core/variable.py +++ b/src/policyengine/core/variable.py @@ -15,3 +15,5 @@ class Variable(BaseModel): possible_values: list[Any] | None = None default_value: Any = None value_type: type | None = None + adds: list[str] | None = None + subtracts: list[str] | None = None diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 04860e58..47198bd3 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -192,6 +192,39 @@ def __init__(self, **kwargs: dict): var_obj.possible_values._value2member_map_.values(), ) ) + # Extract and resolve adds/subtracts. + # Core stores these as either list[str] or a parameter path string. + # Resolve parameter paths to lists so consumers always get list[str]. + if hasattr(var_obj, "adds") and var_obj.adds is not None: + if isinstance(var_obj.adds, str): + try: + from policyengine_core.parameters.operations.get_parameter import ( + get_parameter, + ) + + param = get_parameter( + system.parameters, var_obj.adds + ) + variable.adds = list(param("2025-01-01")) + except (ValueError, Exception): + variable.adds = None + else: + variable.adds = var_obj.adds + if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None: + if isinstance(var_obj.subtracts, str): + try: + from policyengine_core.parameters.operations.get_parameter import ( + get_parameter, + ) + + param = get_parameter( + system.parameters, var_obj.subtracts + ) + variable.subtracts = list(param("2025-01-01")) + except (ValueError, Exception): + variable.subtracts = None + else: + variable.subtracts = var_obj.subtracts self.add_variable(variable) from policyengine_core.parameters import Parameter as CoreParameter diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index b80a4d3e..64145673 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -165,6 +165,39 @@ def __init__(self, **kwargs: dict): var_obj.possible_values._value2member_map_.values(), ) ) + # Extract and resolve adds/subtracts. + # Core stores these as either list[str] or a parameter path string. + # Resolve parameter paths to lists so consumers always get list[str]. + if hasattr(var_obj, "adds") and var_obj.adds is not None: + if isinstance(var_obj.adds, str): + try: + from policyengine_core.parameters.operations.get_parameter import ( + get_parameter, + ) + + param = get_parameter( + system.parameters, var_obj.adds + ) + variable.adds = list(param("2025-01-01")) + except (ValueError, Exception): + variable.adds = None + else: + variable.adds = var_obj.adds + if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None: + if isinstance(var_obj.subtracts, str): + try: + from policyengine_core.parameters.operations.get_parameter import ( + get_parameter, + ) + + param = get_parameter( + system.parameters, var_obj.subtracts + ) + variable.subtracts = list(param("2025-01-01")) + except (ValueError, Exception): + variable.subtracts = None + else: + variable.subtracts = var_obj.subtracts self.add_variable(variable) from policyengine_core.parameters import Parameter as CoreParameter diff --git a/tests/test_models.py b/tests/test_models.py index e5b4484e..0ab5461e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -178,6 +178,115 @@ def test__given_bracket_label__then_follows_expected_format(self): break +class TestVariableAddsSubtracts: + """Tests for Variable adds/subtracts extraction and parameter path resolution.""" + + def test_us_variable_with_list_adds_has_list(self): + """US variables with list-type adds should have list[str] on the Variable.""" + # employment_income uses adds as a list of variable names + var = next( + (v for v in us_latest.variables if v.name == "employment_income"), + None, + ) + assert var is not None, "employment_income not found in US model" + assert var.adds is not None, "employment_income should have adds" + assert isinstance(var.adds, list), "adds should be a list" + assert len(var.adds) > 0, "adds should not be empty" + assert all( + isinstance(name, str) for name in var.adds + ), "all adds entries should be strings" + + def test_us_variable_with_parameter_path_adds_resolves_to_list(self): + """US variables whose core adds is a parameter path should resolve to list[str].""" + # household_state_benefits uses adds as a parameter path string + # "gov.household.household_state_benefits" + var = next( + ( + v + for v in us_latest.variables + if v.name == "household_state_benefits" + ), + None, + ) + assert var is not None, ( + "household_state_benefits not found in US model" + ) + assert var.adds is not None, ( + "household_state_benefits should have adds (resolved from param path)" + ) + assert isinstance(var.adds, list), ( + "adds should be resolved to a list, not a string" + ) + assert len(var.adds) > 0, "resolved adds should not be empty" + + def test_us_variable_without_adds_has_none(self): + """US variables without adds should have adds=None.""" + age_var = next( + (v for v in us_latest.variables if v.name == "age"), None + ) + assert age_var is not None, "age variable not found in US model" + assert age_var.adds is None, "age should not have adds" + + def test_us_variable_without_subtracts_has_none(self): + """US variables without subtracts should have subtracts=None.""" + age_var = next( + (v for v in us_latest.variables if v.name == "age"), None + ) + assert age_var is not None, "age variable not found in US model" + assert age_var.subtracts is None, "age should not have subtracts" + + def test_us_some_variables_have_adds(self): + """US model should have many variables with adds populated.""" + vars_with_adds = [v for v in us_latest.variables if v.adds is not None] + assert len(vars_with_adds) >= 50, ( + f"Expected at least 50 variables with adds, got {len(vars_with_adds)}" + ) + + def test_uk_variable_with_adds_has_list(self): + """UK variables with adds should have list[str] on the Variable.""" + # total_income is a common UK aggregation variable + total_income_var = next( + (v for v in uk_latest.variables if v.name == "total_income"), None + ) + assert total_income_var is not None, "total_income not found in UK model" + assert total_income_var.adds is not None, "total_income should have adds" + assert isinstance(total_income_var.adds, list), "adds should be a list" + assert len(total_income_var.adds) > 0, "adds should not be empty" + + def test_uk_variable_without_adds_has_none(self): + """UK variables without adds should have adds=None.""" + age_var = next( + (v for v in uk_latest.variables if v.name == "age"), None + ) + assert age_var is not None, "age variable not found in UK model" + assert age_var.adds is None, "age should not have adds" + + def test_us_variable_with_subtracts_has_list(self): + """US variables with subtracts should have list[str] on the Variable.""" + var = next( + (v for v in us_latest.variables if v.name == "household_net_income"), + None, + ) + assert var is not None, "household_net_income not found in US model" + assert var.subtracts is not None, "household_net_income should have subtracts" + assert isinstance(var.subtracts, list), "subtracts should be a list" + assert len(var.subtracts) > 0, "subtracts should not be empty" + + def test_adds_entries_are_valid_variable_names(self): + """adds entries should reference real variable names in the model.""" + all_var_names = {v.name for v in us_latest.variables} + var = next( + (v for v in us_latest.variables if v.name == "employment_income"), + None, + ) + assert var is not None + assert var.adds is not None + for component_name in var.adds: + assert component_name in all_var_names, ( + f"adds entry '{component_name}' is not a valid variable in the US model" + ) + + class TestVariableDefaultValue: """Tests for Variable default_value and value_type fields."""