Skip to content

Performance: fetch inner chunks concurrently within a shard #3953

@maxrjones

Description

@maxrjones

TL;DR

When a read selection touches multiple inner chunks inside a single shard,
ShardingCodec._decode_partial_single fetches those chunks sequentially
in a for loop with await inside the body. They should be fetched
concurrently. Every other place in the codebase that fetches multiple
chunks already does this via zarr.core.common.concurrent_map.

The fix is to replace the loop with a single concurrent_map call.

Background: what is "sharding"?

Sharding is a zarr v3 codec that packs many small "inner chunks" into one
larger "shard" object in storage. A user reads or writes at the inner-chunk
granularity, but storage I/O happens at the shard granularity. The shard
contains, in order:

  1. The compressed bytes of each inner chunk, concatenated.
  2. An index at the end of the shard giving the (offset, length) of each
    inner chunk within the shard.

To read a region that spans N inner chunks inside one shard, the codec:

  1. Fetches the shard index (one byte-range request).
  2. Looks up the (offset, length) for each of the N inner chunks.
  3. Fetches the bytes for each inner chunk (currently: N sequential
    range requests; should be: N concurrent range requests).
  4. Decodes each inner chunk and assembles the output buffer.

Step 3 is the bug.

Where the bug lives

File: src/zarr/codecs/sharding.py
Function: ShardingCodec._decode_partial_single
Lines (approximately): 500–513

else:
    # read some chunks within the shard
    shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
    if shard_index is None:
        return None
    shard_dict = {}
    for chunk_coords in all_chunk_coords:
        chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
        if chunk_byte_slice:
            chunk_bytes = await byte_getter.get(
                prototype=chunk_spec.prototype,
                byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
            )
            if chunk_bytes:
                shard_dict[chunk_coords] = chunk_bytes

Each await byte_getter.get(...) blocks until the previous one finishes.
Against a remote store (S3, GCS, HTTP) this serializes a round-trip per
chunk, which can dominate read latency for shard-heavy datasets.

The matching "full shard" branch a few lines above (_load_full_shard_maybe)
is fine — it issues a single byte-range request for the whole shard.

How concurrency is done elsewhere — copy this pattern

src/zarr/core/common.py defines concurrent_map:

async def concurrent_map[T: tuple[Any, ...], V](
    items: Iterable[T],
    func: Callable[..., Awaitable[V]],
    limit: int | None = None,
) -> list[V]: ...

It runs func(*item) for every item with asyncio.gather, optionally
limited to limit in flight. Examples already in the tree:

  • src/zarr/core/codec_pipeline.py (the array-level read path) uses it to
    fetch chunks concurrently across the array.
  • src/zarr/abc/codec.py (Codec.decode_partial default impl) uses it to
    call _decode_partial_single concurrently across multiple shards.

So today, across shards we are concurrent; within a shard we are
not. This issue closes that gap.

What the fix looks like

Roughly:

else:
    shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
    if shard_index is None:
        return None

    chunks_to_fetch = [
        (chunk_coords, shard_index.get_chunk_slice(chunk_coords))
        for chunk_coords in all_chunk_coords
    ]
    chunks_to_fetch = [(c, sl) for c, sl in chunks_to_fetch if sl is not None]

    async def _fetch(coords, byte_slice):
        chunk_bytes = await byte_getter.get(
            prototype=chunk_spec.prototype,
            byte_range=RangeByteRequest(byte_slice[0], byte_slice[1]),
        )
        return coords, chunk_bytes

    results = await concurrent_map(chunks_to_fetch, _fetch)
    shard_dict = {coords: b for coords, b in results if b is not None}

Treat that as a sketch, not a copy-paste. Things to decide / check while
implementing:

  • Concurrency limit. Look at how codec_pipeline.read picks its
    concurrency value (it threads through from config). The same value, or
    an explicit zarr.config knob, should govern in-shard concurrency.
    Pick the simplest thing that matches existing convention; flag the
    decision in the PR description so reviewers can weigh in.
  • Imports. concurrent_map lives in zarr.core.common. Check whether
    it's already imported at the top of sharding.py; add the import if not.
  • Type annotations. The repo uses strict typing. Annotate _fetch's
    parameters and return type. Run pre-commit run --all-files (or at least
    mypy) before pushing.
  • Don't change the full-shard branch. That one is already a single
    request and is correct.
  • Don't change _decode_partial_single's signature or return type.
    This is a pure-internal performance fix.

How to verify the fix

1. Existing tests must still pass

hatch run test:run-coverage tests/test_codecs/test_sharding.py

(Or equivalent — see CONTRIBUTING.md for the project's preferred runner.
Per project convention, prefer hatch or a PEP 723 uv script over
hand-rolled venvs.)

2. Add a test that demonstrates concurrency

A behavioral test is more valuable than a coverage test. Use the existing
LatencyStore wrapper from zarr.testing.store — it adds a fixed
asyncio.sleep to every .get() and .set() call, so concurrent vs.
sequential fetches show up directly in wall-clock time. Suggested shape:

  • Wrap a MemoryStore in LatencyStore(..., get_latency=T) with T
    large enough to dwarf overhead (e.g. 50 ms).
  • Build a sharded array on top, write some data, then read a selection
    that spans ≥ 2 inner chunks within one shard.
  • Assert total wall time is bounded by ~2 * T (one sleep for the index,
    one for the concurrent batch of chunks) rather than scaling with the
    number of inner chunks touched.

For a working model of the timing-assertion style, see
tests/test_group.py::test_group_members_concurrency (search for
LatencyStore in that file). It uses essentially the inequality
elapsed < num_items * get_latency to assert that work was concurrent —
the same shape of assertion you want here, at the chunk level.

A weaker but still useful assertion is "all the inner-chunk .get() calls
were issued before any of them returned" — write a small wrapper store
that records start/finish timestamps, and check the intervals overlap.

3. Optional: micro-benchmark

A before/after timing against a LatencyStore wrapper (sleep N ms per
.get) makes the performance win concrete in the PR description.
The repro at the bottom of this doc is enough — paste the elapsed-time
numbers from before and after the change.

Out of scope (do not do these in the same PR)

  • Reading the index and the chunks in a single round trip. Some stores
    support multi-range GETs; an in-flight branch (feat/get-many) is adding
    a SupportsGetRanges protocol with a get_ranges method on the store.
    Once that lands, the sharding codec is a natural consumer — but that is
    a separate, larger change. Keep this PR to the concurrent_map swap.
  • Refactoring _decode_partial_single more broadly. It is doing a few
    things (indexing, fetching, decoding); leave the structure alone.
  • Changing the full-shard read path.
  • Touching the write path. This issue is only about reads.

Repro you can paste into a script

import asyncio
import time

import numpy as np

import zarr
from zarr.storage import MemoryStore
from zarr.testing.store import LatencyStore


async def main():
    latency_s = 0.05
    store = LatencyStore(MemoryStore(), get_latency=latency_s, set_latency=0.0)

    arr = await zarr.api.asynchronous.create_array(
        store=store,
        shape=(64,),
        shards=(64,),     # one shard covering the whole array
        chunks=(8,),      # 8 inner chunks per shard
        dtype="i4",
    )
    await arr.setitem(slice(None), np.arange(64, dtype="i4"))

    t0 = time.perf_counter()
    # selection that touches several inner chunks but not the whole shard
    _ = await arr.getitem(slice(0, 40))
    print(f"elapsed: {time.perf_counter() - t0:.3f}s")


asyncio.run(main())

Before the fix: elapsed scales with the number of inner chunks touched
(roughly N * latency_s plus the index fetch).
After the fix: elapsed should be ~2 * latency_s (one for the index, one
for the concurrent batch of chunks).

(If the create_array call signature has drifted, the canonical reference
is tests/test_codecs/test_sharding.py.)

PR checklist

  • _decode_partial_single uses concurrent_map for the partial-shard
    branch.
  • concurrent_map is imported at the top of sharding.py.
  • Concurrency-limit decision is called out in the PR body.
  • New test demonstrates concurrent fetching (not just correctness).
  • Existing sharding tests still pass.
  • pre-commit run --all-files is clean.
  • A changes/<PR#>.bugfix.md (or .feature.md — argue for one in the
    PR) towncrier fragment is added.

This issue was flagged to me by @aldenks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions