diff --git a/xarray_beam/_src/core.py b/xarray_beam/_src/core.py index 3aa80d9..44ca4ee 100644 --- a/xarray_beam/_src/core.py +++ b/xarray_beam/_src/core.py @@ -20,6 +20,7 @@ import itertools import math import pickle +import warnings import time from typing import Any, Generic, TypeVar @@ -272,7 +273,15 @@ def is_deterministic(self) -> bool: return False def estimate_size(self, value: xarray.Dataset) -> int: - return value.nbytes + nbytes = value.nbytes + if nbytes > 2**31 - 1: + warnings.warn( + f"Dataset size ({nbytes / 2**30:.1f} GB) exceeds Beam's 2 GB " + f"counter limit; capping estimate_size to 2^31-1.", + stacklevel=2, + ) + return 2**31 - 1 + return nbytes def to_type_hint(self) -> type[xarray.Dataset]: return xarray.Dataset diff --git a/xarray_beam/_src/core_test.py b/xarray_beam/_src/core_test.py index f44f6ac..2e4ddc1 100644 --- a/xarray_beam/_src/core_test.py +++ b/xarray_beam/_src/core_test.py @@ -309,6 +309,21 @@ def test_no_fallback_deterministic_coder_warnings(self): with beam.Pipeline(runner='DirectRunner') as p: p | beam.Create(inputs) | beam.GroupByKey() + def test_dataset_coder_estimate_size_small(self): + dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) + self.assertEqual(core.DatasetCoder().estimate_size(dataset), dataset.nbytes) + + def test_dataset_coder_estimate_size_overflow(self): + # dask-backed array gives us a >2 GB nbytes without allocating memory: + # 2**29 * 8 bytes = 2**32 bytes (~4 GB), above Beam's 2**31 - 1 cap. + dataset = xarray.Dataset( + {'foo': (('x',), da.zeros(2**29, dtype='float64'))} + ) + self.assertGreater(dataset.nbytes, 2**31 - 1) + with self.assertWarnsRegex(UserWarning, "exceeds Beam's 2 GB"): + size = core.DatasetCoder().estimate_size(dataset) + self.assertEqual(size, 2**31 - 1) + class DatasetToChunksTest(test_util.TestCase):