Skip to content
Merged
86 changes: 40 additions & 46 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,54 +1512,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
out.flags.writeable = False
return out

# Optimization: Remove singleton dimensions to enable magic number usage
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
if singleton_dims:
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
if squeezed_shape:
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
squeezed_order = np.asarray(_morton_order(squeezed_shape))
out = np.zeros((n_total, n_dims), dtype=np.intp)
squeezed_col = 0
for full_col in range(n_dims):
if chunk_shape[full_col] != 1:
out[:, full_col] = squeezed_order[:, squeezed_col]
squeezed_col += 1
else:
# All dimensions are singletons, just return the single point
out = np.zeros((1, n_dims), dtype=np.intp)
out.flags.writeable = False
return out

# Find the largest power-of-2 hypercube that fits within chunk_shape.
# Within this hypercube, Morton codes are guaranteed to be in bounds.
min_dim = min(chunk_shape)
if min_dim >= 1:
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
hypercube_size = 1 << power # 2^power
n_hypercube = hypercube_size**n_dims
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
# is the size of this hypercube.
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
n_z = 1 << total_bits if total_bits > 0 else 1

# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
# n_z=262144, n_total=35937), consider the argsort strategy below.
order: npt.NDArray[np.intp]
if n_z <= 4 * n_total:
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
z_values = np.arange(n_z, dtype=np.intp)
all_coords = decode_morton_vectorized(z_values, chunk_shape)
shape_arr = np.array(chunk_shape, dtype=np.intp)
valid_mask = np.all(all_coords < shape_arr, axis=1)
order = all_coords[valid_mask]
else:
n_hypercube = 0
# Argsort strategy: enumerate all n_total valid coordinates directly,
# encode each to a Morton code, then sort by code. Avoids the 8x or
# larger overgeneration penalty for near-miss shapes like (33,33,33).
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
all_coords = np.stack([g.ravel() for g in grids], axis=1)

# Encode all coordinates to Morton codes (vectorized).
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
max_coord_bits = max(bits_per_dim)
z_codes = np.zeros(n_total, dtype=np.intp)
output_bit = 0
for coord_bit in range(max_coord_bits):
for dim in range(n_dims):
if coord_bit < bits_per_dim[dim]:
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
output_bit += 1

sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
order = np.asarray(all_coords[sort_idx], dtype=np.intp)

# Within the hypercube, no bounds checking needed - use vectorized decoding
if n_hypercube > 0:
z_values = np.arange(n_hypercube, dtype=np.intp)
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
else:
order = np.empty((0, n_dims), dtype=np.intp)

# For remaining elements outside the hypercube, bounds checking is needed
remaining: list[tuple[int, ...]] = []
i = n_hypercube
while len(order) + len(remaining) < n_total:
m = decode_morton(i, chunk_shape)
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
remaining.append(m)
i += 1

if remaining:
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
order.flags.writeable = False
return order

Expand Down