diff --git a/changelog.d/perf-fast-cache.changed.md b/changelog.d/perf-fast-cache.changed.md new file mode 100644 index 00000000..d5db5557 --- /dev/null +++ b/changelog.d/perf-fast-cache.changed.md @@ -0,0 +1 @@ +Optimize simulation calculate with fast cache and vectorial parameter lookups. diff --git a/policyengine_core/parameters/parameter_node_at_instant.py b/policyengine_core/parameters/parameter_node_at_instant.py index 343e152d..b55889eb 100644 --- a/policyengine_core/parameters/parameter_node_at_instant.py +++ b/policyengine_core/parameters/parameter_node_at_instant.py @@ -57,7 +57,17 @@ def __getitem__( if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray): key = numpy.asarray(key) if isinstance(key, numpy.ndarray): - return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key] + # Cache the vectorial node to avoid rebuilding the recarray on + # every call -- build_from_node is expensive (walks the full + # parameter subtree each time). + try: + vectorial = self._vectorial_node + except AttributeError: + vectorial = parameters.VectorialParameterNodeAtInstant.build_from_node( + self + ) + self._vectorial_node = vectorial + return vectorial[key] return self._children[key] def __iter__(self) -> Iterable: diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 514b1714..6a49b5ba 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -181,84 +181,217 @@ def __getitem__(self, key: str) -> Any: return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] # Convert pandas arrays (e.g., StringArray from pandas 3) to numpy + # before checking, since StringArray has __array__ but is not hashable if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray): key = numpy.asarray(key) if isinstance(key, numpy.ndarray): - if not numpy.issubdtype(key.dtype, numpy.str_): - # In case the key is not a string vector, stringify it - if key.dtype == object and issubclass(type(key[0]), Enum): - enum = type(key[0]) - key = numpy.select( - [key == item for item in enum], - [str(item.name) for item in enum], - default="unknown", - ) - elif isinstance(key, EnumArray): - enum = key.possible_values - key = numpy.select( - [key == item.index for item in enum], - [item.name for item in enum], - default="unknown", - ) - else: + names = self.dtype.names + # Build name->child-index mapping (cached on instance) + if not hasattr(self, "_name_to_child_idx"): + self._name_to_child_idx = {name: i for i, name in enumerate(names)} + + name_to_child_idx = self._name_to_child_idx + n = len(key) + SENTINEL = len(names) + + # Convert key to integer indices directly, avoiding + # expensive intermediate string arrays where possible. + if isinstance(key, EnumArray): + # EnumArray: map enum int codes -> child indices via + # a pre-built lookup table (O(N), no string comparison). + enum = key.possible_values + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(item.name) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + idx = lut[numpy.asarray(key)] + elif ( + key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum) + ): + # Object array of Enum instances + enum = type(key[0]) + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(str(item.name)) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + codes = numpy.array([v.index for v in key], dtype=numpy.intp) + idx = lut[codes] + else: + # String keys: map via dict lookup + if not numpy.issubdtype(key.dtype, numpy.str_): key = key.astype("str") - names = list( - self.dtype.names - ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] - conditions = [key == name for name in names] + # Vectorised dict lookup using numpy unique + scatter + uniq, inverse = numpy.unique(key, return_inverse=True) + uniq_idx = numpy.array( + [name_to_child_idx.get(u, SENTINEL) for u in uniq], + dtype=numpy.intp, + ) + idx = uniq_idx[inverse] + + # Gather values by child index using take on a stacked array. values = [self.vector[name] for name in names] - # NumPy 2.x requires all arrays in numpy.select to have identical dtypes - # For structured arrays with different field sets, we need to normalize them - if ( + is_structured = ( len(values) > 0 and hasattr(values[0].dtype, "names") and values[0].dtype.names - ): - # Check if all values have the same dtype + ) + + if is_structured: dtypes_match = all(val.dtype == values[0].dtype for val in values) + v0_len = len(values[0]) + + if v0_len <= 1: + # 1-element structured arrays: simple concat + index + if not dtypes_match: + all_fields = [] + seen = set() + for val in values: + for field in val.dtype.names: + if field not in seen: + all_fields.append(field) + seen.add(field) - if not dtypes_match: - # Find the union of all field names across all values, preserving first seen order - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - - # Create unified dtype with all fields - unified_dtype = numpy.dtype([(f, " ArrayLike: if is_cache_available: smc.set_cache_value(cache_path, array) + if hasattr(self, "_fast_cache"): + self._fast_cache[(variable_name, str(period))] = array + return array def purge_cache_of_invalid_values(self) -> None: @@ -761,6 +776,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) + self._fast_cache.pop((_name, str(_period)), None) self.invalidated_caches = set() def calculate_add( @@ -1121,6 +1137,12 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: True """ self.get_holder(variable).delete_arrays(period) + if period is None: + self._fast_cache = { + k: v for k, v in self._fast_cache.items() if k[0] != variable + } + else: + self._fast_cache.pop((variable, str(period)), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1205,8 +1227,15 @@ def clone( new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ("debug", "trace", "tracer", "branches"): + if key not in ( + "debug", + "trace", + "tracer", + "branches", + "_fast_cache", + ): new_dict[key] = value + new._fast_cache = {} new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons)