diff --git a/changelog.d/484.fixed.md b/changelog.d/484.fixed.md new file mode 100644 index 00000000..d8dc57f9 --- /dev/null +++ b/changelog.d/484.fixed.md @@ -0,0 +1 @@ +Preserve values passed through dispatched `set_input` handlers across cache invalidation. diff --git a/changelog.d/490.fixed.md b/changelog.d/490.fixed.md new file mode 100644 index 00000000..a2dbf761 --- /dev/null +++ b/changelog.d/490.fixed.md @@ -0,0 +1 @@ +Apply chained parameter uprating dependencies in deterministic dependency order. diff --git a/policyengine_core/holders/helpers.py b/policyengine_core/holders/helpers.py index 21a9ebdb..30e324c6 100644 --- a/policyengine_core/holders/helpers.py +++ b/policyengine_core/holders/helpers.py @@ -10,6 +10,18 @@ log = logging.getLogger(__name__) +def get_input_branch(holder: Holder) -> str: + simulation = getattr(holder, "simulation", None) + user_input_contexts = getattr(simulation, "_user_input_contexts", None) + if user_input_contexts: + return user_input_contexts[-1] + return "default" + + +def get_stored_array(holder: Holder, period: Period, branch_name: str) -> ArrayLike: + return holder._get_array_from_storage(period, branch_name) + + def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLike): """ This function can be declared as a ``set_input`` attribute of a variable. @@ -35,11 +47,12 @@ def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLik after_instant = period.start.offset(period_size, period_unit) # Cache the input data, skipping the existing cached months + branch_name = get_input_branch(holder) sub_period = period.start.period(cached_period_unit) while sub_period.start < after_instant: - existing_array = holder.get_array(sub_period) + existing_array = get_stored_array(holder, sub_period, branch_name) if existing_array is None: - holder._set(sub_period, array) + holder._set(sub_period, array, branch_name) else: # The array of the current sub-period is reused for the next ones. # TODO: refactor or document this behavior @@ -72,11 +85,12 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike) after_instant = period.start.offset(period_size, period_unit) # Count the number of elementary periods to change, and the difference with what is already known. + branch_name = get_input_branch(holder) remaining_array = array.copy() sub_period = period.start.period(cached_period_unit) sub_periods_count = 0 while sub_period.start < after_instant: - existing_array = holder.get_array(sub_period) + existing_array = get_stored_array(holder, sub_period, branch_name) if existing_array is not None: remaining_array -= existing_array else: @@ -88,8 +102,8 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike) divided_array = remaining_array / sub_periods_count sub_period = period.start.period(cached_period_unit) while sub_period.start < after_instant: - if holder.get_array(sub_period) is None: - holder._set(sub_period, divided_array) + if get_stored_array(holder, sub_period, branch_name) is None: + holder._set(sub_period, divided_array, branch_name) sub_period = sub_period.offset(1) elif not (remaining_array == 0).all(): raise ValueError( diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py index 6ee42d1b..fa2aac5e 100644 --- a/policyengine_core/holders/holder.py +++ b/policyengine_core/holders/holder.py @@ -94,6 +94,14 @@ def delete_arrays( if self._disk_storage: self._disk_storage.delete(period, branch_name) + def _get_array_from_storage( + self, period: Period, branch_name: str = "default" + ) -> ArrayLike: + value = self._memory_storage.get(period, branch_name) + if value is None and self._disk_storage: + value = self._disk_storage.get(period, branch_name) + return value + def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike: """ Get the value of the variable for the given period. @@ -102,7 +110,9 @@ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike: """ if self.variable.is_neutralized: return self.default_array() - value = self._memory_storage.get(period, branch_name) + value = self._get_array_from_storage(period, branch_name) + if value is not None: + return value if value is None and branch_name != "default": # Walk up ``simulation.parent_branch`` so nested branches inherit # values from their parent (e.g. a ``no_salt`` branch cloned @@ -121,17 +131,16 @@ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike: else None ) while parent is not None: - ancestor_value = self._memory_storage.get(period, parent.branch_name) + ancestor_value = self._get_array_from_storage( + period, + parent.branch_name, + ) if ancestor_value is not None: return ancestor_value parent = getattr(parent, "parent_branch", None) - default_value = self._memory_storage.get(period, "default") + default_value = self._get_array_from_storage(period, "default") if default_value is not None: return default_value - if value is not None: - return value - if self._disk_storage: - return self._disk_storage.get(period, branch_name) def get_memory_usage(self) -> dict: """ @@ -241,21 +250,23 @@ def set_input( return warnings.warn(warning_message, Warning) if self.variable.value_type in (float, int) and isinstance(array, str): array = tools.eval_expression(array) - # Track user-provided inputs on the simulation so - # ``Simulation._invalidate_all_caches`` can preserve them across - # ``apply_reform``. ``Simulation.set_input`` also records this, but - # ``SimulationBuilder.finalize_variables_init`` (the situation-dict - # path) and country-package dataset loaders call - # ``holder.set_input`` directly, bypassing the simulation-level hook. - # Recording here covers both paths. simulation = getattr(self, "simulation", None) if simulation is not None: if not hasattr(simulation, "_user_input_keys"): simulation._user_input_keys = set() - simulation._user_input_keys.add((self.variable.name, branch_name, period)) - if self.variable.set_input and period.unit != self.variable.definition_period: - return self.variable.set_input(self, period, array) - return self._set(period, array, branch_name) + if not hasattr(simulation, "_user_input_contexts"): + simulation._user_input_contexts = [] + simulation._user_input_contexts.append(branch_name) + try: + if ( + self.variable.set_input + and period.unit != self.variable.definition_period + ): + return self.variable.set_input(self, period, array) + return self._set(period, array, branch_name) + finally: + if simulation is not None: + simulation._user_input_contexts.pop() def _to_array(self, value: Any) -> ArrayLike: if not isinstance(value, numpy.ndarray): @@ -295,6 +306,10 @@ def _to_array(self, value: Any) -> ArrayLike: def _set( self, period: Period, value: ArrayLike, branch_name: str = "default" ) -> None: + simulation = getattr(self, "simulation", None) + user_input_contexts = getattr(simulation, "_user_input_contexts", None) + if user_input_contexts and branch_name == "default": + branch_name = user_input_contexts[-1] value = self._to_array(value) if self.variable.definition_period != periods.ETERNITY: if period is None: @@ -313,6 +328,10 @@ def _set( self._disk_storage.put(value, period, branch_name) else: self._memory_storage.put(value, period, branch_name) + if user_input_contexts: + if not hasattr(simulation, "_user_input_keys"): + simulation._user_input_keys = set() + simulation._user_input_keys.add((self.variable.name, branch_name, period)) def put_in_cache( self, value: ArrayLike, period: Period, branch_name: str = "default" diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py index 415b9fbf..e9c10ecc 100644 --- a/policyengine_core/parameters/operations/uprate_parameters.py +++ b/policyengine_core/parameters/operations/uprate_parameters.py @@ -6,6 +6,7 @@ from dateutil.relativedelta import relativedelta from dateutil.parser import parse from datetime import datetime +from typing import Optional, Union from policyengine_core.parameters.operations.get_parameter import get_parameter from policyengine_core.parameters.parameter import Parameter @@ -13,8 +14,7 @@ ParameterAtInstant, ) from policyengine_core.parameters.parameter_node import ParameterNode -from policyengine_core.parameters.parameter_scale import ParameterScale -from policyengine_core.periods import instant, Instant +from policyengine_core.periods import instant def uprate_parameters(root: ParameterNode) -> ParameterNode: @@ -27,135 +27,179 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode: ParameterNode: The same root, with uprating applied to descendants. """ - descendants = list(root.get_descendants()) - - scales = list(filter(lambda p: isinstance(p, ParameterScale), descendants)) - for scale in scales: - for bracket in scale.brackets: - for allowed_key in bracket._allowed_keys: - if hasattr(bracket, allowed_key): - descendants.append(getattr(bracket, allowed_key)) - - for parameter in descendants: - if isinstance(parameter, Parameter): - if parameter.metadata.get("uprating") is not None: - # Pull the uprating definition dict - meta = parameter.metadata["uprating"] - - # If defined in short method (i.e. "uprating: PARAM"), - # redefine this as dict with param key - if meta == "self": - meta = dict(parameter="self") - elif isinstance(meta, str): - meta = dict(parameter=meta) - - # If param is "self", construct the uprating table - if meta["parameter"] == "self": - uprating_parameter = construct_uprater_self( - parameter, - meta, - ) - # Otherwise, pull uprating table from YAML - else: - uprating_parameter = get_parameter(root, meta["parameter"]) - - # If uprating with a set candence, ensure that all - # required values are present - cadence_meta = meta.get("at_defined_interval") - cadence_options = {} - if cadence_meta: - cadence_options_test = [ - cadence_meta.get("start"), - cadence_meta.get("end"), - cadence_meta.get("enactment"), - ] - - # Ensure that all options are properly defined - if not all(cadence_options_test): - raise SyntaxError( - f"Failed to uprate {parameter.name} using cadence; start, end, and enactment must all be provided" - ) - - # Construct cadence options object - cadence_options = construct_cadence_options(cadence_meta, parameter) - - # Ensure that end comes after start and enactment comes after end - if cadence_options["end"] <= cadence_options["start"]: - raise ValueError( - f"Failed to uprate {parameter.name} using {uprating_parameter.name}: end must come after start" - ) - if cadence_options["enactment"] <= cadence_options["end"]: - raise ValueError( - f"Failed to uprate {parameter.name} using {uprating_parameter.name}: enactment must come after end" - ) - - # Determine the first date from which to start uprating - - # this should be the first application date (month, day) - # following the last defined param value (not including the - # final value) - uprating_first_date: datetime = find_cadence_first( - parameter, cadence_options - ) - uprating_last_date: datetime = find_cadence_last( - uprating_parameter, cadence_options - ) - - # Uprate data - uprated_data = uprate_by_cadence( - parameter, - uprating_parameter, - cadence_options, - uprating_first_date, - uprating_last_date, - meta, - ) + parameters = [ + parameter + for parameter in root.get_descendants() + if isinstance(parameter, Parameter) + ] - # Append uprated data to parameter values list - parameter.values_list.extend(uprated_data) - - else: - # Start from the latest value - if "start_instant" in meta: - last_instant = instant(meta["start_instant"]) - else: - last_instant = instant(parameter.values_list[0].instant_str) - - # Pre-compute values that don't change in the loop - last_instant_str = str(last_instant) - value_at_start = parameter(last_instant) - uprater_at_start = uprating_parameter(last_instant) - - if uprater_at_start is None: - raise ValueError( - f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} because the uprating parameter is not defined at {last_instant}." - ) - - # Pre-compute uprater values for all entries to avoid repeated lookups - has_rounding = "rounding" in meta - - # For each defined instant in the uprating parameter - for entry in uprating_parameter.values_list[::-1]: - entry_instant = instant(entry.instant_str) - # If the uprater instant is defined after the last parameter instant - if entry_instant > last_instant: - # Apply the uprater and add to the parameter - uprater_at_entry = uprating_parameter(entry_instant) - uprater_change = uprater_at_entry / uprater_at_start - uprated_value = value_at_start * uprater_change - if has_rounding: - uprated_value = round_uprated_value(meta, uprated_value) - parameter.values_list.append( - ParameterAtInstant( - parameter.name, - entry.instant_str, - data=uprated_value, - ) - ) - # Whether using cadence or not, sort the parameter values_list - parameter.values_list.sort(key=lambda x: x.instant_str, reverse=True) + for parameter in sort_parameters_by_uprating_dependencies(parameters): + uprate_parameter(parameter, root) return root +def normalize_uprating_metadata(meta: Union[dict, str]) -> dict: + if meta == "self": + return dict(parameter="self") + if isinstance(meta, str): + return dict(parameter=meta) + return meta + + +def get_uprating_dependency_name(parameter: Parameter) -> Optional[str]: + meta = parameter.metadata.get("uprating") + if meta is None: + return None + meta = normalize_uprating_metadata(meta) + dependency_name = meta["parameter"] + if dependency_name == "self": + return None + return dependency_name + + +def sort_parameters_by_uprating_dependencies( + parameters: list[Parameter], +) -> list[Parameter]: + parameters_to_uprate = [ + parameter + for parameter in parameters + if parameter.metadata.get("uprating") is not None + ] + parameter_by_name = { + parameter.name: parameter for parameter in parameters_to_uprate + } + ordered_parameters = [] + visited = set() + visiting = [] + + def visit(parameter: Parameter): + if parameter.name in visited: + return + if parameter.name in visiting: + cycle = visiting[visiting.index(parameter.name) :] + [parameter.name] + raise ValueError( + "Cyclic uprating dependency detected: " + " -> ".join(cycle) + ) + visiting.append(parameter.name) + dependency_name = get_uprating_dependency_name(parameter) + if dependency_name in parameter_by_name: + visit(parameter_by_name[dependency_name]) + visiting.pop() + visited.add(parameter.name) + ordered_parameters.append(parameter) + + for parameter in sorted(parameters_to_uprate, key=lambda p: p.name): + visit(parameter) + + return ordered_parameters + + +def uprate_parameter(parameter: Parameter, root: ParameterNode) -> None: + # Pull the uprating definition dict + meta = normalize_uprating_metadata(parameter.metadata["uprating"]) + + # If param is "self", construct the uprating table + if meta["parameter"] == "self": + uprating_parameter = construct_uprater_self( + parameter, + meta, + ) + # Otherwise, pull uprating table from YAML + else: + uprating_parameter = get_parameter(root, meta["parameter"]) + + # If uprating with a set candence, ensure that all + # required values are present + cadence_meta = meta.get("at_defined_interval") + cadence_options = {} + if cadence_meta: + cadence_options_test = [ + cadence_meta.get("start"), + cadence_meta.get("end"), + cadence_meta.get("enactment"), + ] + + # Ensure that all options are properly defined + if not all(cadence_options_test): + raise SyntaxError( + f"Failed to uprate {parameter.name} using cadence; start, end, and enactment must all be provided" + ) + + # Construct cadence options object + cadence_options = construct_cadence_options(cadence_meta, parameter) + + # Ensure that end comes after start and enactment comes after end + if cadence_options["end"] <= cadence_options["start"]: + raise ValueError( + f"Failed to uprate {parameter.name} using {uprating_parameter.name}: end must come after start" + ) + if cadence_options["enactment"] <= cadence_options["end"]: + raise ValueError( + f"Failed to uprate {parameter.name} using {uprating_parameter.name}: enactment must come after end" + ) + + # Determine the first date from which to start uprating - + # this should be the first application date (month, day) + # following the last defined param value (not including the + # final value) + uprating_first_date: datetime = find_cadence_first(parameter, cadence_options) + uprating_last_date: datetime = find_cadence_last( + uprating_parameter, cadence_options + ) + + # Uprate data + uprated_data = uprate_by_cadence( + parameter, + uprating_parameter, + cadence_options, + uprating_first_date, + uprating_last_date, + meta, + ) + + # Append uprated data to parameter values list + parameter.values_list.extend(uprated_data) + + else: + # Start from the latest value + if "start_instant" in meta: + last_instant = instant(meta["start_instant"]) + else: + last_instant = instant(parameter.values_list[0].instant_str) + + # Pre-compute values that don't change in the loop + value_at_start = parameter(last_instant) + uprater_at_start = uprating_parameter(last_instant) + + if uprater_at_start is None: + raise ValueError( + f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} because the uprating parameter is not defined at {last_instant}." + ) + + has_rounding = "rounding" in meta + + # For each defined instant in the uprating parameter + for entry in uprating_parameter.values_list[::-1]: + entry_instant = instant(entry.instant_str) + # If the uprater instant is defined after the last parameter instant + if entry_instant > last_instant: + # Apply the uprater and add to the parameter + uprater_at_entry = uprating_parameter(entry_instant) + uprater_change = uprater_at_entry / uprater_at_start + uprated_value = value_at_start * uprater_change + if has_rounding: + uprated_value = round_uprated_value(meta, uprated_value) + parameter.values_list.append( + ParameterAtInstant( + parameter.name, + entry.instant_str, + data=uprated_value, + ) + ) + # Whether using cadence or not, sort the parameter values_list + parameter.values_list.sort(key=lambda x: x.instant_str, reverse=True) + + def round_uprated_value(meta: dict, uprated_value: float) -> float: rounding_config = meta["rounding"] if isinstance(rounding_config, float): diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py index ce837a2d..0df0e387 100644 --- a/policyengine_core/parameters/parameter_node.py +++ b/policyengine_core/parameters/parameter_node.py @@ -89,7 +89,7 @@ def __init__( if directory_path: self.file_path = directory_path - for child_name in os.listdir(directory_path): + for child_name in sorted(os.listdir(directory_path)): child_path = os.path.join(directory_path, child_name) if os.path.isfile(child_path): child_name, ext = os.path.splitext(child_name) diff --git a/policyengine_core/parameters/parameter_scale_bracket.py b/policyengine_core/parameters/parameter_scale_bracket.py index 703b25a7..8d80e114 100644 --- a/policyengine_core/parameters/parameter_scale_bracket.py +++ b/policyengine_core/parameters/parameter_scale_bracket.py @@ -7,7 +7,7 @@ class ParameterScaleBracket(ParameterNode): A parameter scale bracket. """ - _allowed_keys = set(["amount", "threshold", "rate", "average_rate", "base"]) + _allowed_keys = ("amount", "threshold", "rate", "average_rate", "base") @staticmethod def allowed_unit_keys(): diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index eacf415c..2bdfe045 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1,7 +1,7 @@ import hashlib import tempfile from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import numpy as np import pandas as pd @@ -64,6 +64,10 @@ class PreservedUserInput: branch_name: str period: Period value: object + storage: str + disk_key: Optional[str] = None + disk_file: Optional[str] = None + disk_enum: object = None class Simulation: @@ -297,8 +301,31 @@ def _invalidate_all_caches(self) -> None: branch_name=branch_name, period=period, value=stored_value, + storage="memory", ) ) + continue + if holder._disk_storage is not None: + disk_period = ( + periods.period(periods.ETERNITY) + if holder._disk_storage.is_eternal + else periods.period(period) + ) + disk_key = f"{branch_name}_{disk_period}" + disk_file = holder._disk_storage._files.get(disk_key) + if disk_file is not None: + preserved.append( + PreservedUserInput( + variable_name=variable_name, + branch_name=branch_name, + period=period, + value=None, + storage="disk", + disk_key=disk_key, + disk_file=disk_file, + disk_enum=holder._disk_storage._enums.get(disk_file), + ) + ) # Iterate only over holders that already exist on each population — # lazy-creating a holder for every variable in the tax-benefit # system (thousands in policyengine-us) inflated the cost of @@ -313,11 +340,18 @@ def _invalidate_all_caches(self) -> None: # Replay preserved user inputs so ``calculate`` still sees them. for user_input in preserved: holder = self.get_holder(user_input.variable_name) - holder._memory_storage.put( - user_input.value, - user_input.period, - user_input.branch_name, - ) + if user_input.storage == "disk" and holder._disk_storage is not None: + holder._disk_storage._files[user_input.disk_key] = user_input.disk_file + if user_input.disk_enum is not None: + holder._disk_storage._enums[user_input.disk_file] = ( + user_input.disk_enum + ) + else: + holder._memory_storage.put( + user_input.value, + user_input.period, + user_input.branch_name, + ) for branch in self.branches.values(): branch._invalidate_all_caches() @@ -1302,12 +1336,6 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value, self.branch_name) - # Lazy-init ``_user_input_keys`` so country-package subclasses that - # override ``__init__`` without calling ``super().__init__`` still - # benefit from the set-input preservation across ``apply_reform``. - if not hasattr(self, "_user_input_keys"): - self._user_input_keys = set() - self._user_input_keys.add((variable_name, self.branch_name, period)) _fast_cache = getattr(self, "_fast_cache", None) if _fast_cache is not None: _fast_cache.pop((variable_name, period), None) diff --git a/tests/core/parameters/operations/test_uprating.py b/tests/core/parameters/operations/test_uprating.py index 096c1d8d..af99eac3 100644 --- a/tests/core/parameters/operations/test_uprating.py +++ b/tests/core/parameters/operations/test_uprating.py @@ -1,6 +1,106 @@ import pytest +def test_parameter_uprating_processes_dependencies_before_dependents(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "target": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "middle"}, + }, + "middle": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "base"}, + }, + "base": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + } + ) + + uprated = uprate_parameters(root) + + assert uprated.middle("2026-01-01") == pytest.approx(110) + assert uprated.target("2026-01-01") == pytest.approx(110) + + +def test_scale_bracket_uprating_processes_dependencies_before_dependents(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "scale": { + "metadata": { + "uprating": "middle", + "uprate_thresholds": True, + }, + "brackets": [ + { + "threshold": {"values": {"2025-01-01": 100}}, + "rate": {"values": {"2025-01-01": 0.1}}, + }, + ], + }, + "middle": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "base"}, + }, + "base": { + "values": {"2025-01-01": 100, "2026-01-01": 110}, + }, + } + ) + + uprated = uprate_parameters(root) + + assert uprated.middle("2026-01-01") == pytest.approx(110) + assert uprated.scale.brackets[0].threshold("2026-01-01") == pytest.approx(110) + + +def test_parameter_uprating_rejects_cyclic_dependencies(): + from policyengine_core.parameters import ParameterNode, uprate_parameters + + root = ParameterNode( + data={ + "a": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "b"}, + }, + "b": { + "values": {"2025-01-01": 100}, + "metadata": {"uprating": "a"}, + }, + } + ) + + with pytest.raises(ValueError, match="Cyclic uprating dependency"): + uprate_parameters(root) + + +def test_parameter_node_loads_directory_children_deterministically( + tmp_path, + monkeypatch, +): + from policyengine_core.parameters import ParameterNode + + for child_name in ("zeta.yaml", "alpha.yaml"): + (tmp_path / child_name).write_text( + "values:\n 2025-01-01:\n value: 1\n", + encoding="utf-8", + ) + + monkeypatch.setattr( + "os.listdir", + lambda directory_path: ["zeta.yaml", "alpha.yaml"], + ) + + root = ParameterNode(directory_path=str(tmp_path)) + + assert list(root.children) == ["alpha", "zeta"] + + def test_parameter_uprating(): from policyengine_core.parameters import ParameterNode diff --git a/tests/core/test_apply_reform_preserves_user_inputs.py b/tests/core/test_apply_reform_preserves_user_inputs.py index f11c5c57..f678b7e3 100644 --- a/tests/core/test_apply_reform_preserves_user_inputs.py +++ b/tests/core/test_apply_reform_preserves_user_inputs.py @@ -27,6 +27,7 @@ from policyengine_core.model_api import Reform from policyengine_core.country_template import situation_examples +from policyengine_core.experimental import MemoryConfig from policyengine_core.simulations import Simulation, SimulationBuilder @@ -143,6 +144,116 @@ def apply(self): ) +def test_apply_reform_preserves_dispatched_set_input_values( + tax_benefit_system, +): + """Inputs dispatched into subperiod storage must survive reform apply.""" + sim = SimulationBuilder().build_from_entities( + tax_benefit_system, situation_examples.single + ) + input_period = "2017" + calculated_period = "2017-01" + yearly_salary = np.array([12_000.0]) + expected_monthly_salary = np.array([1_000.0]) + + sim.set_input("salary", input_period, yearly_salary) + assert np.allclose( + sim.calculate("salary", calculated_period), expected_monthly_salary + ) + + class NoOpReform(Reform): + def apply(self): + pass + + sim.apply_reform(NoOpReform) + + result = sim.calculate("salary", calculated_period) + assert np.allclose(result, expected_monthly_salary), ( + "apply_reform lost a salary input that was dispatched from a yearly " + f"input to monthly storage; got {result} instead of " + f"{expected_monthly_salary}." + ) + + +def test_apply_reform_preserves_dispatched_inputs_on_branch( + tax_benefit_system, +): + """Dispatched inputs must keep the branch they were set on.""" + sim = SimulationBuilder().build_from_entities( + tax_benefit_system, situation_examples.single + ) + default_salary = np.array([24_000.0]) + expected_default_monthly_salary = np.array([2_000.0]) + sim.set_input("salary", "2017", default_salary) + + branch = sim.get_branch("reform") + input_period = "2017" + calculated_period = "2017-01" + yearly_salary = np.array([12_000.0]) + expected_monthly_salary = np.array([1_000.0]) + + branch.set_input("salary", input_period, yearly_salary) + salary_holder = branch.get_holder("salary") + + assert np.allclose( + salary_holder._memory_storage.get(calculated_period, "default"), + expected_default_monthly_salary, + ) + assert np.allclose( + salary_holder._memory_storage.get(calculated_period, "reform"), + expected_monthly_salary, + ) + + class NoOpReform(Reform): + def apply(self): + pass + + branch.apply_reform(NoOpReform) + + assert np.allclose( + branch.calculate("salary", calculated_period), + expected_monthly_salary, + ) + assert np.allclose( + sim.calculate("salary", calculated_period), + expected_default_monthly_salary, + ) + + +def test_apply_reform_preserves_on_disk_inputs_on_disk(tax_benefit_system): + """Disk-backed user inputs should not be replayed into memory.""" + sim = SimulationBuilder().build_from_entities( + tax_benefit_system, situation_examples.single + ) + sim.memory_config = MemoryConfig(max_memory_occupation=0) + period = "2017-01" + expected_salary = np.array([5_000.0]) + salary_holder = sim.get_holder("salary") + salary_holder._disk_storage = salary_holder.create_disk_storage() + salary_holder._on_disk_storable = True + + sim.set_input("salary", period, expected_salary) + + assert salary_holder._memory_storage.get(period, "default") is None + assert np.allclose( + salary_holder._disk_storage.get(period, "default"), + expected_salary, + ) + + class NoOpReform(Reform): + def apply(self): + pass + + sim.apply_reform(NoOpReform) + + assert salary_holder._memory_storage.get(period, "default") is None + assert np.allclose( + salary_holder._disk_storage.get(period, "default"), + expected_salary, + ) + assert np.allclose(sim.calculate("salary", period), expected_salary) + + def test_apply_reform_preserves_situation_dict_inputs(tax_benefit_system): """Situation-dict inputs must survive ``apply_reform`` too. diff --git a/tests/core/test_holder_branch_fallback.py b/tests/core/test_holder_branch_fallback.py index df674bbe..a1dcc864 100644 --- a/tests/core/test_holder_branch_fallback.py +++ b/tests/core/test_holder_branch_fallback.py @@ -62,6 +62,21 @@ def test_get_array_falls_back_to_default_branch(tax_benefit_system): assert result[0] == 42.0 +def test_get_array_falls_back_to_default_branch_on_disk(tax_benefit_system): + """Branch fallback must work for disk-backed default values too.""" + sim = _build_single(tax_benefit_system) + holder = sim.person.get_holder("salary") + period = periods.period("2017-01") + holder._disk_storage = holder.create_disk_storage() + + holder._disk_storage.put(np.asarray([5_000.0]), period, "default") + + result = holder.get_array(period, "reform") + + assert result is not None + assert result[0] == 5_000.0 + + def test_get_array_falls_back_through_parent_branch_chain(tax_benefit_system): """Nested branches must inherit values from their parent branch.