diff --git a/src/waveresponse/_core.py b/src/waveresponse/_core.py index a9cf236..24d3389 100644 --- a/src/waveresponse/_core.py +++ b/src/waveresponse/_core.py @@ -43,7 +43,7 @@ def check_type(grid, grid_type): raise ValueError("Grid objects have different wave conventions.") -def multiply(grid1, grid2, output_type="grid"): +def multiply(grid1, grid2, output_type="Grid"): """ Multiply values (element-wise). @@ -53,23 +53,29 @@ def multiply(grid1, grid2, output_type="grid"): Grid object. grid2 : obj Grid object. - output_type : str {"grid", "rao", "directional_spectrum", "wave_spectrum"} + output_type : {'Grid', 'RAO', 'DirectionalSpectrum', 'WaveSpectrum', 'DirectionalBinSpectrum', 'WaveBinSpectrum'} Output grid type. """ TYPE_MAP = { - "grid": Grid, - "rao": RAO, - "directional_spectrum": DirectionalSpectrum, - "wave_spectrum": WaveSpectrum, + "Grid": Grid, + "RAO": RAO, + "DirectionalSpectrum": DirectionalSpectrum, + "DirectionalBinSpectrum": DirectionalBinSpectrum, + "WaveSpectrum": WaveSpectrum, + "WaveBinSpectrum": WaveBinSpectrum, + "grid": Grid, # for backward compatibility + "rao": RAO, # for backward compatibility + "directional_spectrum": DirectionalSpectrum, # for backward compatibility + "wave_spectrum": WaveSpectrum, # for backward compatibility } - if output_type not in TYPE_MAP: - raise ValueError("The given `output_type` is not valid.") + output_type_ = TYPE_MAP.get(output_type, output_type) - _check_is_similar(grid1, grid2, exact_type=False) + if not (isinstance(output_type_, type) and issubclass(output_type_, Grid)): + raise ValueError(f"Invalid `output_type`: {output_type_!r}") - type_ = TYPE_MAP.get(output_type) + _check_is_similar(grid1, grid2, exact_type=False) freq = grid1._freq dirs = grid1._dirs @@ -85,7 +91,7 @@ def multiply(grid1, grid2, output_type="grid"): **convention, ) - return type_.from_grid(new) + return output_type_.from_grid(new) def _cast_to_grid(grid): diff --git a/tests/test_core.py b/tests/test_core.py index fcc51e5..fd82005 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -189,8 +189,9 @@ def test_rao_and_rao_to_default_grid(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_grid_and_grid_to_grid(self, grid): - out = wr.multiply(grid, grid.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_grid_and_grid_to_grid(self, output_type, grid): + out = wr.multiply(grid, grid.copy(), output_type=output_type) vals_expect = grid._vals * grid._vals @@ -203,8 +204,9 @@ def test_grid_and_grid_to_grid(self, grid): np.testing.assert_array_almost_equal(out._dirs, grid._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_wave_to_grid(self, rao, wave): - out = wr.multiply(rao, wave, output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_rao_and_wave_to_grid(self, output_type, rao, wave): + out = wr.multiply(rao, wave, output_type=output_type) vals_expect = rao._vals * wave._vals @@ -217,8 +219,9 @@ def test_rao_and_wave_to_grid(self, rao, wave): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_rao_to_rao(self, rao): - out = wr.multiply(rao, rao.copy(), output_type="rao") + @pytest.mark.parametrize("output_type", ("rao", "RAO", RAO)) + def test_rao_and_rao_to_rao(self, output_type, rao): + out = wr.multiply(rao, rao.copy(), output_type=output_type) vals_expect = rao._vals * rao._vals @@ -231,8 +234,9 @@ def test_rao_and_rao_to_rao(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_rao_and_rao_to_grid(self, rao): - out = wr.multiply(rao, rao.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_rao_and_rao_to_grid(self, output_type, rao): + out = wr.multiply(rao, rao.copy(), output_type=output_type) vals_expect = rao._vals * rao._vals @@ -246,8 +250,11 @@ def test_rao_and_rao_to_grid(self, rao): np.testing.assert_array_almost_equal(out._dirs, rao._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_wave(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="wave_spectrum") + @pytest.mark.parametrize( + "output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum) + ) + def test_wave_and_wave_to_wave(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals @@ -260,8 +267,12 @@ def test_wave_and_wave_to_wave(self, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_dir_spectrum(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="directional_spectrum") + @pytest.mark.parametrize( + "output_type", + ("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum), + ) + def test_wave_and_wave_to_dir_spectrum(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals @@ -275,8 +286,9 @@ def test_wave_and_wave_to_dir_spectrum(self, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) - def test_wave_and_wave_to_grid(self, wave): - out = wr.multiply(wave, wave.copy(), output_type="grid") + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_wave_and_wave_to_grid(self, output_type, wave): + out = wr.multiply(wave, wave.copy(), output_type=output_type) vals_expect = wave._vals * wave._vals @@ -290,6 +302,85 @@ def test_wave_and_wave_to_grid(self, wave): np.testing.assert_array_almost_equal(out._dirs, wave._dirs) np.testing.assert_array_almost_equal(out._vals, vals_expect) + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin): + grid1 = wavebin.copy() + grid2 = wavebin.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = wavebin._vals * wavebin._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize("output_type", ("WaveBinSpectrum", WaveBinSpectrum)) + def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin): + grid1 = wavebin.copy() + grid2 = wavebin.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = wavebin._vals * wavebin._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize("output_type", ("grid", "Grid", Grid)) + def test_binspectrum_and_binspectrum_to_grid( + self, output_type, directional_bin_spectrum + ): + grid1 = directional_bin_spectrum.copy() + grid2 = directional_bin_spectrum.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = grid1._vals * grid2._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + + @pytest.mark.parametrize( + "output_type", + ("DirectionalBinSpectrum", DirectionalBinSpectrum), + ) + def test_binspectrum_and_binspectrum_to_binspectrum( + self, output_type, directional_bin_spectrum + ): + grid1 = directional_bin_spectrum.copy() + grid2 = directional_bin_spectrum.copy() + out = wr.multiply(grid1, grid2, output_type=output_type) + + vals_expect = grid1._vals * grid2._vals + + assert isinstance(out, Grid) + assert not isinstance(out, WaveSpectrum) + assert out._freq_hz is False + assert out._degrees is False + assert out._clockwise == grid1._clockwise + assert out._waves_coming_from == grid1._waves_coming_from + np.testing.assert_array_almost_equal(out._freq, grid1._freq) + np.testing.assert_array_almost_equal(out._dirs, grid1._dirs) + np.testing.assert_array_almost_equal(out._vals, vals_expect) + def test_raises_output_type(self, grid): with pytest.raises(ValueError): wr.multiply(grid, grid.copy(), output_type="invalid-type")