From 74200169a378d173afe9038b5a08f1e1ec550925 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Wed, 8 Apr 2026 14:31:09 -0500 Subject: [PATCH 1/3] TST: Add nonzero-field radial velocity tests for JAX and TensorFlow engines - Add _make_radvel_inputs() helper for building single-radar test inputs - Add test_calculate_rad_velocity_cost_nonzero_jax: verifies JAX cost function and gradient match numpy reference with float32 tolerance - Add test_calculate_rad_velocity_cost_nonzero_tf: same for TensorFlow Co-Authored-By: Claude Sonnet 4.6 --- pydda/tests/test_cost_functions.py | 103 +++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 8c6c6ed4..ec1bc1b1 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -22,6 +22,23 @@ JAX_AVAILABLE = False +def _make_radvel_inputs(): + """Build single-radar inputs for radial velocity cost function tests.""" + Grid = pyart.testing.make_empty_grid( + (20, 20, 20), ((0, 10000), (-10000, 10000), (-10000, 10000)) + ) + rng = np.random.default_rng(0) + fdata = rng.random((20, 20, 20)) + Grid.fields["vel_field"] = {"data": np.ma.array(fdata), "units": "m/s"} + Grid = pydda.io.read_from_pyart_grid(Grid) + vrs = [np.ma.array(Grid["vel_field"].values).squeeze()] + azs = [np.array(Grid["AZ"].values).squeeze()] + els = [np.array(Grid["EL"].values).squeeze()] + wts = [np.ma.zeros((20, 20, 20))] + weights = [np.ones((20, 20, 20))] + return vrs, azs, els, wts, weights + + def test_calculate_rad_velocity_cost(): Grid = pyart.testing.make_empty_grid( (20, 20, 20), ((0, 10000), (-10000, 10000), (-10000, 10000)) @@ -84,6 +101,49 @@ def test_calculate_rad_velocity_cost_jax(): assert np.all(grad == 0) +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_rad_velocity_cost_nonzero_jax(): + """Nonzero wind field produces nonzero cost and gradient matching numpy.""" + vrs, azs, els, wts, weights = _make_radvel_inputs() + rng = np.random.default_rng(42) + u = rng.random((20, 20, 20)) + v = rng.random((20, 20, 20)) + w = rng.random((20, 20, 20)) + rmsVr = 1.0 + + numpy_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights + ) + jax_cost = pydda.cost_functions.jax.calculate_radial_vel_cost_function( + [np.array(vrs[0])], + azs, + els, + jnp.array(u), + jnp.array(v), + jnp.array(w), + [jnp.array(np.array(wts[0]))], + rmsVr, + jnp.array(weights), + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=1e-5) + + numpy_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr + ) + jax_grad = pydda.cost_functions.jax.calculate_grad_radial_vel( + [np.array(vrs[0])], + els, + azs, + jnp.array(u), + jnp.array(v), + jnp.array(w), + [jnp.array(np.array(wts[0]))], + jnp.array(weights), + rmsVr, + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + @pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") def test_calculate_rad_velocity_cost_tf(): """Test with a zero velocity field radar""" @@ -116,6 +176,49 @@ def test_calculate_rad_velocity_cost_tf(): assert np.all(grad.numpy() == 0) +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_rad_velocity_cost_nonzero_tf(): + """Nonzero wind field produces nonzero cost and gradient matching numpy.""" + vrs, azs, els, wts, weights = _make_radvel_inputs() + rng = np.random.default_rng(42) + u = rng.random((20, 20, 20)) + v = rng.random((20, 20, 20)) + w = rng.random((20, 20, 20)) + rmsVr = 1.0 + + numpy_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights + ) + tf_cost = pydda.cost_functions.tf.calculate_radial_vel_cost_function( + [tf.constant(np.array(vrs[0]), dtype=tf.float32)], + [tf.constant(np.array(azs[0]), dtype=tf.float32)], + [tf.constant(np.array(els[0]), dtype=tf.float32)], + tf.constant(u.astype(np.float32), dtype=tf.float32), + tf.constant(v.astype(np.float32), dtype=tf.float32), + tf.constant(w.astype(np.float32), dtype=tf.float32), + [tf.constant(np.array(wts[0]), dtype=tf.float32)], + rmsVr, + [tf.constant(weights[0].astype(np.float32), dtype=tf.float32)], + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr + ) + tf_grad = pydda.cost_functions.tf.calculate_grad_radial_vel( + [tf.constant(np.array(vrs[0]), dtype=tf.float32)], + [tf.constant(np.array(els[0]), dtype=tf.float32)], + [tf.constant(np.array(azs[0]), dtype=tf.float32)], + tf.constant(u.astype(np.float32), dtype=tf.float32), + tf.constant(v.astype(np.float32), dtype=tf.float32), + tf.constant(w.astype(np.float32), dtype=tf.float32), + [tf.constant(np.array(wts[0]), dtype=tf.float32)], + [tf.constant(weights[0].astype(np.float32), dtype=tf.float32)], + rmsVr, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + def test_calculate_fall_speed(): """Check to see if fall speeds are realistic""" ref_field = 10 * np.ones((10, 100, 100)) From 402c6782859827452713c1c79a30128ca7e89644 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Wed, 8 Apr 2026 14:45:51 -0500 Subject: [PATCH 2/3] TST: Add nonzero cross-engine comparison tests for all cost functions For each cost function (smoothness, mass continuity, background, vertical vorticity, model, point), adds nonzero-field tests for both JAX and TensorFlow engines that verify: - Cost values match the numpy reference within float32 tolerance - Gradients are nonzero (or match numpy where implementations agree) Notable findings documented in test comments: - Smoothness gradient: numpy omits Cx/Cy/Cz coefficients and uses scipy.ndimage.laplace; cross-engine comparison is not meaningful - Mass continuity JAX gradient: jax.vjp differs from numpy's analytic adjoint at grid boundaries; cost comparison is used instead - Vertical vorticity TF: default coeff=1 vs numpy/JAX default coeff=1e-5; explicit coeff=1e-5 is now passed in tests Co-Authored-By: Claude Sonnet 4.6 --- pydda/tests/test_cost_functions.py | 409 +++++++++++++++++++++++++++++ 1 file changed, 409 insertions(+) diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index ec1bc1b1..36260375 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -827,3 +827,412 @@ def test_model_cost_tf(): u, v, w, weights, u - 1, v - 1, w ) assert cost2 > cost1 + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_smoothness_cost_nonzero_jax(): + """Nonzero field: JAX smoothness cost matches numpy; gradient is nonzero. + + Note: the numpy smoothness gradient omits Cx/Cy/Cz and uses a different + Laplacian operator (scipy.ndimage) so cross-engine gradient comparison is + not meaningful. Each engine's gradient is tested for being nonzero instead. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + + numpy_cost = pydda.cost_functions.calculate_smoothness_cost(u, v, w, dx, dy, dz) + jax_cost = pydda.cost_functions.jax.calculate_smoothness_cost(u, v, w, dx, dy, dz) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + jax_grad = pydda.cost_functions.jax.calculate_smoothness_gradient( + u, v, w, dx, dy, dz + ) + assert np.any(np.array(jax_grad) != 0) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_smoothness_cost_nonzero_tf(): + """Nonzero field: TensorFlow smoothness cost matches numpy; gradient is nonzero. + + Note: the numpy smoothness gradient omits Cx/Cy/Cz and uses a different + Laplacian operator (scipy.ndimage) so cross-engine gradient comparison is + not meaningful. Each engine's gradient is tested for being nonzero instead. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + + numpy_cost = pydda.cost_functions.calculate_smoothness_cost(u, v, w, dx, dy, dz) + tf_cost = pydda.cost_functions.tf.calculate_smoothness_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + tf_grad = pydda.cost_functions.tf.calculate_smoothness_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + ) + assert np.any(np.array(tf_grad) != 0) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_mass_continuity_nonzero_jax(): + """Nonzero divergent field: JAX mass continuity cost matches numpy; gradient is nonzero. + + Note: the numpy gradient uses an analytic adjoint formula while JAX uses + jax.vjp; the discrete transpose differs at grid boundaries, so cross-engine + gradient comparison is not used here. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + z = np.arange(0, 1000.0, 100) + + numpy_cost = pydda.cost_functions.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + jax_cost = pydda.cost_functions.jax.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + jax_grad = pydda.cost_functions.jax.calculate_mass_continuity_gradient( + u, v, w, z, dx, dy, dz, anel=0 + ) + assert np.any(np.array(jax_grad) != 0) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_mass_continuity_nonzero_tf(): + """Nonzero divergent field: TensorFlow mass continuity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + z = np.arange(0, 1000.0, 100) + + numpy_cost = pydda.cost_functions.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + tf_cost = pydda.cost_functions.tf.calculate_mass_continuity( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(z.astype(np.float32)), + dx, + dy, + dz, + anel=0, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_mass_continuity_gradient( + u, v, w, z, dx, dy, dz, anel=0 + ) + tf_grad = pydda.cost_functions.tf.calculate_mass_continuity_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(z.astype(np.float32)), + dx, + dy, + dz, + anel=0, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_background_cost_nonzero_jax(): + """Nonzero background mismatch: JAX background cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + u_back = 5.0 * np.ones(10) + v_back = 3.0 * np.ones(10) + + numpy_cost = pydda.cost_functions.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + jax_cost = pydda.cost_functions.jax.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + jax_grad = pydda.cost_functions.jax.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_background_cost_nonzero_tf(): + """Nonzero background mismatch: TF background cost and gradient match numpy. + + Note: TF background functions omit the w parameter. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + u_back = 5.0 * np.ones(10) + v_back = 3.0 * np.ones(10) + + numpy_cost = pydda.cost_functions.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + tf_cost = pydda.cost_functions.tf.calculate_background_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_back.astype(np.float32)), + tf.constant(v_back.astype(np.float32)), + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + tf_grad = pydda.cost_functions.tf.calculate_background_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_back.astype(np.float32)), + tf.constant(v_back.astype(np.float32)), + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_vert_vorticity_nonzero_jax(): + """Nonzero rotating field: JAX vorticity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + Ut, Vt = 10.0, 10.0 + + numpy_cost = pydda.cost_functions.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt + ) + jax_cost = pydda.cost_functions.jax.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt + ) + jax_grad = pydda.cost_functions.jax.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_vert_vorticity_nonzero_tf(): + """Nonzero rotating field: TensorFlow vorticity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + Ut, Vt = 10.0, 10.0 + + # TF default coeff=1 differs from numpy/JAX default coeff=1e-5; pass explicitly + numpy_cost = pydda.cost_functions.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt, coeff=1e-5 + ) + tf_cost = pydda.cost_functions.tf.calculate_vertical_vorticity_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + Ut, + Vt, + coeff=1e-5, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt, coeff=1e-5 + ) + tf_grad = pydda.cost_functions.tf.calculate_vertical_vorticity_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + Ut, + Vt, + coeff=1e-5, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_model_cost_nonzero_jax(): + """Nonzero model mismatch: JAX model cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + u_model = 5.0 * np.ones((10, 10, 10)) + v_model = 3.0 * np.ones((10, 10, 10)) + w_model = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + + numpy_cost = pydda.cost_functions.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + jax_cost = pydda.cost_functions.jax.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + jax_grad = pydda.cost_functions.jax.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_model_cost_nonzero_tf(): + """Nonzero model mismatch: TensorFlow model cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + u_model = 5.0 * np.ones((10, 10, 10)) + v_model = 3.0 * np.ones((10, 10, 10)) + w_model = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + + numpy_cost = pydda.cost_functions.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + tf_cost = pydda.cost_functions.tf.calculate_model_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_model.astype(np.float32)), + tf.constant(v_model.astype(np.float32)), + tf.constant(w_model.astype(np.float32)), + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + tf_grad = pydda.cost_functions.tf.calculate_model_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_model.astype(np.float32)), + tf.constant(v_model.astype(np.float32)), + tf.constant(w_model.astype(np.float32)), + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_point_cost_nonzero_jax(): + """Nonzero point observation mismatch: JAX point cost and gradient match numpy.""" + u = np.ones((10, 10, 10)) + v = np.ones((10, 10, 10)) + x_1d = np.linspace(-10, 10, 10) + y_1d = np.linspace(-10, 10, 10) + z_1d = np.linspace(-10, 10, 10) + x, y, z = np.meshgrid(x_1d, y_1d, z_1d) + my_point = {"x": 0, "y": 0, "z": 0, "u": 3.0, "v": 3.0, "w": 0.0} + + numpy_cost = pydda.cost_functions.calculate_point_cost( + u, v, x, y, z, [my_point], roi=2.0 + ) + jax_cost = pydda.cost_functions.jax.calculate_point_cost( + u, v, x, y, z, [my_point], roi=2.0 + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_point_gradient( + u, v, x, y, z, [my_point], roi=2.0 + ) + jax_grad = pydda.cost_functions.jax.calculate_point_gradient( + u, v, x, y, z, [my_point], roi=2.0 + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_point_cost_nonzero_tf(): + """Nonzero point observation mismatch: TensorFlow point cost and gradient match numpy.""" + u = np.ones((10, 10, 10)) + v = np.ones((10, 10, 10)) + x_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + y_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + z_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + x_tf, y_tf, z_tf = tf.meshgrid(x_1d, y_1d, z_1d) + x_np, y_np, z_np = np.meshgrid( + np.linspace(-10, 10, 10), np.linspace(-10, 10, 10), np.linspace(-10, 10, 10) + ) + my_point = {"x": 0, "y": 0, "z": 0, "u": 3.0, "v": 3.0, "w": 0.0} + + numpy_cost = pydda.cost_functions.calculate_point_cost( + u, v, x_np, y_np, z_np, [my_point], roi=2.0 + ) + tf_cost = pydda.cost_functions.tf.calculate_point_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + x_tf, + y_tf, + z_tf, + [my_point], + roi=2.0, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_point_gradient( + u, v, x_np, y_np, z_np, [my_point], roi=2.0 + ) + tf_grad = pydda.cost_functions.tf.calculate_point_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + x_tf, + y_tf, + z_tf, + [my_point], + roi=2.0, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) From fa4bcd9d47c4127cdc27fa24da5f111f5b6c6469 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Wed, 8 Apr 2026 14:53:17 -0500 Subject: [PATCH 3/3] FIX: Correct numpy smoothness gradient to include Cx/Cy/Cz coefficients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation used scipy.ndimage.laplace (biharmonic operator) without the Cx, Cy, Cz weighting coefficients, producing a gradient inconsistent with the actual cost function. Replace with the analytic gradient of the cost: grad = 2*Cx * d²fx/dx² + 2*Cy * d²fy/dy² + 2*Cz * d²fz/dz² where fx/fy/fz are the combined second-derivative terms from the cost. u, v, w share the same gradient since the cost treats them symmetrically. Also remove unused scipy and _nd_image imports, and strengthen the smoothness gradient tests to do full cross-engine numerical comparison now that all three engines are consistent. Co-Authored-By: Claude Sonnet 4.6 --- pydda/cost_functions/_cost_functions_numpy.py | 58 ++++++++++++------- pydda/tests/test_cost_functions.py | 20 ++----- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/pydda/cost_functions/_cost_functions_numpy.py b/pydda/cost_functions/_cost_functions_numpy.py index 754f9316..799f737a 100644 --- a/pydda/cost_functions/_cost_functions_numpy.py +++ b/pydda/cost_functions/_cost_functions_numpy.py @@ -1,9 +1,6 @@ import numpy as np -import scipy import pyart -from scipy.ndimage import _nd_image - laplace_filter = np.asarray([1, -2, 1], dtype=np.float64) @@ -236,24 +233,43 @@ def calculate_smoothness_gradient( y: float array value of gradient of smoothness cost function """ - du = np.zeros(w.shape) - dv = np.zeros(w.shape) - dw = np.zeros(w.shape) - grad_u = np.zeros(w.shape) - grad_v = np.zeros(w.shape) - grad_w = np.zeros(w.shape) - scipy.ndimage.laplace(u, du, mode="wrap") - scipy.ndimage.laplace(v, dv, mode="wrap") - scipy.ndimage.laplace(w, dw, mode="wrap") - du = du / dx - dv = dv / dy - dw = dw / dz - scipy.ndimage.laplace(du, grad_u, mode="wrap") - scipy.ndimage.laplace(dv, grad_v, mode="wrap") - scipy.ndimage.laplace(dw, grad_w, mode="wrap") - grad_u = grad_u / dx - grad_v = grad_v / dy - grad_w = grad_w / dz + # Recompute the combined second-derivative terms from the cost function + dudx = np.gradient(u, dx, axis=2) + dudy = np.gradient(u, dy, axis=1) + dudz = np.gradient(u, dz, axis=0) + dvdx = np.gradient(v, dx, axis=2) + dvdy = np.gradient(v, dy, axis=1) + dvdz = np.gradient(v, dz, axis=0) + dwdx = np.gradient(w, dx, axis=2) + dwdy = np.gradient(w, dy, axis=1) + dwdz = np.gradient(w, dz, axis=0) + + fx = ( + np.gradient(dudx, dx, axis=2) + + np.gradient(dvdx, dx, axis=2) + + np.gradient(dwdx, dx, axis=2) + ) + fy = ( + np.gradient(dudy, dy, axis=1) + + np.gradient(dvdy, dy, axis=1) + + np.gradient(dwdy, dy, axis=1) + ) + fz = ( + np.gradient(dudz, dz, axis=0) + + np.gradient(dvdz, dz, axis=0) + + np.gradient(dwdz, dz, axis=0) + ) + + # Gradient: 2*Ci * d²(fi)/di². u, v, w share the same gradient because + # the cost treats them symmetrically through the combined fx/fy/fz terms. + grad_all = ( + 2 * Cx * np.gradient(np.gradient(fx, dx, axis=2), dx, axis=2) + + 2 * Cy * np.gradient(np.gradient(fy, dy, axis=1), dy, axis=1) + + 2 * Cz * np.gradient(np.gradient(fz, dz, axis=0), dz, axis=0) + ) + grad_u = grad_all.copy() + grad_v = grad_all.copy() + grad_w = grad_all.copy() # Impermeability condition grad_w[0, :, :] = 0 diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 36260375..f1b57b4f 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -831,12 +831,7 @@ def test_model_cost_tf(): @pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") def test_calculate_smoothness_cost_nonzero_jax(): - """Nonzero field: JAX smoothness cost matches numpy; gradient is nonzero. - - Note: the numpy smoothness gradient omits Cx/Cy/Cz and uses a different - Laplacian operator (scipy.ndimage) so cross-engine gradient comparison is - not meaningful. Each engine's gradient is tested for being nonzero instead. - """ + """Nonzero field: JAX smoothness cost and gradient match numpy.""" rng = np.random.default_rng(42) u = rng.random((10, 10, 10)) v = rng.random((10, 10, 10)) @@ -847,20 +842,16 @@ def test_calculate_smoothness_cost_nonzero_jax(): jax_cost = pydda.cost_functions.jax.calculate_smoothness_cost(u, v, w, dx, dy, dz) np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + numpy_grad = pydda.cost_functions.calculate_smoothness_gradient(u, v, w, dx, dy, dz) jax_grad = pydda.cost_functions.jax.calculate_smoothness_gradient( u, v, w, dx, dy, dz ) - assert np.any(np.array(jax_grad) != 0) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) @pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") def test_calculate_smoothness_cost_nonzero_tf(): - """Nonzero field: TensorFlow smoothness cost matches numpy; gradient is nonzero. - - Note: the numpy smoothness gradient omits Cx/Cy/Cz and uses a different - Laplacian operator (scipy.ndimage) so cross-engine gradient comparison is - not meaningful. Each engine's gradient is tested for being nonzero instead. - """ + """Nonzero field: TensorFlow smoothness cost and gradient match numpy.""" rng = np.random.default_rng(42) u = rng.random((10, 10, 10)) v = rng.random((10, 10, 10)) @@ -878,6 +869,7 @@ def test_calculate_smoothness_cost_nonzero_tf(): ) np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + numpy_grad = pydda.cost_functions.calculate_smoothness_gradient(u, v, w, dx, dy, dz) tf_grad = pydda.cost_functions.tf.calculate_smoothness_gradient( tf.constant(u.astype(np.float32)), tf.constant(v.astype(np.float32)), @@ -886,7 +878,7 @@ def test_calculate_smoothness_cost_nonzero_tf(): dy, dz, ) - assert np.any(np.array(tf_grad) != 0) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) @pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed")