Skip to content
1 change: 1 addition & 0 deletions changes/3907.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add protocols for stores that support byte-range-writes. This is necessary to support in-place writes of sharded arrays.
18 changes: 18 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"Store",
"SupportsDeleteSync",
"SupportsGetSync",
"SupportsSetRange",
"SupportsSetSync",
"SupportsSyncStore",
"set_or_delete",
Expand Down Expand Up @@ -709,6 +710,23 @@ async def delete(self) -> None: ...
async def set_if_not_exists(self, default: Buffer) -> None: ...


@runtime_checkable
class SupportsSetRange(Protocol):
"""Protocol for stores that support writing to a byte range within an existing value.

Overwrites ``len(value)`` bytes starting at byte offset ``start`` within the
existing stored value for ``key``. The key must already exist and the write
must fit within the existing value (i.e., ``start + len(value) <= len(existing)``).

Behavior when the write extends past the end of the existing value is
implementation-specific and should not be relied upon.
"""

async def set_range(self, key: str, value: Buffer, start: int) -> None: ...

def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ...


@runtime_checkable
class SupportsGetSync(Protocol):
def get_sync(
Expand Down
23 changes: 22 additions & 1 deletion src/zarr/storage/_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RangeByteRequest,
Store,
SuffixByteRequest,
SupportsSetRange,
)
from zarr.core.buffer import Buffer
from zarr.core.buffer.core import default_buffer_prototype
Expand Down Expand Up @@ -77,6 +78,13 @@ def _atomic_write(
raise


def _put_range(path: Path, value: Buffer, start: int) -> None:
"""Write bytes at a specific offset within an existing file."""
with path.open("r+b") as f:
f.seek(start)
f.write(value.as_numpy_array().tobytes())


def _put(path: Path, value: Buffer, exclusive: bool = False) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
# write takes any object supporting the buffer protocol
Expand All @@ -85,7 +93,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int:
return f.write(view)


class LocalStore(Store):
class LocalStore(Store, SupportsSetRange):
"""
Store for the local file system.

Expand Down Expand Up @@ -292,6 +300,19 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
path = self.root / key
await asyncio.to_thread(_put, path, value, exclusive=exclusive)

async def set_range(self, key: str, value: Buffer, start: int) -> None:
if not self._is_open:
await self._open()
self._check_writable()
path = self.root / key
await asyncio.to_thread(_put_range, path, value, start)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._ensure_open_sync()
self._check_writable()
path = self.root / key
_put_range(path, value, start)

async def delete(self, key: str) -> None:
"""
Remove a key from the store.
Expand Down
24 changes: 22 additions & 2 deletions src/zarr/storage/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any, Self

from zarr.abc.store import ByteRequest, Store
from zarr.abc.store import ByteRequest, Store, SupportsSetRange
from zarr.core.buffer import Buffer, gpu
from zarr.core.buffer.core import default_buffer_prototype
from zarr.core.common import concurrent_map
Expand All @@ -26,7 +26,7 @@
logger = getLogger(__name__)


class MemoryStore(Store):
class MemoryStore(Store, SupportsSetRange):
"""
Store for local memory.

Expand Down Expand Up @@ -194,6 +194,26 @@ async def delete(self, key: str) -> None:
except KeyError:
logger.debug("Key %s does not exist.", key)

def _set_range_impl(self, key: str, value: Buffer, start: int) -> None:
buf = self._store_dict[key]
target = buf.as_numpy_array()
if not target.flags.writeable:
target = target.copy()
self._store_dict[key] = buf.__class__(target)
source = value.as_numpy_array()
target[start : start + len(source)] = source

async def set_range(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
await self._ensure_open()
self._set_range_impl(key, value, start)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
if not self._is_open:
self._is_open = True
self._set_range_impl(key, value, start)

async def list(self) -> AsyncIterator[str]:
# docstring inherited
for key in self._store_dict:
Expand Down
67 changes: 67 additions & 0 deletions tests/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import zarr
from zarr import create_array
from zarr.abc.store import SupportsSetRange
from zarr.core.buffer import Buffer, cpu
from zarr.core.sync import sync
from zarr.storage import LocalStore
Expand Down Expand Up @@ -162,6 +163,72 @@ def test_get_json_sync_with_prototype_none(
result = store._get_json_sync(key, prototype=buffer_cls)
assert result == data

def test_supports_set_range(self, store: LocalStore) -> None:
"""LocalStore should implement SupportsSetRange."""
assert isinstance(store, SupportsSetRange)

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
async def test_set_range(
self, store: LocalStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range should overwrite bytes at the given offset."""
await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
def test_set_range_sync(
self, store: LocalStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range_sync should overwrite bytes at the given offset."""
sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")))
store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

async def test_set_range_not_open(self, store_not_open: LocalStore) -> None:
"""set_range auto-opens a closed store."""
assert not store_not_open._is_open
await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0)
assert getattr(store_not_open, "_is_open") # noqa: B009
observed = await self.get(store_not_open, "test/key")
assert observed.to_bytes() == b"XXAAAAAAAA"

def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None:
"""set_range_sync auto-opens a closed store."""
assert not store_not_open._is_open
sync(self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")))
store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0)
assert getattr(store_not_open, "_is_open") # noqa: B009
observed = sync(self.get(store_not_open, "test/key"))
assert observed.to_bytes() == b"XXAAAAAAAA"


@pytest.mark.parametrize("exclusive", [True, False])
def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None:
Expand Down
66 changes: 66 additions & 0 deletions tests/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import zarr
from zarr.abc.store import SupportsSetRange
from zarr.core.buffer import Buffer, cpu, gpu
from zarr.core.sync import sync
from zarr.errors import ZarrUserWarning
Expand Down Expand Up @@ -127,6 +128,71 @@ def test_get_json_sync_with_prototype_none(
result = store._get_json_sync(key, prototype=buffer_cls)
assert result == data

def test_supports_set_range(self, store: MemoryStore) -> None:
"""MemoryStore should implement SupportsSetRange."""
assert isinstance(store, SupportsSetRange)

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
async def test_set_range(
self, store: MemoryStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range should overwrite bytes at the given offset."""
await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

@pytest.mark.parametrize(
("start", "patch", "expected"),
[
(0, b"XX", b"XXAAAAAAAA"),
(3, b"XX", b"AAAXXAAAAA"),
(8, b"XX", b"AAAAAAAAXX"),
(0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"),
(5, b"B", b"AAAAABAAAA"),
(0, b"BCDE", b"BCDEAAAAAA"),
],
ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"],
)
def test_set_range_sync(
self, store: MemoryStore, start: int, patch: bytes, expected: bytes
) -> None:
"""set_range_sync should overwrite bytes at the given offset."""
store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA")
store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start)
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
assert result.to_bytes() == expected

async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None:
"""set_range auto-opens a closed store."""
assert not store_not_open._is_open
await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))
await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0)
assert getattr(store_not_open, "_is_open") # noqa: B009
observed = await self.get(store_not_open, "test/key")
assert observed.to_bytes() == b"XXAAAAAAAA"

def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None:
"""set_range_sync auto-opens a closed store."""
assert not store_not_open._is_open
store_not_open._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA")
store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0)
assert getattr(store_not_open, "_is_open") # noqa: B009
assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA"


# TODO: fix this warning
@pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")
Expand Down
Loading