Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-fast-cache-bugs.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix _fast_cache invalidation bug in set_input and add cache tests.
111 changes: 50 additions & 61 deletions policyengine_core/parameters/vectorial_parameter_node_at_instant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,45 @@
from policyengine_core.parameters.parameter_node import ParameterNode


def _build_enum_lut(enum, name_to_child_idx, sentinel, stringify_names=False):
"""Build a lookup table mapping enum int codes to child indices."""
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, sentinel, dtype=numpy.intp)
for item in enum_items:
name = str(item.name) if stringify_names else item.name
child_idx = name_to_child_idx.get(name)
if child_idx is not None:
lut[item.index] = child_idx
return lut


def _unify_structured_dtypes(values):
"""Compute a unified dtype across structured arrays with potentially
different fields, and cast all values to that dtype.

Returns (unified_dtype, all_fields, casted_values).
"""
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])

casted_values = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
casted_values.append(casted)

return unified_dtype, all_fields, casted_values


class VectorialParameterNodeAtInstant:
"""
Parameter node of the legislation at a given instant which has been vectorized.
Expand Down Expand Up @@ -205,13 +244,7 @@ def __getitem__(self, key: str) -> Any:
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(item.name)
if child_idx is not None:
lut[item.index] = child_idx
lut = _build_enum_lut(enum, name_to_child_idx, SENTINEL)
self._enum_lut_cache[cache_key] = lut
idx = lut[numpy.asarray(key)]
elif (
Expand All @@ -224,13 +257,9 @@ def __getitem__(self, key: str) -> Any:
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(str(item.name))
if child_idx is not None:
lut[item.index] = child_idx
lut = _build_enum_lut(
enum, name_to_child_idx, SENTINEL, stringify_names=True
)
self._enum_lut_cache[cache_key] = lut
codes = numpy.array([v.index for v in key], dtype=numpy.intp)
idx = lut[codes]
Expand Down Expand Up @@ -262,23 +291,9 @@ def __getitem__(self, key: str) -> Any:
if v0_len <= 1:
# 1-element structured arrays: simple concat + index
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])

values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)

unified_dtype, all_fields, values_cast = (
_unify_structured_dtypes(values)
)
default = numpy.zeros(1, dtype=unified_dtype)
for field in unified_dtype.names:
default[field] = numpy.nan
Expand All @@ -302,22 +317,9 @@ def __getitem__(self, key: str) -> Any:
# Nested structured: fall back to numpy.select
conditions = [idx == i for i in range(len(values))]
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
unified_dtype, all_fields, values_cast = (
_unify_structured_dtypes(values)
)
values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)
default = numpy.zeros(v0_len, dtype=unified_dtype)
for field in unified_dtype.names:
default[field] = numpy.nan
Expand All @@ -328,22 +330,9 @@ def __getitem__(self, key: str) -> Any:
else:
# Flat structured: fast per-field indexing
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
unified_dtype, all_fields, values_unified = (
_unify_structured_dtypes(values)
)
values_unified = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_unified.append(casted)
field_names = all_fields
result_dtype = unified_dtype
else:
Expand Down
13 changes: 8 additions & 5 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ def calculate(
# Fast path: skip tracer, random seed and all _calculate() machinery for
# already-computed values. map_to and decode_enums are NOT cached here —
# they are post-processing steps that vary per call site.
if map_to is None and not decode_enums and not getattr(self, "_trace", False):
_fast_key = (variable_name, str(period))
if map_to is None and not decode_enums and not getattr(self, "trace", False):
_fast_key = (variable_name, period)
_fast_cache = getattr(self, "_fast_cache", None)
if _fast_cache is not None:
_cached = _fast_cache.get(_fast_key)
Expand Down Expand Up @@ -765,7 +765,7 @@ def _calculate(self, variable_name: str, period: Period = None) -> ArrayLike:
smc.set_cache_value(cache_path, array)

if hasattr(self, "_fast_cache"):
self._fast_cache[(variable_name, str(period))] = array
self._fast_cache[(variable_name, period)] = array

return array

Expand All @@ -776,7 +776,7 @@ def purge_cache_of_invalid_values(self) -> None:
for _name, _period in self.invalidated_caches:
holder = self.get_holder(_name)
holder.delete_arrays(_period)
self._fast_cache.pop((_name, str(_period)), None)
self._fast_cache.pop((_name, _period), None)
self.invalidated_caches = set()

def calculate_add(
Expand Down Expand Up @@ -1142,7 +1142,9 @@ def delete_arrays(self, variable: str, period: Period = None) -> None:
k: v for k, v in self._fast_cache.items() if k[0] != variable
}
else:
self._fast_cache.pop((variable, str(period)), None)
if not isinstance(period, Period):
period = periods.period(period)
self._fast_cache.pop((variable, period), None)

def get_known_periods(self, variable: str) -> List[Period]:
"""
Expand Down Expand Up @@ -1187,6 +1189,7 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non
if (variable.end is not None) and (period.start.date > variable.end):
return
self.get_holder(variable_name).set_input(period, value, self.branch_name)
self._fast_cache.pop((variable_name, period), None)

def get_variable_population(self, variable_name: str) -> Population:
variable = self.tax_benefit_system.get_variable(
Expand Down
132 changes: 132 additions & 0 deletions tests/core/test_fast_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Tests for the _fast_cache mechanism in Simulation."""

import numpy as np
from policyengine_core.simulations import SimulationBuilder


def _make_simulation(tax_benefit_system, salary=3000):
"""Build a simple simulation with one person and a salary."""
return SimulationBuilder().build_from_entities(
tax_benefit_system,
{
"persons": {
"bill": {"salary": {"2017-01": salary}},
},
"households": {"household": {"parents": ["bill"]}},
},
)


def test_fast_cache_returns_cached_value(tax_benefit_system):
"""Second calculate() for a computed variable should return
the cached value from _fast_cache without recomputing."""
sim = _make_simulation(tax_benefit_system)

# income_tax is computed via formula, so it enters _fast_cache
result1 = sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0

result2 = sim.calculate("income_tax", "2017-01")
# Must be the exact same object (identity check proves cache hit)
assert result1 is result2


def test_fast_cache_invalidated_after_set_input(tax_benefit_system):
"""set_input() must evict the stale _fast_cache entry so the next
calculate() returns the new value."""
sim = _make_simulation(tax_benefit_system)

# Populate the cache with a computed variable
result1 = sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0
old_val = result1[0]

# Overwrite income_tax with a direct value
sim.set_input("income_tax", "2017-01", np.array([9999.0]))

# The cache entry for income_tax must be gone
result2 = sim.calculate("income_tax", "2017-01")
assert np.isclose(result2[0], 9999.0), (
f"Expected 9999.0 after set_input, got {result2[0]} (stale cache bug)"
)


def test_fast_cache_invalidated_after_delete_arrays_with_period(
tax_benefit_system,
):
"""delete_arrays(variable, period) must evict that specific
_fast_cache entry."""
sim = _make_simulation(tax_benefit_system)

sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0

sim.delete_arrays("income_tax", "2017-01")

matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
assert len(matching) == 0


def test_fast_cache_invalidated_after_delete_arrays_all_periods(
tax_benefit_system,
):
"""delete_arrays(variable) with no period must evict ALL
_fast_cache entries for that variable."""
sim = _make_simulation(tax_benefit_system)

sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0

sim.delete_arrays("income_tax")

matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
assert len(matching) == 0


def test_fast_cache_empty_after_clone(tax_benefit_system):
"""clone() must produce a simulation with an empty _fast_cache."""
sim = _make_simulation(tax_benefit_system)

sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0

cloned = sim.clone()
assert len(cloned._fast_cache) == 0


def test_fast_cache_invalidated_after_purge_cache(tax_benefit_system):
"""purge_cache_of_invalid_values() must remove entries listed in
invalidated_caches from _fast_cache."""
sim = _make_simulation(tax_benefit_system)

sim.calculate("income_tax", "2017-01")
assert len(sim._fast_cache) > 0

# Manually mark the entry as invalidated (simulating what the
# framework does during dependency tracking)
from policyengine_core.periods import period as make_period

sim.invalidated_caches.add(("income_tax", make_period("2017-01")))
# The stack must be empty for purge to fire
sim.tracer._stack.clear()
sim.purge_cache_of_invalid_values()

matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
assert len(matching) == 0


def test_fast_cache_uses_period_not_str_as_key(tax_benefit_system):
"""_fast_cache keys should use Period objects, not str(period),
to avoid unnecessary string conversions."""
sim = _make_simulation(tax_benefit_system)
sim.calculate("income_tax", "2017-01")

# All keys should have Period as second element, not str
for key in sim._fast_cache:
variable_name, period_key = key
assert not isinstance(period_key, str), (
f"Expected Period as cache key, got str: {period_key!r}"
)
assert isinstance(period_key, tuple), (
f"Period should be a tuple subclass, got {type(period_key)}"
)
Loading