From 647f3df22e6cd4656d589177b178631b55380ea2 Mon Sep 17 00:00:00 2001 From: "nikhil@policyengine.org" Date: Wed, 18 Feb 2026 09:49:05 +0000 Subject: [PATCH 1/6] Add _fast_cache to Simulation for O(1) repeated variable lookups Adds a flat dict[tuple[str,str], array] at the Simulation level, checked at the top of calculate() before tracer, random seed and _calculate() machinery. Only active when map_to=None and decode_enums=False (the inner-loop hot path). Invalidation mirrors the existing holder cache: - purge_cache_of_invalid_values() removes invalidated entries - delete_arrays() removes the relevant key(s) - clone() gets a fresh empty cache to prevent cross-simulation sharing Uses getattr/hasattr guards so StubSimulation and other test subclasses that bypass __init__ work without modification. Co-Authored-By: Claude --- policyengine_core/simulations/simulation.py | 25 ++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 8f79c3b7..9ca5c167 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -118,6 +118,7 @@ def __init__( self.is_over_dataset = dataset is not None self.invalidated_caches = set() + self._fast_cache: dict = {} self.debug: bool = False self.trace: bool = trace self.tracer: SimpleTracer = SimpleTracer() if not trace else FullTracer() @@ -448,6 +449,17 @@ def calculate( elif period is None and self.default_calculation_period is not None: period = periods.period(self.default_calculation_period) + # Fast path: skip tracer, random seed and all _calculate() machinery for + # already-computed values. map_to and decode_enums are NOT cached here — + # they are post-processing steps that vary per call site. + if map_to is None and not decode_enums: + _fast_key = (variable_name, str(period)) + _fast_cache = getattr(self, "_fast_cache", None) + if _fast_cache is not None: + _cached = _fast_cache.get(_fast_key) + if _cached is not None: + return _cached + self.tracer.record_calculation_start(variable_name, period, self.branch_name) np.random.seed(hash(variable_name + str(period)) % 1000000) @@ -752,6 +764,9 @@ def _calculate(self, variable_name: str, period: Period = None) -> ArrayLike: if is_cache_available: smc.set_cache_value(cache_path, array) + if hasattr(self, "_fast_cache"): + self._fast_cache[(variable_name, str(period))] = array + return array def purge_cache_of_invalid_values(self) -> None: @@ -761,6 +776,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) + self._fast_cache.pop((_name, str(_period)), None) self.invalidated_caches = set() def calculate_add( @@ -1121,6 +1137,12 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: True """ self.get_holder(variable).delete_arrays(period) + if period is None: + self._fast_cache = { + k: v for k, v in self._fast_cache.items() if k[0] != variable + } + else: + self._fast_cache.pop((variable, str(period)), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1205,8 +1227,9 @@ def clone( new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ("debug", "trace", "tracer", "branches"): + if key not in ("debug", "trace", "tracer", "branches", "_fast_cache"): new_dict[key] = value + new._fast_cache = {} new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons) From 9961d422408763835d19c1b83d8e3aa56de9138b Mon Sep 17 00:00:00 2001 From: "nikhil@policyengine.org" Date: Wed, 18 Feb 2026 09:59:24 +0000 Subject: [PATCH 2/6] Fix black formatting --- policyengine_core/simulations/simulation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 9ca5c167..5afe0c03 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1227,7 +1227,13 @@ def clone( new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ("debug", "trace", "tracer", "branches", "_fast_cache"): + if key not in ( + "debug", + "trace", + "tracer", + "branches", + "_fast_cache", + ): new_dict[key] = value new._fast_cache = {} From 4394e09d82da238e568243af0d51af81138cac96 Mon Sep 17 00:00:00 2001 From: "nikhil@policyengine.org" Date: Wed, 18 Feb 2026 10:08:24 +0000 Subject: [PATCH 3/6] Fix _fast_cache bypassing tracer when trace=True Skip the fast path when tracing is enabled, so FullTracer records all calculations correctly. Co-Authored-By: Claude --- policyengine_core/simulations/simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 5afe0c03..51bd45f1 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -452,7 +452,7 @@ def calculate( # Fast path: skip tracer, random seed and all _calculate() machinery for # already-computed values. map_to and decode_enums are NOT cached here — # they are post-processing steps that vary per call site. - if map_to is None and not decode_enums: + if map_to is None and not decode_enums and not self.trace: _fast_key = (variable_name, str(period)) _fast_cache = getattr(self, "_fast_cache", None) if _fast_cache is not None: From 450f89ece886f3852f9b4a293e9b64a07cdf66ff Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 8 Mar 2026 23:08:09 +0000 Subject: [PATCH 4/6] Optimise vectorial parameter lookups for US simulation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace O(N×K) numpy.select with O(N) index-based selection in VectorialParameterNodeAtInstant.__getitem__. For enum/EnumArray keys, build a lookup table mapping integer codes directly to child indices, avoiding the intermediate string conversion entirely. For string keys, use numpy.unique to reduce N×K string comparisons to U dict lookups (where U = unique keys, typically ≪ N). Also cache build_from_node results on ParameterNodeAtInstant to avoid rebuilding the recarray on every vectorial access. US household_net_income compute: 12.8s → 9.0s (-30%). Co-Authored-By: Claude --- .../parameters/parameter_node_at_instant.py | 12 +- .../vectorial_parameter_node_at_instant.py | 265 +++++++++++++----- 2 files changed, 213 insertions(+), 64 deletions(-) diff --git a/policyengine_core/parameters/parameter_node_at_instant.py b/policyengine_core/parameters/parameter_node_at_instant.py index 343e152d..b55889eb 100644 --- a/policyengine_core/parameters/parameter_node_at_instant.py +++ b/policyengine_core/parameters/parameter_node_at_instant.py @@ -57,7 +57,17 @@ def __getitem__( if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray): key = numpy.asarray(key) if isinstance(key, numpy.ndarray): - return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key] + # Cache the vectorial node to avoid rebuilding the recarray on + # every call -- build_from_node is expensive (walks the full + # parameter subtree each time). + try: + vectorial = self._vectorial_node + except AttributeError: + vectorial = parameters.VectorialParameterNodeAtInstant.build_from_node( + self + ) + self._vectorial_node = vectorial + return vectorial[key] return self._children[key] def __iter__(self) -> Iterable: diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 514b1714..70f52d60 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -181,84 +181,223 @@ def __getitem__(self, key: str) -> Any: return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] # Convert pandas arrays (e.g., StringArray from pandas 3) to numpy + # before checking, since StringArray has __array__ but is not hashable if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray): key = numpy.asarray(key) if isinstance(key, numpy.ndarray): - if not numpy.issubdtype(key.dtype, numpy.str_): - # In case the key is not a string vector, stringify it - if key.dtype == object and issubclass(type(key[0]), Enum): - enum = type(key[0]) - key = numpy.select( - [key == item for item in enum], - [str(item.name) for item in enum], - default="unknown", - ) - elif isinstance(key, EnumArray): - enum = key.possible_values - key = numpy.select( - [key == item.index for item in enum], - [item.name for item in enum], - default="unknown", - ) - else: + names = self.dtype.names + # Build name->child-index mapping (cached on instance) + if not hasattr(self, "_name_to_child_idx"): + self._name_to_child_idx = {name: i for i, name in enumerate(names)} + + name_to_child_idx = self._name_to_child_idx + n = len(key) + SENTINEL = len(names) + + # Convert key to integer indices directly, avoiding + # expensive intermediate string arrays where possible. + if isinstance(key, EnumArray): + # EnumArray: map enum int codes -> child indices via + # a pre-built lookup table (O(N), no string comparison). + enum = key.possible_values + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(item.name) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + idx = lut[numpy.asarray(key)] + elif ( + key.dtype == object + and len(key) > 0 + and issubclass(type(key[0]), Enum) + ): + # Object array of Enum instances + enum = type(key[0]) + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(str(item.name)) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + codes = numpy.array([v.index for v in key], dtype=numpy.intp) + idx = lut[codes] + else: + # String keys: map via dict lookup + if not numpy.issubdtype(key.dtype, numpy.str_): key = key.astype("str") - names = list( - self.dtype.names - ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] - conditions = [key == name for name in names] + # Vectorised dict lookup using numpy unique + scatter + uniq, inverse = numpy.unique(key, return_inverse=True) + uniq_idx = numpy.array( + [name_to_child_idx.get(u, SENTINEL) for u in uniq], + dtype=numpy.intp, + ) + idx = uniq_idx[inverse] + + # Gather values by child index using take on a stacked array. values = [self.vector[name] for name in names] - # NumPy 2.x requires all arrays in numpy.select to have identical dtypes - # For structured arrays with different field sets, we need to normalize them - if ( + is_structured = ( len(values) > 0 and hasattr(values[0].dtype, "names") and values[0].dtype.names - ): - # Check if all values have the same dtype + ) + + if is_structured: dtypes_match = all(val.dtype == values[0].dtype for val in values) + v0_len = len(values[0]) + + if v0_len <= 1: + # 1-element structured arrays: simple concat + index + if not dtypes_match: + all_fields = [] + seen = set() + for val in values: + for field in val.dtype.names: + if field not in seen: + all_fields.append(field) + seen.add(field) - if not dtypes_match: - # Find the union of all field names across all values, preserving first seen order - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - - # Create unified dtype with all fields - unified_dtype = numpy.dtype([(f, " Date: Sat, 14 Mar 2026 07:19:46 -0700 Subject: [PATCH 5/6] Apply ruff formatting and add changelog fragment Co-Authored-By: Claude Opus 4.6 --- changelog.d/perf-fast-cache.changed.md | 1 + .../vectorial_parameter_node_at_instant.py | 12 +++--------- 2 files changed, 4 insertions(+), 9 deletions(-) create mode 100644 changelog.d/perf-fast-cache.changed.md diff --git a/changelog.d/perf-fast-cache.changed.md b/changelog.d/perf-fast-cache.changed.md new file mode 100644 index 00000000..d5db5557 --- /dev/null +++ b/changelog.d/perf-fast-cache.changed.md @@ -0,0 +1 @@ +Optimize simulation calculate with fast cache and vectorial parameter lookups. diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 70f52d60..1b67b908 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -215,9 +215,7 @@ def __getitem__(self, key: str) -> Any: self._enum_lut_cache[cache_key] = lut idx = lut[numpy.asarray(key)] elif ( - key.dtype == object - and len(key) > 0 - and issubclass(type(key[0]), Enum) + key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum) ): # Object array of Enum instances enum = type(key[0]) @@ -371,9 +369,7 @@ def __getitem__(self, key: str) -> Any: v0 = numpy.asarray(values[0]) if v0.ndim == 0 or v0.shape[0] <= 1: # Scalar per child: 1D lookup - scalar_vals = numpy.empty( - len(values) + 1, dtype=numpy.float64 - ) + scalar_vals = numpy.empty(len(values) + 1, dtype=numpy.float64) for i, v in enumerate(values): scalar_vals[i] = float(v) scalar_vals[-1] = numpy.nan @@ -381,9 +377,7 @@ def __getitem__(self, key: str) -> Any: else: # N-element vectors: stack into (K+1, N) matrix m = v0.shape[0] - stacked = numpy.empty( - (len(values) + 1, m), dtype=numpy.float64 - ) + stacked = numpy.empty((len(values) + 1, m), dtype=numpy.float64) for i, v in enumerate(values): stacked[i] = v stacked[-1] = numpy.nan From 4a79aefaf9671c715c5e93b5cd2000985ab2f9a4 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 14 Mar 2026 07:28:26 -0700 Subject: [PATCH 6/6] Fix test failures from fast cache and vectorial optimizations - Use getattr for _trace in fast cache guard to handle StubSimulation subclasses that bypass Simulation.__init__ - Fix float() conversion for recarray elements in scalar lookup path Co-Authored-By: Claude Opus 4.6 --- .../parameters/vectorial_parameter_node_at_instant.py | 2 +- policyengine_core/simulations/simulation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 1b67b908..6a49b5ba 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -371,7 +371,7 @@ def __getitem__(self, key: str) -> Any: # Scalar per child: 1D lookup scalar_vals = numpy.empty(len(values) + 1, dtype=numpy.float64) for i, v in enumerate(values): - scalar_vals[i] = float(v) + scalar_vals[i] = float(numpy.asarray(v).flat[0]) scalar_vals[-1] = numpy.nan result = scalar_vals[idx] else: diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 51bd45f1..b41f2fcb 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -452,7 +452,7 @@ def calculate( # Fast path: skip tracer, random seed and all _calculate() machinery for # already-computed values. map_to and decode_enums are NOT cached here — # they are post-processing steps that vary per call site. - if map_to is None and not decode_enums and not self.trace: + if map_to is None and not decode_enums and not getattr(self, "_trace", False): _fast_key = (variable_name, str(period)) _fast_cache = getattr(self, "_fast_cache", None) if _fast_cache is not None: