diff --git a/activestorage/active.py b/activestorage/active.py index fb963f32..dc34cacb 100644 --- a/activestorage/active.py +++ b/activestorage/active.py @@ -116,7 +116,8 @@ def hfix(x): missing_value = ds.attrs.get('missing_value') # see https://github.com/NCAS-CMS/PyActiveStorage/pull/303 if isinstance(missing_value, np.ndarray): - missing_value = missing_value[0] + if missing_value.size == 1: + missing_value = missing_value[0] valid_min = hfix(ds.attrs.get('valid_min')) valid_max = hfix(ds.attrs.get('valid_max')) valid_range = hfix(ds.attrs.get('valid_range')) diff --git a/activestorage/storage.py b/activestorage/storage.py index eee0cabe..498b50da 100644 --- a/activestorage/storage.py +++ b/activestorage/storage.py @@ -129,10 +129,19 @@ def mask_missing(data, missing): fill_value, missing_value, valid_min, valid_max = missing if fill_value is not None: - data = np.ma.masked_equal(data, fill_value) + if isinstance(fill_value, np.ndarray) or isinstance(fill_value, list): + data = np.ma.masked_where(data == fill_value, data) + else: + data = np.ma.masked_equal(data, fill_value) if missing_value is not None: - data = np.ma.masked_equal(data, missing_value) + if isinstance(missing_value, np.ndarray) or isinstance(missing_value, list): + try: + data = np.ma.masked_where(data == missing_value, data) + except ValueError: # not broadcastable + raise ValueError("Data and missing_value arrays are not brodcastable!") + else: + data = np.ma.masked_equal(data, missing_value) if valid_max is not None: data = np.ma.masked_greater(data, valid_max) diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py index 56d3e627..96c4c0ee 100644 --- a/tests/unit/test_storage.py +++ b/tests/unit/test_storage.py @@ -6,6 +6,67 @@ import activestorage.storage as st +def test_mask_missing(): + """Test mask missing.""" + missing_1 = ([-900.], np.array([-900.]), None, None) + missing_2 = ([-900., 33.], np.array([-900., 33.]), None, None) + data_1 = np.ma.array( + [[[-900., 33.], [33., -900], [33., 44.]]], + mask=False, + fill_value=-900.0, + dtype=float + ) + data_2 = np.ma.array( + [[[-900., 33.], [33., -900], [33., 44.]]], + mask=False, + fill_value=[-900.0, 33.], + dtype=float + ) + res_1 = st.mask_missing(data_1, missing_1) + expected_1 = np.ma.array( + data_1, + mask=[[[True, False], [False, True], [False, False]]] + ) + np.testing.assert_array_equal(res_1, expected_1) + res_2 = st.mask_missing(data_2, missing_2) + expected_2 = np.ma.array( + data_2, + mask=[[[True, True], [False, False], [False, False]]] + ) + np.testing.assert_array_equal(res_2, expected_2) + + +def test_mask_missing_missing_broadcastable(): + """Test mask missing when fill_value cant be broadcast to data.""" + data = np.ma.array( + [[[-900., 33.], [33., -900], [33., 44.]]], + mask=False, + fill_value=np.array([-900.0]), + dtype=float + ) + missing = (-900, np.array([-900., 33.]), None, None) + res = st.mask_missing(data, missing) + expected = np.ma.array( + data, + mask=[[[True, True], [False, False], [False, False]]] + ) + np.testing.assert_array_equal(res, expected) + + +def test_mask_missing_missing_not_broadcastable(): + """Test mask missing when fill_value cant be broadcast to data.""" + data = np.ma.array( + [[[-900., 33.], [33., -900], [33., 44.]]], + mask=False, + fill_value=np.array([-900.0]), + dtype=float + ) + missing = (-900, np.array([-900., -900., 33.]), None, None) + msg = "Data and missing_value arrays are not brodcastable!" + with pytest.raises(ValueError, match=msg): + st.mask_missing(data, missing) + + def test_reduce_chunk(): """Test reduce chunk entirely.""" rfile = "tests/test_data/cesm2_native.nc"