Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/waveresponse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_type(grid, grid_type):
raise ValueError("Grid objects have different wave conventions.")


def multiply(grid1, grid2, output_type="grid"):
def multiply(grid1, grid2, output_type="Grid"):
"""
Multiply values (element-wise).

Expand All @@ -53,23 +53,29 @@ def multiply(grid1, grid2, output_type="grid"):
Grid object.
grid2 : obj
Grid object.
output_type : str {"grid", "rao", "directional_spectrum", "wave_spectrum"}
output_type : {'Grid', 'RAO', 'DirectionalSpectrum', 'WaveSpectrum', 'DirectionalBinSpectrum', 'WaveBinSpectrum'}
Output grid type.
"""

TYPE_MAP = {
"grid": Grid,
"rao": RAO,
"directional_spectrum": DirectionalSpectrum,
"wave_spectrum": WaveSpectrum,
"Grid": Grid,
"RAO": RAO,
"DirectionalSpectrum": DirectionalSpectrum,
"DirectionalBinSpectrum": DirectionalBinSpectrum,
"WaveSpectrum": WaveSpectrum,
"WaveBinSpectrum": WaveBinSpectrum,
"grid": Grid, # for backward compatibility
"rao": RAO, # for backward compatibility
"directional_spectrum": DirectionalSpectrum, # for backward compatibility
"wave_spectrum": WaveSpectrum, # for backward compatibility
}

if output_type not in TYPE_MAP:
raise ValueError("The given `output_type` is not valid.")
output_type_ = TYPE_MAP.get(output_type, output_type)

_check_is_similar(grid1, grid2, exact_type=False)
if not (isinstance(output_type_, type) and issubclass(output_type_, Grid)):
raise ValueError(f"Invalid `output_type`: {output_type_!r}")

type_ = TYPE_MAP.get(output_type)
_check_is_similar(grid1, grid2, exact_type=False)

freq = grid1._freq
dirs = grid1._dirs
Expand All @@ -85,7 +91,7 @@ def multiply(grid1, grid2, output_type="grid"):
**convention,
)

return type_.from_grid(new)
return output_type_.from_grid(new)


def _cast_to_grid(grid):
Expand Down
119 changes: 105 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def test_rao_and_rao_to_default_grid(self, rao):
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_grid_and_grid_to_grid(self, grid):
out = wr.multiply(grid, grid.copy(), output_type="grid")
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_grid_and_grid_to_grid(self, output_type, grid):
out = wr.multiply(grid, grid.copy(), output_type=output_type)

vals_expect = grid._vals * grid._vals

Expand All @@ -203,8 +204,9 @@ def test_grid_and_grid_to_grid(self, grid):
np.testing.assert_array_almost_equal(out._dirs, grid._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_rao_and_wave_to_grid(self, rao, wave):
out = wr.multiply(rao, wave, output_type="grid")
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_rao_and_wave_to_grid(self, output_type, rao, wave):
out = wr.multiply(rao, wave, output_type=output_type)

vals_expect = rao._vals * wave._vals

Expand All @@ -217,8 +219,9 @@ def test_rao_and_wave_to_grid(self, rao, wave):
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_rao_and_rao_to_rao(self, rao):
out = wr.multiply(rao, rao.copy(), output_type="rao")
@pytest.mark.parametrize("output_type", ("rao", "RAO", RAO))
def test_rao_and_rao_to_rao(self, output_type, rao):
out = wr.multiply(rao, rao.copy(), output_type=output_type)

vals_expect = rao._vals * rao._vals

Expand All @@ -231,8 +234,9 @@ def test_rao_and_rao_to_rao(self, rao):
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_rao_and_rao_to_grid(self, rao):
out = wr.multiply(rao, rao.copy(), output_type="grid")
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_rao_and_rao_to_grid(self, output_type, rao):
out = wr.multiply(rao, rao.copy(), output_type=output_type)

vals_expect = rao._vals * rao._vals

Expand All @@ -246,8 +250,11 @@ def test_rao_and_rao_to_grid(self, rao):
np.testing.assert_array_almost_equal(out._dirs, rao._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_wave_and_wave_to_wave(self, wave):
out = wr.multiply(wave, wave.copy(), output_type="wave_spectrum")
@pytest.mark.parametrize(
"output_type", ("wave_spectrum", "WaveSpectrum", WaveSpectrum)
)
def test_wave_and_wave_to_wave(self, output_type, wave):
out = wr.multiply(wave, wave.copy(), output_type=output_type)

vals_expect = wave._vals * wave._vals

Expand All @@ -260,8 +267,12 @@ def test_wave_and_wave_to_wave(self, wave):
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_wave_and_wave_to_dir_spectrum(self, wave):
out = wr.multiply(wave, wave.copy(), output_type="directional_spectrum")
@pytest.mark.parametrize(
"output_type",
("directional_spectrum", "DirectionalSpectrum", DirectionalSpectrum),
)
def test_wave_and_wave_to_dir_spectrum(self, output_type, wave):
out = wr.multiply(wave, wave.copy(), output_type=output_type)

vals_expect = wave._vals * wave._vals

Expand All @@ -275,8 +286,9 @@ def test_wave_and_wave_to_dir_spectrum(self, wave):
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_wave_and_wave_to_grid(self, wave):
out = wr.multiply(wave, wave.copy(), output_type="grid")
@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_wave_and_wave_to_grid(self, output_type, wave):
out = wr.multiply(wave, wave.copy(), output_type=output_type)

vals_expect = wave._vals * wave._vals

Expand All @@ -290,6 +302,85 @@ def test_wave_and_wave_to_grid(self, wave):
np.testing.assert_array_almost_equal(out._dirs, wave._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_wavebin_and_wavebin_to_grid(self, output_type, wavebin):
grid1 = wavebin.copy()
grid2 = wavebin.copy()
out = wr.multiply(grid1, grid2, output_type=output_type)

vals_expect = wavebin._vals * wavebin._vals

assert isinstance(out, Grid)
assert not isinstance(out, WaveSpectrum)
assert out._freq_hz is False
assert out._degrees is False
assert out._clockwise == grid1._clockwise
assert out._waves_coming_from == grid1._waves_coming_from
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

@pytest.mark.parametrize("output_type", ("WaveBinSpectrum", WaveBinSpectrum))
def test_wavebin_and_wavebin_to_wavebin(self, output_type, wavebin):
grid1 = wavebin.copy()
grid2 = wavebin.copy()
out = wr.multiply(grid1, grid2, output_type=output_type)

vals_expect = wavebin._vals * wavebin._vals

assert isinstance(out, Grid)
assert not isinstance(out, WaveSpectrum)
assert out._freq_hz is False
assert out._degrees is False
assert out._clockwise == grid1._clockwise
assert out._waves_coming_from == grid1._waves_coming_from
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

@pytest.mark.parametrize("output_type", ("grid", "Grid", Grid))
def test_binspectrum_and_binspectrum_to_grid(
self, output_type, directional_bin_spectrum
):
grid1 = directional_bin_spectrum.copy()
grid2 = directional_bin_spectrum.copy()
out = wr.multiply(grid1, grid2, output_type=output_type)

vals_expect = grid1._vals * grid2._vals

assert isinstance(out, Grid)
assert not isinstance(out, WaveSpectrum)
assert out._freq_hz is False
assert out._degrees is False
assert out._clockwise == grid1._clockwise
assert out._waves_coming_from == grid1._waves_coming_from
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

@pytest.mark.parametrize(
"output_type",
("DirectionalBinSpectrum", DirectionalBinSpectrum),
)
def test_binspectrum_and_binspectrum_to_binspectrum(
self, output_type, directional_bin_spectrum
):
grid1 = directional_bin_spectrum.copy()
grid2 = directional_bin_spectrum.copy()
out = wr.multiply(grid1, grid2, output_type=output_type)

vals_expect = grid1._vals * grid2._vals

assert isinstance(out, Grid)
assert not isinstance(out, WaveSpectrum)
assert out._freq_hz is False
assert out._degrees is False
assert out._clockwise == grid1._clockwise
assert out._waves_coming_from == grid1._waves_coming_from
np.testing.assert_array_almost_equal(out._freq, grid1._freq)
np.testing.assert_array_almost_equal(out._dirs, grid1._dirs)
np.testing.assert_array_almost_equal(out._vals, vals_expect)

def test_raises_output_type(self, grid):
with pytest.raises(ValueError):
wr.multiply(grid, grid.copy(), output_type="invalid-type")
Expand Down