Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/perf-fast-cache.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimize simulation calculate with fast cache and vectorial parameter lookups.
12 changes: 11 additions & 1 deletion policyengine_core/parameters/parameter_node_at_instant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,17 @@ def __getitem__(
if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray):
key = numpy.asarray(key)
if isinstance(key, numpy.ndarray):
return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key]
# Cache the vectorial node to avoid rebuilding the recarray on
# every call -- build_from_node is expensive (walks the full
# parameter subtree each time).
try:
vectorial = self._vectorial_node
except AttributeError:
vectorial = parameters.VectorialParameterNodeAtInstant.build_from_node(
self
)
self._vectorial_node = vectorial
return vectorial[key]
return self._children[key]

def __iter__(self) -> Iterable:
Expand Down
259 changes: 196 additions & 63 deletions policyengine_core/parameters/vectorial_parameter_node_at_instant.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,84 +181,217 @@ def __getitem__(self, key: str) -> Any:
return self.__getattr__(key)
# If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1']
# Convert pandas arrays (e.g., StringArray from pandas 3) to numpy
# before checking, since StringArray has __array__ but is not hashable
if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray):
key = numpy.asarray(key)
if isinstance(key, numpy.ndarray):
if not numpy.issubdtype(key.dtype, numpy.str_):
# In case the key is not a string vector, stringify it
if key.dtype == object and issubclass(type(key[0]), Enum):
enum = type(key[0])
key = numpy.select(
[key == item for item in enum],
[str(item.name) for item in enum],
default="unknown",
)
elif isinstance(key, EnumArray):
enum = key.possible_values
key = numpy.select(
[key == item.index for item in enum],
[item.name for item in enum],
default="unknown",
)
else:
names = self.dtype.names
# Build name->child-index mapping (cached on instance)
if not hasattr(self, "_name_to_child_idx"):
self._name_to_child_idx = {name: i for i, name in enumerate(names)}

name_to_child_idx = self._name_to_child_idx
n = len(key)
SENTINEL = len(names)

# Convert key to integer indices directly, avoiding
# expensive intermediate string arrays where possible.
if isinstance(key, EnumArray):
# EnumArray: map enum int codes -> child indices via
# a pre-built lookup table (O(N), no string comparison).
enum = key.possible_values
cache_key = id(enum)
if not hasattr(self, "_enum_lut_cache"):
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(item.name)
if child_idx is not None:
lut[item.index] = child_idx
self._enum_lut_cache[cache_key] = lut
idx = lut[numpy.asarray(key)]
elif (
key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum)
):
# Object array of Enum instances
enum = type(key[0])
cache_key = id(enum)
if not hasattr(self, "_enum_lut_cache"):
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(str(item.name))
if child_idx is not None:
lut[item.index] = child_idx
self._enum_lut_cache[cache_key] = lut
codes = numpy.array([v.index for v in key], dtype=numpy.intp)
idx = lut[codes]
else:
# String keys: map via dict lookup
if not numpy.issubdtype(key.dtype, numpy.str_):
key = key.astype("str")
names = list(
self.dtype.names
) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2']
conditions = [key == name for name in names]
# Vectorised dict lookup using numpy unique + scatter
uniq, inverse = numpy.unique(key, return_inverse=True)
uniq_idx = numpy.array(
[name_to_child_idx.get(u, SENTINEL) for u in uniq],
dtype=numpy.intp,
)
idx = uniq_idx[inverse]

# Gather values by child index using take on a stacked array.
values = [self.vector[name] for name in names]

# NumPy 2.x requires all arrays in numpy.select to have identical dtypes
# For structured arrays with different field sets, we need to normalize them
if (
is_structured = (
len(values) > 0
and hasattr(values[0].dtype, "names")
and values[0].dtype.names
):
# Check if all values have the same dtype
)

if is_structured:
dtypes_match = all(val.dtype == values[0].dtype for val in values)
v0_len = len(values[0])

if v0_len <= 1:
# 1-element structured arrays: simple concat + index
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

if not dtypes_match:
# Find the union of all field names across all values, preserving first seen order
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

# Create unified dtype with all fields
unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])

# Cast all values to unified dtype
values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)

default = numpy.zeros(len(values_cast[0]), dtype=unified_dtype)
# Fill with NaN
for field in unified_dtype.names:
default[field] = numpy.nan

result = numpy.select(conditions, values_cast, default)
unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])

values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)

default = numpy.zeros(1, dtype=unified_dtype)
for field in unified_dtype.names:
default[field] = numpy.nan
stacked = numpy.concatenate(values_cast + [default])
result = stacked[idx]
else:
default = numpy.full(1, numpy.nan, dtype=values[0].dtype)
stacked = numpy.concatenate(values + [default])
result = stacked[idx]
else:
# All dtypes match, use original logic
default = numpy.full_like(values[0], numpy.nan)
result = numpy.select(conditions, values, default)
# N-element structured arrays: check if fields are
# simple scalars (fast path) or nested records
# (fall back to numpy.select).
first_field = values[0].dtype.names[0]
field_dtype = values[0][first_field].dtype
is_nested = (
hasattr(field_dtype, "names") and field_dtype.names is not None
)

if is_nested:
# Nested structured: fall back to numpy.select
conditions = [idx == i for i in range(len(values))]
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)
values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)
default = numpy.zeros(v0_len, dtype=unified_dtype)
for field in unified_dtype.names:
default[field] = numpy.nan
result = numpy.select(conditions, values_cast, default)
else:
default = numpy.full_like(values[0], numpy.nan)
result = numpy.select(conditions, values, default)
else:
# Flat structured: fast per-field indexing
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)
values_unified = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_unified.append(casted)
field_names = all_fields
result_dtype = unified_dtype
else:
values_unified = values
field_names = values[0].dtype.names
result_dtype = values[0].dtype

result = numpy.empty(n, dtype=result_dtype)
arange_n = numpy.arange(v0_len)
for field in field_names:
field_stack = numpy.empty(
(len(values_unified) + 1, v0_len),
dtype=numpy.float64,
)
for i, v in enumerate(values_unified):
field_stack[i] = v[field]
field_stack[-1] = numpy.nan
result[field] = field_stack[idx, arange_n]
else:
# Non-structured array case
default = numpy.full_like(
values[0] if values else self.vector[key[0]], numpy.nan
)
result = numpy.select(conditions, values, default)
# Non-structured: values are either scalars (1-elem arrays)
# or N-element vectors (after prior vectorial indexing).
if values:
v0 = numpy.asarray(values[0])
if v0.ndim == 0 or v0.shape[0] <= 1:
# Scalar per child: 1D lookup
scalar_vals = numpy.empty(len(values) + 1, dtype=numpy.float64)
for i, v in enumerate(values):
scalar_vals[i] = float(numpy.asarray(v).flat[0])
scalar_vals[-1] = numpy.nan
result = scalar_vals[idx]
else:
# N-element vectors: stack into (K+1, N) matrix
m = v0.shape[0]
stacked = numpy.empty((len(values) + 1, m), dtype=numpy.float64)
for i, v in enumerate(values):
stacked[i] = v
stacked[-1] = numpy.nan
result = stacked[idx, numpy.arange(m)]
else:
result = numpy.full(n, numpy.nan)

# Check for unexpected keys (NaN results from missing keys)
# Check for unexpected keys
if helpers.contains_nan(result):
unexpected_keys = set(key).difference(self.vector.dtype.names)
unexpected_keys = set(
numpy.asarray(key, dtype=str)
if not numpy.issubdtype(numpy.asarray(key).dtype, numpy.str_)
else key
).difference(self.vector.dtype.names)
if unexpected_keys:
unexpected_key = unexpected_keys.pop()
raise ParameterNotFoundError(
Expand Down
31 changes: 30 additions & 1 deletion policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
self.is_over_dataset = dataset is not None

self.invalidated_caches = set()
self._fast_cache: dict = {}
self.debug: bool = False
self.trace: bool = trace
self.tracer: SimpleTracer = SimpleTracer() if not trace else FullTracer()
Expand Down Expand Up @@ -448,6 +449,17 @@ def calculate(
elif period is None and self.default_calculation_period is not None:
period = periods.period(self.default_calculation_period)

# Fast path: skip tracer, random seed and all _calculate() machinery for
# already-computed values. map_to and decode_enums are NOT cached here —
# they are post-processing steps that vary per call site.
if map_to is None and not decode_enums and not getattr(self, "_trace", False):
_fast_key = (variable_name, str(period))
_fast_cache = getattr(self, "_fast_cache", None)
if _fast_cache is not None:
_cached = _fast_cache.get(_fast_key)
if _cached is not None:
return _cached

self.tracer.record_calculation_start(variable_name, period, self.branch_name)

np.random.seed(hash(variable_name + str(period)) % 1000000)
Expand Down Expand Up @@ -752,6 +764,9 @@ def _calculate(self, variable_name: str, period: Period = None) -> ArrayLike:
if is_cache_available:
smc.set_cache_value(cache_path, array)

if hasattr(self, "_fast_cache"):
self._fast_cache[(variable_name, str(period))] = array

return array

def purge_cache_of_invalid_values(self) -> None:
Expand All @@ -761,6 +776,7 @@ def purge_cache_of_invalid_values(self) -> None:
for _name, _period in self.invalidated_caches:
holder = self.get_holder(_name)
holder.delete_arrays(_period)
self._fast_cache.pop((_name, str(_period)), None)
self.invalidated_caches = set()

def calculate_add(
Expand Down Expand Up @@ -1121,6 +1137,12 @@ def delete_arrays(self, variable: str, period: Period = None) -> None:
True
"""
self.get_holder(variable).delete_arrays(period)
if period is None:
self._fast_cache = {
k: v for k, v in self._fast_cache.items() if k[0] != variable
}
else:
self._fast_cache.pop((variable, str(period)), None)

def get_known_periods(self, variable: str) -> List[Period]:
"""
Expand Down Expand Up @@ -1205,8 +1227,15 @@ def clone(
new_dict = new.__dict__

for key, value in self.__dict__.items():
if key not in ("debug", "trace", "tracer", "branches"):
if key not in (
"debug",
"trace",
"tracer",
"branches",
"_fast_cache",
):
new_dict[key] = value
new._fast_cache = {}

new.persons = self.persons.clone(new)
setattr(new, new.persons.entity.key, new.persons)
Expand Down
Loading