From b7e3de4e9c3513393f27d08a6f2ca5616b93c585 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Thu, 1 Jan 2026 14:15:53 +0100 Subject: [PATCH] TST: remove usages of deprecated `numpy.testing.assert_warns` Deprecated in numpy 2.4.0. Also try `pytest.raises` instead of `assert_raises` in one file; both should work, the pytest builtin functionality will be preferred for new code. --- pywt/tests/test_cwt_wavelets.py | 4 +- pywt/tests/test_deprecations.py | 24 ++++++---- pywt/tests/test_multilevel.py | 10 ++-- pywt/tests/test_swt.py | 81 ++++++++++++++++++++------------- 4 files changed, 73 insertions(+), 46 deletions(-) diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index f142404ce..9925a6971 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -10,7 +10,6 @@ assert_almost_equal, assert_equal, assert_raises, - assert_warns, ) import pywt @@ -341,7 +340,8 @@ def test_cwt_parameters_in_names(): for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]: for name in ['fbsp', 'cmor', 'shan']: # additional parameters should be specified within the name - assert_warns(FutureWarning, func, name) + with pytest.warns(FutureWarning): + func(name) for name in ['cmor', 'shan']: # valid names diff --git a/pywt/tests/test_deprecations.py b/pywt/tests/test_deprecations.py index aedaaa867..c8f85fdeb 100644 --- a/pywt/tests/test_deprecations.py +++ b/pywt/tests/test_deprecations.py @@ -1,34 +1,40 @@ import warnings import numpy as np -from numpy.testing import assert_array_equal, assert_warns +import pytest +from numpy.testing import assert_array_equal import pywt def test_intwave_deprecation(): wavelet = pywt.Wavelet('db3') - assert_warns(DeprecationWarning, pywt.intwave, wavelet) + with pytest.warns(DeprecationWarning): + pywt.intwave(wavelet) def test_centrfrq_deprecation(): wavelet = pywt.Wavelet('db3') - assert_warns(DeprecationWarning, pywt.centrfrq, wavelet) + with pytest.warns(DeprecationWarning): + pywt.centrfrq(wavelet) def test_scal2frq_deprecation(): wavelet = pywt.Wavelet('db3') - assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1) + with pytest.warns(DeprecationWarning): + pywt.scal2frq(wavelet, 1) def test_orthfilt_deprecation(): - assert_warns(DeprecationWarning, pywt.orthfilt, range(6)) + with pytest.warns(DeprecationWarning): + pywt.orthfilt(range(6)) def test_integrate_wave_tuple(): sig = [0, 1, 2, 3] xgrid = [0, 1, 2, 3] - assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid)) + with pytest.warns(DeprecationWarning): + pywt.integrate_wavelet((sig, xgrid)) old_modes = ['zpd', @@ -42,7 +48,8 @@ def test_integrate_wave_tuple(): def test_MODES_from_object_deprecation(): for mode in old_modes: - assert_warns(DeprecationWarning, pywt.Modes.from_object, mode) + with pytest.warns(DeprecationWarning): + pywt.Modes.from_object(mode) def test_MODES_attributes_deprecation(): @@ -50,7 +57,8 @@ def get_mode(Modes, name): return getattr(Modes, name) for mode in old_modes: - assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode) + with pytest.warns(DeprecationWarning): + get_mode(pywt.Modes, mode) def test_mode_equivalence(): diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py index cbfea564d..ea85e88ac 100644 --- a/pywt/tests/test_multilevel.py +++ b/pywt/tests/test_multilevel.py @@ -14,7 +14,6 @@ assert_equal, assert_raises, assert_raises_regex, - assert_warns, ) import pywt @@ -899,8 +898,9 @@ def test_fswavedecn_fswaverecn_variable_levels(): assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1, 1, 1)) # levels too large for array size - assert_warns(UserWarning, pywt.fswavedecn, data, 'haar', - levels=int(np.log2(np.min(data.shape)))+1) + with pytest.warns(UserWarning): + pywt.fswavedecn(data, 'haar', + levels=int(np.log2(np.min(data.shape)))+1) def test_fswavedecn_fswaverecn_variable_wavelets_and_modes(): @@ -967,8 +967,8 @@ def test_fswavedecnresult(): k, np.zeros(tuple([s + 1 for s in d.shape]))) # warns on assigning with a non-matching dtype - assert_warns(UserWarning, result.__setitem__, - k, np.zeros_like(d).astype(np.float32)) + with pytest.warns(UserWarning): + result.__setitem__(k, np.zeros_like(d).astype(np.float32)) # all coefficients are stacked into result.coeffs (same ndim) assert_equal(result.coeffs.ndim, data.ndim) diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py index 3eaa9426d..98e722ae2 100644 --- a/pywt/tests/test_swt.py +++ b/pywt/tests/test_swt.py @@ -12,8 +12,6 @@ assert_allclose, assert_array_equal, assert_equal, - assert_raises, - assert_warns, ) import pywt @@ -69,7 +67,9 @@ def test_swt_decomposition(): def test_swt_max_level(): # odd sized signal will warn about no levels of decomposition possible - assert_warns(UserWarning, pywt.swt_max_level, 11) + with pytest.warns(UserWarning): + pywt.swt_max_level(11) + with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) assert_equal(pywt.swt_max_level(11), 0) @@ -134,7 +134,8 @@ def test_swt_axis(): assert_array_equal(row, cD2) # axis too large - assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5) + with pytest.raises(ValueError): + pywt.swt(x, db1, level=2, axis=5) def test_swt_iswt_integration(): @@ -217,9 +218,8 @@ def test_swt_default_level_by_axis(): def test_swt2_ndim_error(): x = np.ones(8) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', FutureWarning) - assert_raises(ValueError, pywt.swt2, x, 'haar', level=1) + with pytest.raises(ValueError): + pywt.swt2(x, 'haar', level=1) @pytest.mark.slow @@ -298,10 +298,12 @@ def test_swt2_axes(): assert_allclose(X, r2, atol=atol) # duplicate axes not allowed - assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, - axes=(0, 0)) + with pytest.raises(ValueError): + pywt.swt2(X, current_wavelet, 1, axes=(0, 0)) + # too few axes - assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, )) + with pytest.raises(ValueError): + pywt.swt2(X, current_wavelet, 1, axes=(0, )) def test_swtn_axes(): @@ -325,21 +327,24 @@ def test_swtn_axes(): assert_equal(empty, []) # duplicate axes not allowed - assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0)) + with pytest.raises(ValueError): + pywt.swtn(X, current_wavelet, 1, axes=(0, 0)) # data.ndim = 0 - assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1) + with pytest.raises(ValueError): + pywt.swtn(np.asarray([]), current_wavelet, 1) # start_level too large - assert_raises(ValueError, pywt.swtn, X, current_wavelet, - level=1, start_level=2) + with pytest.raises(ValueError): + pywt.swtn(X, current_wavelet, level=1, start_level=2) # level < 1 in swt_axis call - assert_raises(ValueError, swt_axis, X, current_wavelet, level=0, - start_level=0) + with pytest.raises(ValueError): + swt_axis(X, current_wavelet, level=0, start_level=0) + # odd-sized data not allowed - assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0, - start_level=0, axis=0) + with pytest.raises(ValueError): + swt_axis( X[-1, :], current_wavelet, level=0, start_level=0, axis=0) @pytest.mark.slow @@ -401,12 +406,17 @@ def test_iswtn_errors(): coeffs = pywt.swtn(x, w, max_level, axes=axes) # more axes than dimensions transformed - assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2)) + with pytest.raises(ValueError): + pywt.iswtn(coeffs, w, axes=(0, 1, 2)) + # duplicate axes not allowed - assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0)) + with pytest.raises(ValueError): + pywt.iswtn(coeffs, w, axes=(0, 0)) + # mismatched coefficient size coeffs[0]['da'] = coeffs[0]['da'][:-1, :] - assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes) + with pytest.raises(RuntimeError): + pywt.iswtn(coeffs, w, axes=axes) def test_swtn_iswtn_unique_shape_per_axis(): @@ -441,8 +451,11 @@ def test_per_axis_wavelets(): assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14) # length of wavelets doesn't match the length of axes - assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level) - assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2]) + with pytest.raises(ValueError): + pywt.swtn(data, wavelets[:2], level) + + with pytest.raises(ValueError): + pywt.iswtn(coefs, wavelets[:2]) with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) @@ -458,11 +471,12 @@ def test_error_on_continuous_wavelet(): for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn], [pywt.iswt, pywt.iswt2, pywt.iswtn]): for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]: - assert_raises(ValueError, dec_func, data, wavelet=cwave, - level=3) + with pytest.raises(ValueError): + dec_func(data, wavelet=cwave, level=3) c = dec_func(data, 'db1', level=3) - assert_raises(ValueError, rec_func, c, wavelet=cwave) + with pytest.raises(ValueError): + rec_func(c, wavelet=cwave) def test_iswt_mixed_dtypes(): @@ -552,11 +566,13 @@ def test_iswtn_mixed_dtypes(): def test_swt_zero_size_axes(): # raise on empty input array - assert_raises(ValueError, pywt.swt, [], 'db2') + with pytest.raises(ValueError): + pywt.swt([], 'db2') # >1D case uses a different code path so check there as well x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis - assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,)) + with pytest.raises(ValueError): + pywt.swtn(x, 'db2', level=1, axes=(0,)) def test_swt_variance_and_energy_preservation(): @@ -575,7 +591,8 @@ def test_swt_variance_and_energy_preservation(): np.linalg.norm(np.concatenate(coeffs))) # non-orthogonal wavelet with norm=True raises a warning - assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True) + with pytest.warns(UserWarning): + pywt.swt(x, 'bior2.2', norm=True) def test_swt2_variance_and_energy_preservation(): @@ -598,7 +615,8 @@ def test_swt2_variance_and_energy_preservation(): np.linalg.norm(np.concatenate(coeff_list))) # non-orthogonal wavelet with norm=True raises a warning - assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True) + with pytest.warns(UserWarning): + pywt.swt2(x, 'bior2.2', level=4, norm=True) def test_swtn_variance_and_energy_preservation(): @@ -621,7 +639,8 @@ def test_swtn_variance_and_energy_preservation(): np.linalg.norm(np.concatenate(coeff_list))) # non-orthogonal wavelet with norm=True raises a warning - assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True) + with pytest.warns(UserWarning): + pywt.swtn(x, 'bior2.2', level=4, norm=True) def test_swt_ravel_and_unravel():