From bc754a46e8f668d939e7f37a22e71fbc15fca7da Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 14 Mar 2026 07:50:58 -0700 Subject: [PATCH] Fix _fast_cache invalidation bug and improve cache key performance set_input() was not clearing _fast_cache entries, causing stale values to be returned on subsequent calculate() calls. Also switches cache keys from str(period) to Period objects (already hashable tuples) to avoid unnecessary string conversions, deduplicates enum LUT and structured dtype unification logic in vectorial parameter lookups, and uses the public trace property instead of the private _trace attr. Co-Authored-By: Claude Opus 4.6 --- changelog.d/fix-fast-cache-bugs.fixed.md | 1 + .../vectorial_parameter_node_at_instant.py | 111 +++++++-------- policyengine_core/simulations/simulation.py | 13 +- tests/core/test_fast_cache.py | 132 ++++++++++++++++++ 4 files changed, 191 insertions(+), 66 deletions(-) create mode 100644 changelog.d/fix-fast-cache-bugs.fixed.md create mode 100644 tests/core/test_fast_cache.py diff --git a/changelog.d/fix-fast-cache-bugs.fixed.md b/changelog.d/fix-fast-cache-bugs.fixed.md new file mode 100644 index 00000000..2b9eda51 --- /dev/null +++ b/changelog.d/fix-fast-cache-bugs.fixed.md @@ -0,0 +1 @@ +Fix _fast_cache invalidation bug in set_input and add cache tests. diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 6a49b5ba..fcc562ef 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -12,6 +12,45 @@ from policyengine_core.parameters.parameter_node import ParameterNode +def _build_enum_lut(enum, name_to_child_idx, sentinel, stringify_names=False): + """Build a lookup table mapping enum int codes to child indices.""" + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, sentinel, dtype=numpy.intp) + for item in enum_items: + name = str(item.name) if stringify_names else item.name + child_idx = name_to_child_idx.get(name) + if child_idx is not None: + lut[item.index] = child_idx + return lut + + +def _unify_structured_dtypes(values): + """Compute a unified dtype across structured arrays with potentially + different fields, and cast all values to that dtype. + + Returns (unified_dtype, all_fields, casted_values). + """ + all_fields = [] + seen = set() + for val in values: + for field in val.dtype.names: + if field not in seen: + all_fields.append(field) + seen.add(field) + + unified_dtype = numpy.dtype([(f, " Any: self._enum_lut_cache = {} lut = self._enum_lut_cache.get(cache_key) if lut is None: - enum_items = list(enum) - max_code = max(item.index for item in enum_items) + 1 - lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) - for item in enum_items: - child_idx = name_to_child_idx.get(item.name) - if child_idx is not None: - lut[item.index] = child_idx + lut = _build_enum_lut(enum, name_to_child_idx, SENTINEL) self._enum_lut_cache[cache_key] = lut idx = lut[numpy.asarray(key)] elif ( @@ -224,13 +257,9 @@ def __getitem__(self, key: str) -> Any: self._enum_lut_cache = {} lut = self._enum_lut_cache.get(cache_key) if lut is None: - enum_items = list(enum) - max_code = max(item.index for item in enum_items) + 1 - lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) - for item in enum_items: - child_idx = name_to_child_idx.get(str(item.name)) - if child_idx is not None: - lut[item.index] = child_idx + lut = _build_enum_lut( + enum, name_to_child_idx, SENTINEL, stringify_names=True + ) self._enum_lut_cache[cache_key] = lut codes = numpy.array([v.index for v in key], dtype=numpy.intp) idx = lut[codes] @@ -262,23 +291,9 @@ def __getitem__(self, key: str) -> Any: if v0_len <= 1: # 1-element structured arrays: simple concat + index if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - - unified_dtype = numpy.dtype([(f, " Any: # Nested structured: fall back to numpy.select conditions = [idx == i for i in range(len(values))] if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - unified_dtype = numpy.dtype( - [(f, " Any: else: # Flat structured: fast per-field indexing if not dtypes_match: - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - unified_dtype = numpy.dtype( - [(f, " ArrayLike: smc.set_cache_value(cache_path, array) if hasattr(self, "_fast_cache"): - self._fast_cache[(variable_name, str(period))] = array + self._fast_cache[(variable_name, period)] = array return array @@ -776,7 +776,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) - self._fast_cache.pop((_name, str(_period)), None) + self._fast_cache.pop((_name, _period), None) self.invalidated_caches = set() def calculate_add( @@ -1142,7 +1142,9 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: k: v for k, v in self._fast_cache.items() if k[0] != variable } else: - self._fast_cache.pop((variable, str(period)), None) + if not isinstance(period, Period): + period = periods.period(period) + self._fast_cache.pop((variable, period), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1187,6 +1189,7 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value, self.branch_name) + self._fast_cache.pop((variable_name, period), None) def get_variable_population(self, variable_name: str) -> Population: variable = self.tax_benefit_system.get_variable( diff --git a/tests/core/test_fast_cache.py b/tests/core/test_fast_cache.py new file mode 100644 index 00000000..a60a3d9e --- /dev/null +++ b/tests/core/test_fast_cache.py @@ -0,0 +1,132 @@ +"""Tests for the _fast_cache mechanism in Simulation.""" + +import numpy as np +from policyengine_core.simulations import SimulationBuilder + + +def _make_simulation(tax_benefit_system, salary=3000): + """Build a simple simulation with one person and a salary.""" + return SimulationBuilder().build_from_entities( + tax_benefit_system, + { + "persons": { + "bill": {"salary": {"2017-01": salary}}, + }, + "households": {"household": {"parents": ["bill"]}}, + }, + ) + + +def test_fast_cache_returns_cached_value(tax_benefit_system): + """Second calculate() for a computed variable should return + the cached value from _fast_cache without recomputing.""" + sim = _make_simulation(tax_benefit_system) + + # income_tax is computed via formula, so it enters _fast_cache + result1 = sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + result2 = sim.calculate("income_tax", "2017-01") + # Must be the exact same object (identity check proves cache hit) + assert result1 is result2 + + +def test_fast_cache_invalidated_after_set_input(tax_benefit_system): + """set_input() must evict the stale _fast_cache entry so the next + calculate() returns the new value.""" + sim = _make_simulation(tax_benefit_system) + + # Populate the cache with a computed variable + result1 = sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + old_val = result1[0] + + # Overwrite income_tax with a direct value + sim.set_input("income_tax", "2017-01", np.array([9999.0])) + + # The cache entry for income_tax must be gone + result2 = sim.calculate("income_tax", "2017-01") + assert np.isclose(result2[0], 9999.0), ( + f"Expected 9999.0 after set_input, got {result2[0]} (stale cache bug)" + ) + + +def test_fast_cache_invalidated_after_delete_arrays_with_period( + tax_benefit_system, +): + """delete_arrays(variable, period) must evict that specific + _fast_cache entry.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + sim.delete_arrays("income_tax", "2017-01") + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_invalidated_after_delete_arrays_all_periods( + tax_benefit_system, +): + """delete_arrays(variable) with no period must evict ALL + _fast_cache entries for that variable.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + sim.delete_arrays("income_tax") + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_empty_after_clone(tax_benefit_system): + """clone() must produce a simulation with an empty _fast_cache.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + cloned = sim.clone() + assert len(cloned._fast_cache) == 0 + + +def test_fast_cache_invalidated_after_purge_cache(tax_benefit_system): + """purge_cache_of_invalid_values() must remove entries listed in + invalidated_caches from _fast_cache.""" + sim = _make_simulation(tax_benefit_system) + + sim.calculate("income_tax", "2017-01") + assert len(sim._fast_cache) > 0 + + # Manually mark the entry as invalidated (simulating what the + # framework does during dependency tracking) + from policyengine_core.periods import period as make_period + + sim.invalidated_caches.add(("income_tax", make_period("2017-01"))) + # The stack must be empty for purge to fire + sim.tracer._stack.clear() + sim.purge_cache_of_invalid_values() + + matching = [k for k in sim._fast_cache if k[0] == "income_tax"] + assert len(matching) == 0 + + +def test_fast_cache_uses_period_not_str_as_key(tax_benefit_system): + """_fast_cache keys should use Period objects, not str(period), + to avoid unnecessary string conversions.""" + sim = _make_simulation(tax_benefit_system) + sim.calculate("income_tax", "2017-01") + + # All keys should have Period as second element, not str + for key in sim._fast_cache: + variable_name, period_key = key + assert not isinstance(period_key, str), ( + f"Expected Period as cache key, got str: {period_key!r}" + ) + assert isinstance(period_key, tuple), ( + f"Period should be a tuple subclass, got {type(period_key)}" + )