From 1eb0e862a8ce7420170d5429f65f9018c4a63269 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Sun, 28 Dec 2025 18:39:19 +0000 Subject: [PATCH] Optimize _gridmake2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Performance Optimization Summary The optimized code achieves an **884% speedup** (from 1.07ms to 109μs) by replacing NumPy's high-level array operations with **Numba JIT-compiled explicit loops**. ### Key Optimizations **1. Numba JIT Compilation (`@njit(cache=True)`)** - Compiles the function to machine code at runtime, eliminating Python interpreter overhead - The `cache=True` flag stores the compiled version, avoiding recompilation costs on subsequent runs - Particularly effective here because the function contains simple arithmetic and array indexing operations that Numba optimizes well **2. Explicit Loop-Based Construction vs. NumPy Broadcasting** - **Original approach**: Used `np.tile()`, `np.repeat()`, and `np.column_stack()` which create multiple intermediate arrays and perform memory allocations - **Optimized approach**: Pre-allocates the output array once with `np.empty()` and fills it directly using nested loops - This eliminates intermediate array creation and reduces memory allocation overhead **3. Why This Works** From the line profiler, the original code spent: - **76.4%** of time in `np.column_stack([np.tile(...)])` - **8.5%** in `np.repeat()` - **9.3%** in `np.tile()` for the 2D case These NumPy operations, while convenient, involve: - Multiple temporary array allocations - Memory copies during stacking operations - Python-level function call overhead Numba's compiled loops avoid all of this by directly computing each output element in place. ### Impact on Workloads Based on `function_references`, `_gridmake2` is called from `gridmake()` which: - Calls it **once for 2 input arrays** - Calls it **iteratively** for 3+ arrays (once initially, then in a loop for remaining arrays) For multi-array scenarios (3+ inputs), the speedup compounds significantly since `_gridmake2` is called multiple times per `gridmake()` invocation. The nearly **9x speedup** per call translates to substantial gains in computational economics applications where Cartesian products are frequently computed for state space expansions. ### Trade-offs - First call incurs JIT compilation overhead (~tens of milliseconds), but `cache=True` mitigates this for subsequent calls - Code is more verbose but dramatically faster for repeated execution patterns - Best suited for scenarios where the function is called multiple times (amortizing compilation cost) --- code_to_optimize/discrete_riccati.py | 56 ++++++++++++++++------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/code_to_optimize/discrete_riccati.py b/code_to_optimize/discrete_riccati.py index 53fe30891..c111c1e6c 100644 --- a/code_to_optimize/discrete_riccati.py +++ b/code_to_optimize/discrete_riccati.py @@ -1,5 +1,4 @@ -""" -Utility functions used in CompEcon +"""Utility functions used in CompEcon Based routines found in the CompEcon toolbox by Miranda and Fackler. @@ -9,14 +8,16 @@ and Finance, MIT Press, 2002. """ + from functools import reduce + import numpy as np import torch +from numba import njit def ckron(*arrays): - """ - Repeatedly applies the np.kron function to an arbitrary number of + """Repeatedly applies the np.kron function to an arbitrary number of input arrays Parameters @@ -43,8 +44,7 @@ def ckron(*arrays): def gridmake(*arrays): - """ - Expands one or more vectors (or matrices) into a matrix where rows span the + """Expands one or more vectors (or matrices) into a matrix where rows span the cartesian product of combinations of the input arrays. Each column of the input arrays will correspond to one column of the output matrix. @@ -79,13 +79,12 @@ def gridmake(*arrays): out = _gridmake2(out, arr) return out - else: - raise NotImplementedError("Come back here") + raise NotImplementedError("Come back here") +@njit(cache=True) def _gridmake2(x1, x2): - """ - Expands two vectors (or matrices) into a matrix where rows span the + """Expands two vectors (or matrices) into a matrix where rows span the cartesian product of combinations of the input arrays. Each column of the input arrays will correspond to one column of the output matrix. @@ -114,19 +113,31 @@ def _gridmake2(x1, x2): """ if x1.ndim == 1 and x2.ndim == 1: - return np.column_stack([np.tile(x1, x2.shape[0]), - np.repeat(x2, x1.shape[0])]) - elif x1.ndim > 1 and x2.ndim == 1: - first = np.tile(x1, (x2.shape[0], 1)) - second = np.repeat(x2, x1.shape[0]) - return np.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") + n1 = x1.shape[0] + n2 = x2.shape[0] + out = np.empty((n1 * n2, 2), dtype=x1.dtype) + for i in range(n2): + for j in range(n1): + out[i * n1 + j, 0] = x1[j] + out[i * n1 + j, 1] = x2[i] + return out + if x1.ndim > 1 and x2.ndim == 1: + n1 = x1.shape[0] + n2 = x2.shape[0] + n_features = x1.shape[1] + out = np.empty((n1 * n2, n_features + 1), dtype=x1.dtype) + for i in range(n2): + for j in range(n1): + idx = i * n1 + j + for k in range(n_features): + out[idx, k] = x1[j, k] + out[idx, n_features] = x2[i] + return out + raise NotImplementedError("Come back here") def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - PyTorch version of _gridmake2. + """PyTorch version of _gridmake2. Expands two tensors into a matrix where rows span the cartesian product of combinations of the input tensors. Each column of the input tensors @@ -161,10 +172,9 @@ def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: first = x1.tile(x2.shape[0]) second = x2.repeat_interleave(x1.shape[0]) return torch.column_stack([first, second]) - elif x1.dim() > 1 and x2.dim() == 1: + if x1.dim() > 1 and x2.dim() == 1: # tile x1 along first dimension first = x1.tile(x2.shape[0], 1) second = x2.repeat_interleave(x1.shape[0]) return torch.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") + raise NotImplementedError("Come back here")