From 0bbcafb6740066b95c4dedb2ea555bcd4100701d Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:21:43 +0200 Subject: [PATCH 01/19] add bin classes to multiply --- src/waveresponse/_core.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index a9cf236..16431fe 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -58,19 +58,37 @@ def multiply(grid1, grid2, output_type="grid"): """ TYPE_MAP = { + "Grid": Grid, + "RAO": RAO, + "DirectionalSpectrum": DirectionalSpectrum, + "DirectionalBinSpectrum": DirectionalBinSpectrum, + "WaveSpectrum": WaveSpectrum, + "WaveBinSpectrum": WaveBinSpectrum, + } + + # Deprecated types for backward compatibility + TYPE_MAP_DEPRECATED = { "grid": Grid, "rao": RAO, "directional_spectrum": DirectionalSpectrum, "wave_spectrum": WaveSpectrum, } - if output_type not in TYPE_MAP: - raise ValueError("The given `output_type` is not valid.") + if not isinstance(output_type, Grid): + if output_type not in TYPE_MAP and output_type not in TYPE_MAP_DEPRECATED: + raise ValueError("The given `output_type` is not valid.") + if output_type in TYPE_MAP_DEPRECATED: + warnings.warn( + f"The '{output_type}' type is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + output_type = TYPE_MAP_DEPRECATED[output_type] + else: + output_type = TYPE_MAP_DEPRECATED[output_type] _check_is_similar(grid1, grid2, exact_type=False) - type_ = TYPE_MAP.get(output_type) - freq = grid1._freq dirs = grid1._dirs vals = np.multiply(grid1._vals, grid2._vals) @@ -85,7 +103,7 @@ def multiply(grid1, grid2, output_type="grid"): **convention, ) - return type_.from_grid(new) + return output_type.from_grid(new) def _cast_to_grid(grid): From 9e27defc7df74271d2d237ac839923d1e1544ffc Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:24:56 +0200 Subject: [PATCH 02/19] update default type --- src/waveresponse/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 16431fe..d835304 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, output_type="grid"): +def multiply(grid1, grid2, output_type="Grid"): """ Multiply values (element-wise). @@ -85,7 +85,7 @@ def multiply(grid1, grid2, output_type="grid"): ) output_type = TYPE_MAP_DEPRECATED[output_type] else: - output_type = TYPE_MAP_DEPRECATED[output_type] + output_type = TYPE_MAP[output_type] _check_is_similar(grid1, grid2, exact_type=False) From f2271909db488f024a653bfcd0f114181a7f9956 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:30:46 +0200 Subject: [PATCH 03/19] check issubclass --- src/waveresponse/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index d835304..edd56d2 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -74,7 +74,7 @@ def multiply(grid1, grid2, output_type="Grid"): "wave_spectrum": WaveSpectrum, } - if not isinstance(output_type, Grid): + if not issubclass(output_type, Grid): if output_type not in TYPE_MAP and output_type not in TYPE_MAP_DEPRECATED: raise ValueError("The given `output_type` is not valid.") if output_type in TYPE_MAP_DEPRECATED: From d292fca878d620479490df53cc6bee83444057cf Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:33:47 +0200 Subject: [PATCH 04/19] docstring --- src/waveresponse/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index edd56d2..78d4634 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -53,7 +53,7 @@ def multiply(grid1, grid2, output_type="Grid"): Grid object. grid2 : obj Grid object. - output_type : str {"grid", "rao", "directional_spectrum", "wave_spectrum"} + output_type : str {"Grid", "RAO", "DirectionalSpectrum", "DirectionalBinSpectrum", "WaveSpectrum", "WaveBinSpectrum"} Output grid type. """ From b0ea674fa6926d4104bc43aef710b3418e66d7bd Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:45:47 +0200 Subject: [PATCH 05/19] some changes --- src/waveresponse/_core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 78d4634..c40b8ba 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -74,18 +74,18 @@ def multiply(grid1, grid2, output_type="Grid"): "wave_spectrum": WaveSpectrum, } - if not issubclass(output_type, Grid): - if output_type not in TYPE_MAP and output_type not in TYPE_MAP_DEPRECATED: - raise ValueError("The given `output_type` is not valid.") + if isinstance(output_type, type) and issubclass(output_type, Grid): + gridtype = output_type + elif output_type in (map_ := TYPE_MAP | TYPE_MAP_DEPRECATED): + gridtype = map_[output_type] if output_type in TYPE_MAP_DEPRECATED: warnings.warn( f"The '{output_type}' type is deprecated and will be removed in a future release.", DeprecationWarning, stacklevel=2, ) - output_type = TYPE_MAP_DEPRECATED[output_type] - else: - output_type = TYPE_MAP[output_type] + else: + raise ValueError("The given `output_type` is not valid.") _check_is_similar(grid1, grid2, exact_type=False) @@ -103,7 +103,7 @@ def multiply(grid1, grid2, output_type="Grid"): **convention, ) - return output_type.from_grid(new) + return gridtype.from_grid(new) def _cast_to_grid(grid): From fba8601b20a9df1318b27542a3988d153248f193 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 12:53:44 +0200 Subject: [PATCH 06/19] docstring --- src/waveresponse/_core.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index c40b8ba..ea7cc8a 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -53,7 +53,7 @@ def multiply(grid1, grid2, output_type="Grid"): Grid object. grid2 : obj Grid object. - output_type : str {"Grid", "RAO", "DirectionalSpectrum", "DirectionalBinSpectrum", "WaveSpectrum", "WaveBinSpectrum"} + output_type : grid-type, default 'Grid' Output grid type. """ @@ -74,18 +74,32 @@ def multiply(grid1, grid2, output_type="Grid"): "wave_spectrum": WaveSpectrum, } + if output_type in TYPE_MAP_DEPRECATED: + warnings.warn( + f"The '{output_type}' type is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + + # if isinstance(output_type, type) and issubclass(output_type, Grid): + # gridtype = output_type + # elif output_type in (map_ := TYPE_MAP | TYPE_MAP_DEPRECATED): + # gridtype = map_[output_type] + # if output_type in TYPE_MAP_DEPRECATED: + # warnings.warn( + # f"The '{output_type}' type is deprecated and will be removed in a future release.", + # DeprecationWarning, + # stacklevel=2, + # ) + # else: + # raise ValueError("The given `output_type` is not valid.") + if isinstance(output_type, type) and issubclass(output_type, Grid): gridtype = output_type - elif output_type in (map_ := TYPE_MAP | TYPE_MAP_DEPRECATED): - gridtype = map_[output_type] - if output_type in TYPE_MAP_DEPRECATED: - warnings.warn( - f"The '{output_type}' type is deprecated and will be removed in a future release.", - DeprecationWarning, - stacklevel=2, - ) + elif output_type in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): + gridtype = type_map[output_type] else: - raise ValueError("The given `output_type` is not valid.") + raise ValueError(f"Invalid `output_type`: {output_type!r}") _check_is_similar(grid1, grid2, exact_type=False) From 01a7869e38a84089df33e27446a9d0beef5f2b4a Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:04:32 +0200 Subject: [PATCH 07/19] gtype --- src/waveresponse/_core.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index ea7cc8a..51ef533 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, output_type="Grid"): +def multiply(grid1, grid2, gtype="Grid", **kwargs): """ Multiply values (element-wise). @@ -53,7 +53,7 @@ def multiply(grid1, grid2, output_type="Grid"): Grid object. grid2 : obj Grid object. - output_type : grid-type, default 'Grid' + gtype : grid-type, default 'Grid' Output grid type. """ @@ -74,32 +74,27 @@ def multiply(grid1, grid2, output_type="Grid"): "wave_spectrum": WaveSpectrum, } - if output_type in TYPE_MAP_DEPRECATED: + if "output_type" in kwargs: + gtype = kwargs.pop("output_type") warnings.warn( - f"The '{output_type}' type is deprecated and will be removed in a future release.", + "The 'output_type' keyword argument is deprecated and will be removed in a future release. Use 'gtype' instead.", DeprecationWarning, stacklevel=2, ) - # if isinstance(output_type, type) and issubclass(output_type, Grid): - # gridtype = output_type - # elif output_type in (map_ := TYPE_MAP | TYPE_MAP_DEPRECATED): - # gridtype = map_[output_type] - # if output_type in TYPE_MAP_DEPRECATED: - # warnings.warn( - # f"The '{output_type}' type is deprecated and will be removed in a future release.", - # DeprecationWarning, - # stacklevel=2, - # ) - # else: - # raise ValueError("The given `output_type` is not valid.") - - if isinstance(output_type, type) and issubclass(output_type, Grid): - gridtype = output_type - elif output_type in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): - gridtype = type_map[output_type] + if gtype in TYPE_MAP_DEPRECATED: + warnings.warn( + f"The '{gtype}' type is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(gtype, type) and issubclass(gtype, Grid): + gridtype = gtype + elif gtype in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): + gridtype = type_map[gtype] else: - raise ValueError(f"Invalid `output_type`: {output_type!r}") + raise ValueError(f"Invalid `gtype`: {gtype!r}") _check_is_similar(grid1, grid2, exact_type=False) From 596cf0eda12c8c227185819926014303b889c056 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:05:12 +0200 Subject: [PATCH 08/19] comment --- src/waveresponse/_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 51ef533..35fc41a 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -74,6 +74,7 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): "wave_spectrum": WaveSpectrum, } + # Check for deprecated 'output_type' argument if "output_type" in kwargs: gtype = kwargs.pop("output_type") warnings.warn( From 12eca0ef1839ef11d8743a928b4ee24c69f742ec Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:07:34 +0200 Subject: [PATCH 09/19] small refactor --- src/waveresponse/_core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 35fc41a..eb7063a 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -90,12 +90,11 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): stacklevel=2, ) - if isinstance(gtype, type) and issubclass(gtype, Grid): - gridtype = gtype - elif gtype in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): - gridtype = type_map[gtype] - else: - raise ValueError(f"Invalid `gtype`: {gtype!r}") + if not (isinstance(gtype, type) and issubclass(gtype, Grid)): + if gtype in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): + gtype = type_map[gtype] + else: + raise ValueError(f"Invalid `gtype`: {gtype!r}") _check_is_similar(grid1, grid2, exact_type=False) @@ -113,7 +112,7 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): **convention, ) - return gridtype.from_grid(new) + return gtype.from_grid(new) def _cast_to_grid(grid): From 6992ea19a1b194bf8028dbfe40d0070e96f6236d Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:10:01 +0200 Subject: [PATCH 10/19] improve warning --- src/waveresponse/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index eb7063a..0737d63 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -85,7 +85,7 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): if gtype in TYPE_MAP_DEPRECATED: warnings.warn( - f"The '{gtype}' type is deprecated and will be removed in a future release.", + f"The '{gtype}' type is deprecated and will be removed in a future release. Use one of {set(TYPE_MAP.keys())} instead.", DeprecationWarning, stacklevel=2, ) From 391b08e5569556d1c98b4feb6ef042cde8fbe2a8 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:35:05 +0200 Subject: [PATCH 11/19] revert changes --- src/waveresponse/_core.py | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 0737d63..3fae93a 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, gtype="Grid", **kwargs): +def multiply(grid1, grid2, output_type="Grid"): """ Multiply values (element-wise). @@ -53,7 +53,7 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): Grid object. grid2 : obj Grid object. - gtype : grid-type, default 'Grid' + output_type : {'Grid', 'RAO', 'DirectionalSpectrum', 'WaveSpectrum', 'DirectionalBinSpectrum', 'WaveBinSpectrum'} Output grid type. """ @@ -64,37 +64,19 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): "DirectionalBinSpectrum": DirectionalBinSpectrum, "WaveSpectrum": WaveSpectrum, "WaveBinSpectrum": WaveBinSpectrum, - } - - # Deprecated types for backward compatibility - TYPE_MAP_DEPRECATED = { "grid": Grid, "rao": RAO, "directional_spectrum": DirectionalSpectrum, + "directional_bin_spectrum": DirectionalBinSpectrum, "wave_spectrum": WaveSpectrum, + "wave_bin_spectrum": WaveBinSpectrum, } - # Check for deprecated 'output_type' argument - if "output_type" in kwargs: - gtype = kwargs.pop("output_type") - warnings.warn( - "The 'output_type' keyword argument is deprecated and will be removed in a future release. Use 'gtype' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if gtype in TYPE_MAP_DEPRECATED: - warnings.warn( - f"The '{gtype}' type is deprecated and will be removed in a future release. Use one of {set(TYPE_MAP.keys())} instead.", - DeprecationWarning, - stacklevel=2, - ) - - if not (isinstance(gtype, type) and issubclass(gtype, Grid)): - if gtype in (type_map := TYPE_MAP | TYPE_MAP_DEPRECATED): - gtype = type_map[gtype] + if not (isinstance(output_type, type) and issubclass(output_type, Grid)): + if output_type in TYPE_MAP: + output_type = TYPE_MAP[output_type] else: - raise ValueError(f"Invalid `gtype`: {gtype!r}") + raise ValueError(f"Invalid `output_type`: {output_type!r}") _check_is_similar(grid1, grid2, exact_type=False) @@ -112,7 +94,7 @@ def multiply(grid1, grid2, gtype="Grid", **kwargs): **convention, ) - return gtype.from_grid(new) + return output_type.from_grid(new) def _cast_to_grid(grid): From d6ae7812ee3341796777570bedf247f883b52aaa Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:38:20 +0200 Subject: [PATCH 12/19] some changes --- src/waveresponse/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 3fae93a..b194d0f 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, output_type="Grid"): +def multiply(grid1, grid2, output_type="grid"): """ Multiply values (element-wise). From 236f2e284aee832f78f943f6f9e4d844a74427a1 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:40:20 +0200 Subject: [PATCH 13/19] small fix --- src/waveresponse/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index b194d0f..3fae93a 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, output_type="grid"): +def multiply(grid1, grid2, output_type="Grid"): """ Multiply values (element-wise). From e1f31171bf31f8ae00ff48a7aa4cee47920ab572 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 13:48:47 +0200 Subject: [PATCH 14/19] expand tests --- tests/test_core.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index fcc51e5..d597734 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -189,8 +189,9 @@ def test_rao_and_rao_to_default_grid(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_grid_and_grid_to_grid(self, grid): - out = wr.multiply(grid, grid.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_grid_and_grid_to_grid(self, output_type, grid): + out = wr.multiply(grid, grid.copy(), output_type=output_type) vals_expect = grid._vals * grid._vals @@ -203,8 +204,9 @@ def test_grid_and_grid_to_grid(self, grid): np.testing.assert_array_almost_equal(out._dirs, grid._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_wave_to_grid(self, rao, wave): - out = wr.multiply(rao, wave, output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_rao_and_wave_to_grid(self, output_type, rao, wave): + out = wr.multiply(rao, wave, output_type=output_type) vals_expect = rao._vals * wave._vals @@ -217,8 +219,9 @@ def test_rao_and_wave_to_grid(self, rao, wave): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_rao_to_rao(self, rao): - out = wr.multiply(rao, rao.copy(), output_type="rao") + @pytest.mark.parametrize("output_type", ("rao", "RAO", RAO)) + def test_rao_and_rao_to_rao(self, output_type, rao): + out = wr.multiply(rao, rao.copy(), output_type=output_type) vals_expect = rao._vals * rao._vals @@ -231,8 +234,9 @@ def test_rao_and_rao_to_rao(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_rao_to_grid(self, rao): - out = wr.multiply(rao, rao.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_rao_and_rao_to_grid(self, output_type, rao): + out = wr.multiply(rao, rao.copy(), output_type=output_type) vals_expect = rao._vals * rao._vals @@ -246,8 +250,9 @@ def test_rao_and_rao_to_grid(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_wave(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="wave_spectrum") + @pytest.mark.parametrize("output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum)) + def test_wave_and_wave_to_wave(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals @@ -260,8 +265,9 @@ def test_wave_and_wave_to_wave(self, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_dir_spectrum(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="directional_spectrum") + @pytest.mark.parametrize("output_type", ("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum)) + def test_wave_and_wave_to_dir_spectrum(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals @@ -275,8 +281,9 @@ def test_wave_and_wave_to_dir_spectrum(self, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_grid(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_wave_and_wave_to_grid(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals From 4dbb185a69b4c57e65b735a9cd94829a9625e3e6 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 14:02:20 +0200 Subject: [PATCH 15/19] test bin spectra --- tests/test_core.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index d597734..dbd90b9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -297,6 +297,78 @@ def test_wave_and_wave_to_grid(self, output_type, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin): + grid1 = wavebin.copy() + grid2 = wavebin.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = wavebin._vals * wavebin._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize("output_type", ("wave_bin_spectrum", "WaveBinSpectrum", WaveBinSpectrum)) + def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin): + grid1 = wavebin.copy() + grid2 = wavebin.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = wavebin._vals * wavebin._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_binspectrum_and_binspectrum_to_grid(self, output_type, directional_bin_spectrum): + grid1 = directional_bin_spectrum.copy() + grid2 = directional_bin_spectrum.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = grid1._vals * grid2._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize("output_type", ("directional_bin_spectrum", "DirectionalBinSpectrum", DirectionalBinSpectrum)) + def test_binspectrum_and_binspectrum_to_binspectrum(self, output_type, directional_bin_spectrum): + grid1 = directional_bin_spectrum.copy() + grid2 = directional_bin_spectrum.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = grid1._vals * grid2._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + def test_raises_output_type(self, grid): with pytest.raises(ValueError): wr.multiply(grid, grid.copy(), output_type="invalid-type") From 291fb4cd5a71e1500633ae5a823ec8599b80964a Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 14:02:43 +0200 Subject: [PATCH 16/19] black --- tests/test_core.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index dbd90b9..c368730 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -250,7 +250,9 @@ def test_rao_and_rao_to_grid(self, output_type, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - @pytest.mark.parametrize("output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum)) + @pytest.mark.parametrize( + "output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum) + ) def test_wave_and_wave_to_wave(self, output_type, wave): out = wr.multiply(wave, wave.copy(), output_type=output_type) @@ -265,7 +267,10 @@ def test_wave_and_wave_to_wave(self, output_type, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - @pytest.mark.parametrize("output_type", ("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum)) + @pytest.mark.parametrize( + "output_type", + ("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum), + ) def test_wave_and_wave_to_dir_spectrum(self, output_type, wave): out = wr.multiply(wave, wave.copy(), output_type=output_type) @@ -315,7 +320,9 @@ def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin): np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - @pytest.mark.parametrize("output_type", ("wave_bin_spectrum", "WaveBinSpectrum", WaveBinSpectrum)) + @pytest.mark.parametrize( + "output_type", ("wave_bin_spectrum", "WaveBinSpectrum", WaveBinSpectrum) + ) def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin): grid1 = wavebin.copy() grid2 = wavebin.copy() @@ -334,7 +341,9 @@ def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin): np.testing.assert_array_almost_equal(out._vals, vals_expect) @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) - def test_binspectrum_and_binspectrum_to_grid(self, output_type, directional_bin_spectrum): + def test_binspectrum_and_binspectrum_to_grid( + self, output_type, directional_bin_spectrum + ): grid1 = directional_bin_spectrum.copy() grid2 = directional_bin_spectrum.copy() out = wr.multiply(grid1, grid2, output_type=output_type) @@ -351,8 +360,13 @@ def test_binspectrum_and_binspectrum_to_grid(self, output_type, directional_bin_ np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - @pytest.mark.parametrize("output_type", ("directional_bin_spectrum", "DirectionalBinSpectrum", DirectionalBinSpectrum)) - def test_binspectrum_and_binspectrum_to_binspectrum(self, output_type, directional_bin_spectrum): + @pytest.mark.parametrize( + "output_type", + ("directional_bin_spectrum", "DirectionalBinSpectrum", DirectionalBinSpectrum), + ) + def test_binspectrum_and_binspectrum_to_binspectrum( + self, output_type, directional_bin_spectrum + ): grid1 = directional_bin_spectrum.copy() grid2 = directional_bin_spectrum.copy() out = wr.multiply(grid1, grid2, output_type=output_type) From a02d4dbb18c49773fbefe9b14a18048bb5d583cb Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 14:08:52 +0200 Subject: [PATCH 17/19] few changes --- src/waveresponse/_core.py | 10 ++++------ tests/test_core.py | 6 ++---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 3fae93a..69b3e3f 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -64,12 +64,10 @@ def multiply(grid1, grid2, output_type="Grid"): "DirectionalBinSpectrum": DirectionalBinSpectrum, "WaveSpectrum": WaveSpectrum, "WaveBinSpectrum": WaveBinSpectrum, - "grid": Grid, - "rao": RAO, - "directional_spectrum": DirectionalSpectrum, - "directional_bin_spectrum": DirectionalBinSpectrum, - "wave_spectrum": WaveSpectrum, - "wave_bin_spectrum": WaveBinSpectrum, + "grid": Grid, # for backward compatibility + "rao": RAO, # for backward compatibility + "directional_spectrum": DirectionalSpectrum, # for backward compatibility + "wave_spectrum": WaveSpectrum, # for backward compatibility } if not (isinstance(output_type, type) and issubclass(output_type, Grid)): diff --git a/tests/test_core.py b/tests/test_core.py index c368730..fd82005 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -320,9 +320,7 @@ def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin): np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - @pytest.mark.parametrize( - "output_type", ("wave_bin_spectrum", "WaveBinSpectrum", WaveBinSpectrum) - ) + @pytest.mark.parametrize("output_type", ("WaveBinSpectrum", WaveBinSpectrum)) def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin): grid1 = wavebin.copy() grid2 = wavebin.copy() @@ -362,7 +360,7 @@ def test_binspectrum_and_binspectrum_to_grid( @pytest.mark.parametrize( "output_type", - ("directional_bin_spectrum", "DirectionalBinSpectrum", DirectionalBinSpectrum), + ("DirectionalBinSpectrum", DirectionalBinSpectrum), ) def test_binspectrum_and_binspectrum_to_binspectrum( self, output_type, directional_bin_spectrum From add70cf9680a7796ddfe7fa6f1eca70abf0e6d90 Mon Sep 17 00:00:00 2001 From: "Vegard R. Solum" Date: Wed, 9 Apr 2025 14:12:38 +0200 Subject: [PATCH 18/19] small refactor --- src/waveresponse/_core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index 69b3e3f..fc47239 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -71,10 +71,9 @@ def multiply(grid1, grid2, output_type="Grid"): } if not (isinstance(output_type, type) and issubclass(output_type, Grid)): - if output_type in TYPE_MAP: - output_type = TYPE_MAP[output_type] - else: + if output_type not in TYPE_MAP: raise ValueError(f"Invalid `output_type`: {output_type!r}") + output_type = TYPE_MAP[output_type] _check_is_similar(grid1, grid2, exact_type=False) From b323fb5a65c8751fdf9589a43dac7c4b2d0be24a Mon Sep 17 00:00:00 2001 From: Ali Cetin Date: Wed, 9 Apr 2025 21:25:59 +0200 Subject: [PATCH 19/19] some simplification --- src/waveresponse/_core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index fc47239..24d3389 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -70,10 +70,10 @@ def multiply(grid1, grid2, output_type="Grid"): "wave_spectrum": WaveSpectrum, # for backward compatibility } - if not (isinstance(output_type, type) and issubclass(output_type, Grid)): - if output_type not in TYPE_MAP: - raise ValueError(f"Invalid `output_type`: {output_type!r}") - output_type = TYPE_MAP[output_type] + output_type_ = TYPE_MAP.get(output_type, output_type) + + if not (isinstance(output_type_, type) and issubclass(output_type_, Grid)): + raise ValueError(f"Invalid `output_type`: {output_type_!r}") _check_is_similar(grid1, grid2, exact_type=False) @@ -91,7 +91,7 @@ def multiply(grid1, grid2, output_type="Grid"): **convention, ) - return output_type.from_grid(new) + return output_type_.from_grid(new) def _cast_to_grid(grid):