Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/policyengine/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ class Variable(BaseModel):
possible_values: list[Any] | None = None
default_value: Any = None
value_type: type | None = None
adds: list[str] | None = None
subtracts: list[str] | None = None
33 changes: 33 additions & 0 deletions src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ def __init__(self, **kwargs: dict):
var_obj.possible_values._value2member_map_.values(),
)
)
# Extract and resolve adds/subtracts.
# Core stores these as either list[str] or a parameter path string.
# Resolve parameter paths to lists so consumers always get list[str].
if hasattr(var_obj, "adds") and var_obj.adds is not None:
if isinstance(var_obj.adds, str):
try:
from policyengine_core.parameters.operations.get_parameter import (
get_parameter,
)

param = get_parameter(
system.parameters, var_obj.adds
)
variable.adds = list(param("2025-01-01"))
except (ValueError, Exception):
variable.adds = None
else:
variable.adds = var_obj.adds
if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None:
if isinstance(var_obj.subtracts, str):
try:
from policyengine_core.parameters.operations.get_parameter import (
get_parameter,
)

param = get_parameter(
system.parameters, var_obj.subtracts
)
variable.subtracts = list(param("2025-01-01"))
except (ValueError, Exception):
variable.subtracts = None
else:
variable.subtracts = var_obj.subtracts
self.add_variable(variable)

from policyengine_core.parameters import Parameter as CoreParameter
Expand Down
33 changes: 33 additions & 0 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,39 @@ def __init__(self, **kwargs: dict):
var_obj.possible_values._value2member_map_.values(),
)
)
# Extract and resolve adds/subtracts.
# Core stores these as either list[str] or a parameter path string.
# Resolve parameter paths to lists so consumers always get list[str].
if hasattr(var_obj, "adds") and var_obj.adds is not None:
if isinstance(var_obj.adds, str):
try:
from policyengine_core.parameters.operations.get_parameter import (
get_parameter,
)

param = get_parameter(
system.parameters, var_obj.adds
)
variable.adds = list(param("2025-01-01"))
except (ValueError, Exception):
variable.adds = None
else:
variable.adds = var_obj.adds
if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None:
if isinstance(var_obj.subtracts, str):
try:
from policyengine_core.parameters.operations.get_parameter import (
get_parameter,
)

param = get_parameter(
system.parameters, var_obj.subtracts
)
variable.subtracts = list(param("2025-01-01"))
except (ValueError, Exception):
variable.subtracts = None
else:
variable.subtracts = var_obj.subtracts
self.add_variable(variable)

from policyengine_core.parameters import Parameter as CoreParameter
Expand Down
109 changes: 109 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,115 @@ def test__given_bracket_label__then_follows_expected_format(self):
break


class TestVariableAddsSubtracts:
"""Tests for Variable adds/subtracts extraction and parameter path resolution."""

def test_us_variable_with_list_adds_has_list(self):
"""US variables with list-type adds should have list[str] on the Variable."""
# employment_income uses adds as a list of variable names
var = next(
(v for v in us_latest.variables if v.name == "employment_income"),
None,
)
assert var is not None, "employment_income not found in US model"
assert var.adds is not None, "employment_income should have adds"
assert isinstance(var.adds, list), "adds should be a list"
assert len(var.adds) > 0, "adds should not be empty"
assert all(
isinstance(name, str) for name in var.adds
), "all adds entries should be strings"

def test_us_variable_with_parameter_path_adds_resolves_to_list(self):
"""US variables whose core adds is a parameter path should resolve to list[str]."""
# household_state_benefits uses adds as a parameter path string
# "gov.household.household_state_benefits"
var = next(
(
v
for v in us_latest.variables
if v.name == "household_state_benefits"
),
None,
)
assert var is not None, (
"household_state_benefits not found in US model"
)
assert var.adds is not None, (
"household_state_benefits should have adds (resolved from param path)"
)
assert isinstance(var.adds, list), (
"adds should be resolved to a list, not a string"
)
assert len(var.adds) > 0, "resolved adds should not be empty"

def test_us_variable_without_adds_has_none(self):
"""US variables without adds should have adds=None."""
age_var = next(
(v for v in us_latest.variables if v.name == "age"), None
)
assert age_var is not None, "age variable not found in US model"
assert age_var.adds is None, "age should not have adds"

def test_us_variable_without_subtracts_has_none(self):
"""US variables without subtracts should have subtracts=None."""
age_var = next(
(v for v in us_latest.variables if v.name == "age"), None
)
assert age_var is not None, "age variable not found in US model"
assert age_var.subtracts is None, "age should not have subtracts"

def test_us_some_variables_have_adds(self):
"""US model should have many variables with adds populated."""
vars_with_adds = [v for v in us_latest.variables if v.adds is not None]
assert len(vars_with_adds) >= 50, (
f"Expected at least 50 variables with adds, got {len(vars_with_adds)}"
)

def test_uk_variable_with_adds_has_list(self):
"""UK variables with adds should have list[str] on the Variable."""
# total_income is a common UK aggregation variable
total_income_var = next(
(v for v in uk_latest.variables if v.name == "total_income"), None
)
assert total_income_var is not None, "total_income not found in UK model"
assert total_income_var.adds is not None, "total_income should have adds"
assert isinstance(total_income_var.adds, list), "adds should be a list"
assert len(total_income_var.adds) > 0, "adds should not be empty"

def test_uk_variable_without_adds_has_none(self):
"""UK variables without adds should have adds=None."""
age_var = next(
(v for v in uk_latest.variables if v.name == "age"), None
)
assert age_var is not None, "age variable not found in UK model"
assert age_var.adds is None, "age should not have adds"

def test_us_variable_with_subtracts_has_list(self):
"""US variables with subtracts should have list[str] on the Variable."""
var = next(
(v for v in us_latest.variables if v.name == "household_net_income"),
None,
)
assert var is not None, "household_net_income not found in US model"
assert var.subtracts is not None, "household_net_income should have subtracts"
assert isinstance(var.subtracts, list), "subtracts should be a list"
assert len(var.subtracts) > 0, "subtracts should not be empty"

def test_adds_entries_are_valid_variable_names(self):
"""adds entries should reference real variable names in the model."""
all_var_names = {v.name for v in us_latest.variables}
var = next(
(v for v in us_latest.variables if v.name == "employment_income"),
None,
)
assert var is not None
assert var.adds is not None
for component_name in var.adds:
assert component_name in all_var_names, (
f"adds entry '{component_name}' is not a valid variable in the US model"
)


class TestVariableDefaultValue:
"""Tests for Variable default_value and value_type fields."""

Expand Down