diff --git a/changelog.d/fix-fast-cache-bugs.fixed.md b/changelog.d/fix-fast-cache-bugs.fixed.md new file mode 100644 index 00000000..2b9eda51 --- /dev/null +++ b/changelog.d/fix-fast-cache-bugs.fixed.md @@ -0,0 +1 @@ +Fix _fast_cache invalidation bug in set_input and add cache tests. diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 6a49b5ba..fcc562ef 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -12,6 +12,45 @@ from policyengine_core.parameters.parameter_node import ParameterNode +def _build_enum_lut(enum, name_to_child_idx, sentinel, stringify_names=False): + """Build a lookup table mapping enum int codes to child indices.""" + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, sentinel, dtype=numpy.intp) + for item in enum_items: + name = str(item.name) if stringify_names else item.name + child_idx = name_to_child_idx.get(name) + if child_idx is not None: + lut[item.index] = child_idx + return lut + + +def _unify_structured_dtypes(values): + """Compute a unified dtype across structured arrays with potentially + different fields, and cast all values to that dtype. + + Returns (unified_dtype, all_fields, casted_values). + """ + all_fields = [] + seen = set() + for val in values: + for field in val.dtype.names: + if field not in seen: + all_fields.append(field) + seen.add(field) + + unified_dtype = numpy.dtype([(f, " Any: self._enum_lut_cache = {} lut = self._enum_lut_cache.get(cache_key) if lut is None: - enum_items = list(enum) - max_code = max(item.index for item in enum_items) + 1 - lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) - for item in enum_items: - child_idx = name_to_child_idx.get(item.name) - if child_idx is not None: - lut[item.index] = child_idx + lut = _build_enum_lut(enum, name_to_child_idx, SENTINEL) self._enum_lut_cache[cache_key] = lut idx = lut[numpy.asarray(key)] elif ( @@ -224,13 +257,9 @@ def __getitem__(self, key: str) -> Any: self._enum_lut_cache = {} lut = self._enum_lut_cache.get(cache_key) if lut is None: - enum_items = list(enum) - max_code = max(item.index for item in enum_items) + 1 - lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) - for item in enum_items: - child_idx = name_to_child_idx.get(str(item.name)) - if child_idx is not None: - lut[item.index] = child_idx + lut = _build_enum_lut( + enum, name_to_child_idx, SENTINEL, stringify_names=True + ) self._enum_lut_cache[cache_key] = lut codes = numpy.array([v.index for v in key], dtype=numpy.intp) idx = lut[codes] @@ -262,23 +291,9 @@ def __getitem__(self, key: str) -> Any: if v0_len <= 1: # 1-element structured arrays: simple concat + index if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - - unified_dtype = numpy.dtype([(f, " Any: # Nested structured: fall back to numpy.select conditions = [idx == i for i in range(len(values))] if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - unified_dtype = numpy.dtype( - [(f, " Any: else: # Flat structured: fast per-field indexing if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - unified_dtype = numpy.dtype( - [(f, " ArrayLike: smc.set_cache_value(cache_path, array) if hasattr(self, "_fast_cache"): - self._fast_cache[(variable_name, str(period))] = array + self._fast_cache[(variable_name, period)] = array return array @@ -776,7 +776,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) - self._fast_cache.pop((_name, str(_period)), None) + self._fast_cache.pop((_name, _period), None) self.invalidated_caches = set() def calculate_add( @@ -1142,7 +1142,9 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: k: v for k, v in self._fast_cache.items() if k[0] != variable } else: - self._fast_cache.pop((variable, str(period)), None) + if not isinstance(period, Period): + period = periods.period(period) + self._fast_cache.pop((variable, period), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1187,6 +1189,7 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value, self.branch_name) + self._fast_cache.pop((variable_name, period), None) def get_variable_population(self, variable_name: str) -> Population: variable = self.tax_benefit_system.get_variable( diff --git a/tests/core/test_fast_cache.py b/tests/core/test_fast_cache.py new file mode 100644 index 00000000..a60a3d9e --- /dev/null +++ b/tests/core/test_fast_cache.py @@ -0,0 +1,132 @@ +"""Tests for the _fast_cache mechanism in Simulation.""" + +import numpy as np +from policyengine_core.simulations import SimulationBuilder + + +def _make_simulation(tax_benefit_system, salary=3000): + """Build a simple simulation with one person and a salary.""" + return SimulationBuilder().build_from_entities( + tax_benefit_system, + { + "persons": { + "bill": {"salary": {"2017-01": salary}}, + }, + "households": {"household": {"parents": ["bill"]}}, + }, + ) + + +def test_fast_cache_returns_cached_value(tax_benefit_system): + """Second calculate() for a computed variable should return + the cached value from _fast_cache without recomputing.""" + sim = _make_simulation(tax_benefit_system) + + # income_tax is computed via formula, so it enters _fast_cache + result1 = sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + result2 = sim.calculate("income_tax", "2017-01") + # Must be the exact same object (identity check proves cache hit) + assert result1 is result2 + + +def test_fast_cache_invalidated_after_set_input(tax_benefit_system): + """set_input() must evict the stale _fast_cache entry so the next + calculate() returns the new value.""" + sim = _make_simulation(tax_benefit_system) + + # Populate the cache with a computed variable + result1 = sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + old_val = result1[0] + + # Overwrite income_tax with a direct value + sim.set_input("income_tax", "2017-01", np.array([9999.0])) + + # The cache entry for income_tax must be gone + result2 = sim.calculate("income_tax", "2017-01") + assert np.isclose(result2[0], 9999.0), ( + f"Expected 9999.0 after set_input, got {result2[0]} (stale cache bug)" + ) + + +def test_fast_cache_invalidated_after_delete_arrays_with_period( + tax_benefit_system, +): + """delete_arrays(variable, period) must evict that specific + _fast_cache entry.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + sim.delete_arrays("income_tax", "2017-01") + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_invalidated_after_delete_arrays_all_periods( + tax_benefit_system, +): + """delete_arrays(variable) with no period must evict ALL + _fast_cache entries for that variable.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + sim.delete_arrays("income_tax") + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_empty_after_clone(tax_benefit_system): + """clone() must produce a simulation with an empty _fast_cache.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + cloned = sim.clone() + assert len(cloned._fast_cache) == 0 + + +def test_fast_cache_invalidated_after_purge_cache(tax_benefit_system): + """purge_cache_of_invalid_values() must remove entries listed in + invalidated_caches from _fast_cache.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + # Manually mark the entry as invalidated (simulating what the + # framework does during dependency tracking) + from policyengine_core.periods import period as make_period + + sim.invalidated_caches.add(("income_tax", make_period("2017-01"))) + # The stack must be empty for purge to fire + sim.tracer._stack.clear() + sim.purge_cache_of_invalid_values() + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_uses_period_not_str_as_key(tax_benefit_system): + """_fast_cache keys should use Period objects, not str(period), + to avoid unnecessary string conversions.""" + sim = _make_simulation(tax_benefit_system) + sim.calculate("income_tax", "2017-01") + + # All keys should have Period as second element, not str + for key in sim._fast_cache: + variable_name, period_key = key + assert not isinstance(period_key, str), ( + f"Expected Period as cache key, got str: {period_key!r}" + ) + assert isinstance(period_key, tuple), ( + f"Period should be a tuple subclass, got {type(period_key)}" + )