diff --git a/changes/3628.misc.md b/changes/3628.misc.md new file mode 100644 index 0000000000..0aa706e5cd --- /dev/null +++ b/changes/3628.misc.md @@ -0,0 +1 @@ +Avoid reading lazy arrays or on device arrays twice when comparing them to 0 during the writing process. diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..a9f1cb8938 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, TypeVar from warnings import warn +import numpy as np + from zarr.abc.codec import ( ArrayArrayCodec, ArrayBytesCodec, @@ -19,6 +21,7 @@ from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning from zarr.registry import register_pipeline +from zarr.core.buffer import NDBuffer if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -26,7 +29,7 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer + from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -413,12 +416,28 @@ async def _read_key( if chunk_array is None: chunk_array_batch.append(None) # type: ignore[unreachable] else: - if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( - fill_value_or_default(chunk_spec) - ): - chunk_array_batch.append(None) - else: + if chunk_spec.config.write_empty_chunks: chunk_array_batch.append(chunk_array) + else: + # The operation array_equal operation below effectively will force the array + # into memory. + # if the result is useful, we want to avoid reading it twice + # from a potentially lazy operation. So we cache it here. + # If the result is not useful, we leave it for the garbage collector. + # We optimize this operation for the case that the GPU + if not hasattr(chunk_array._data, '__cuda_array_interface__'): + # I'm not sure why this implementation doesn't work + # it seems like something is getting missed by me + # chunk_array = NDBuffer(np.asarray(chunk_array._data)) + # This line here just feels more dirty + chunk_array._data = np.asarray(chunk_array._data) + + if chunk_array.all_equal( + fill_value_or_default(chunk_spec) + ): + chunk_array_batch.append(None) + else: + chunk_array_batch.append(chunk_array) chunk_bytes_batch = await self.encode_batch( [