Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/484.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve values passed through dispatched `set_input` handlers across cache invalidation.
1 change: 1 addition & 0 deletions changelog.d/490.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Apply chained parameter uprating dependencies in deterministic dependency order.
24 changes: 19 additions & 5 deletions policyengine_core/holders/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
log = logging.getLogger(__name__)


def get_input_branch(holder: Holder) -> str:
simulation = getattr(holder, "simulation", None)
user_input_contexts = getattr(simulation, "_user_input_contexts", None)
if user_input_contexts:
return user_input_contexts[-1]
return "default"


def get_stored_array(holder: Holder, period: Period, branch_name: str) -> ArrayLike:
return holder._get_array_from_storage(period, branch_name)


def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLike):
"""
This function can be declared as a ``set_input`` attribute of a variable.
Expand All @@ -35,11 +47,12 @@ def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLik
after_instant = period.start.offset(period_size, period_unit)

# Cache the input data, skipping the existing cached months
branch_name = get_input_branch(holder)
sub_period = period.start.period(cached_period_unit)
while sub_period.start < after_instant:
existing_array = holder.get_array(sub_period)
existing_array = get_stored_array(holder, sub_period, branch_name)
if existing_array is None:
holder._set(sub_period, array)
holder._set(sub_period, array, branch_name)
else:
# The array of the current sub-period is reused for the next ones.
# TODO: refactor or document this behavior
Expand Down Expand Up @@ -72,11 +85,12 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike)
after_instant = period.start.offset(period_size, period_unit)

# Count the number of elementary periods to change, and the difference with what is already known.
branch_name = get_input_branch(holder)
remaining_array = array.copy()
sub_period = period.start.period(cached_period_unit)
sub_periods_count = 0
while sub_period.start < after_instant:
existing_array = holder.get_array(sub_period)
existing_array = get_stored_array(holder, sub_period, branch_name)
if existing_array is not None:
remaining_array -= existing_array
else:
Expand All @@ -88,8 +102,8 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike)
divided_array = remaining_array / sub_periods_count
sub_period = period.start.period(cached_period_unit)
while sub_period.start < after_instant:
if holder.get_array(sub_period) is None:
holder._set(sub_period, divided_array)
if get_stored_array(holder, sub_period, branch_name) is None:
holder._set(sub_period, divided_array, branch_name)
sub_period = sub_period.offset(1)
elif not (remaining_array == 0).all():
raise ValueError(
Expand Down
55 changes: 37 additions & 18 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def delete_arrays(
if self._disk_storage:
self._disk_storage.delete(period, branch_name)

def _get_array_from_storage(
self, period: Period, branch_name: str = "default"
) -> ArrayLike:
value = self._memory_storage.get(period, branch_name)
if value is None and self._disk_storage:
value = self._disk_storage.get(period, branch_name)
return value

def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike:
"""
Get the value of the variable for the given period.
Expand All @@ -102,7 +110,9 @@ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike:
"""
if self.variable.is_neutralized:
return self.default_array()
value = self._memory_storage.get(period, branch_name)
value = self._get_array_from_storage(period, branch_name)
if value is not None:
return value
if value is None and branch_name != "default":
# Walk up ``simulation.parent_branch`` so nested branches inherit
# values from their parent (e.g. a ``no_salt`` branch cloned
Expand All @@ -121,17 +131,16 @@ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike:
else None
)
while parent is not None:
ancestor_value = self._memory_storage.get(period, parent.branch_name)
ancestor_value = self._get_array_from_storage(
period,
parent.branch_name,
)
if ancestor_value is not None:
return ancestor_value
parent = getattr(parent, "parent_branch", None)
default_value = self._memory_storage.get(period, "default")
default_value = self._get_array_from_storage(period, "default")
if default_value is not None:
return default_value
if value is not None:
return value
if self._disk_storage:
return self._disk_storage.get(period, branch_name)

def get_memory_usage(self) -> dict:
"""
Expand Down Expand Up @@ -241,21 +250,23 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
# Track user-provided inputs on the simulation so
# ``Simulation._invalidate_all_caches`` can preserve them across
# ``apply_reform``. ``Simulation.set_input`` also records this, but
# ``SimulationBuilder.finalize_variables_init`` (the situation-dict
# path) and country-package dataset loaders call
# ``holder.set_input`` directly, bypassing the simulation-level hook.
# Recording here covers both paths.
simulation = getattr(self, "simulation", None)
if simulation is not None:
if not hasattr(simulation, "_user_input_keys"):
simulation._user_input_keys = set()
simulation._user_input_keys.add((self.variable.name, branch_name, period))
if self.variable.set_input and period.unit != self.variable.definition_period:
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)
if not hasattr(simulation, "_user_input_contexts"):
simulation._user_input_contexts = []
simulation._user_input_contexts.append(branch_name)
try:
if (
self.variable.set_input
and period.unit != self.variable.definition_period
):
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)
finally:
if simulation is not None:
simulation._user_input_contexts.pop()

def _to_array(self, value: Any) -> ArrayLike:
if not isinstance(value, numpy.ndarray):
Expand Down Expand Up @@ -295,6 +306,10 @@ def _to_array(self, value: Any) -> ArrayLike:
def _set(
self, period: Period, value: ArrayLike, branch_name: str = "default"
) -> None:
simulation = getattr(self, "simulation", None)
user_input_contexts = getattr(simulation, "_user_input_contexts", None)
if user_input_contexts and branch_name == "default":
branch_name = user_input_contexts[-1]
value = self._to_array(value)
if self.variable.definition_period != periods.ETERNITY:
if period is None:
Expand All @@ -313,6 +328,10 @@ def _set(
self._disk_storage.put(value, period, branch_name)
else:
self._memory_storage.put(value, period, branch_name)
if user_input_contexts:
if not hasattr(simulation, "_user_input_keys"):
simulation._user_input_keys = set()
simulation._user_input_keys.add((self.variable.name, branch_name, period))

def put_in_cache(
self, value: ArrayLike, period: Period, branch_name: str = "default"
Expand Down
Loading
Loading