diff --git a/pydda/cost_functions/_cost_functions_numpy.py b/pydda/cost_functions/_cost_functions_numpy.py index fc578511..ab9c02e0 100644 --- a/pydda/cost_functions/_cost_functions_numpy.py +++ b/pydda/cost_functions/_cost_functions_numpy.py @@ -1,9 +1,6 @@ import numpy as np -import scipy import pyart -from scipy.ndimage import _nd_image - laplace_filter = np.asarray([1, -2, 1], dtype=np.float64) @@ -274,24 +271,43 @@ def calculate_smoothness_gradient( y: float array value of gradient of smoothness cost function """ - du = np.zeros(w.shape) - dv = np.zeros(w.shape) - dw = np.zeros(w.shape) - grad_u = np.zeros(w.shape) - grad_v = np.zeros(w.shape) - grad_w = np.zeros(w.shape) - scipy.ndimage.laplace(u, du, mode="wrap") - scipy.ndimage.laplace(v, dv, mode="wrap") - scipy.ndimage.laplace(w, dw, mode="wrap") - du = du / dx - dv = dv / dy - dw = dw / dz - scipy.ndimage.laplace(du, grad_u, mode="wrap") - scipy.ndimage.laplace(dv, grad_v, mode="wrap") - scipy.ndimage.laplace(dw, grad_w, mode="wrap") - grad_u = grad_u / dx - grad_v = grad_v / dy - grad_w = grad_w / dz + # Recompute the combined second-derivative terms from the cost function + dudx = np.gradient(u, dx, axis=2) + dudy = np.gradient(u, dy, axis=1) + dudz = np.gradient(u, dz, axis=0) + dvdx = np.gradient(v, dx, axis=2) + dvdy = np.gradient(v, dy, axis=1) + dvdz = np.gradient(v, dz, axis=0) + dwdx = np.gradient(w, dx, axis=2) + dwdy = np.gradient(w, dy, axis=1) + dwdz = np.gradient(w, dz, axis=0) + + fx = ( + np.gradient(dudx, dx, axis=2) + + np.gradient(dvdx, dx, axis=2) + + np.gradient(dwdx, dx, axis=2) + ) + fy = ( + np.gradient(dudy, dy, axis=1) + + np.gradient(dvdy, dy, axis=1) + + np.gradient(dwdy, dy, axis=1) + ) + fz = ( + np.gradient(dudz, dz, axis=0) + + np.gradient(dvdz, dz, axis=0) + + np.gradient(dwdz, dz, axis=0) + ) + + # Gradient: 2*Ci * d²(fi)/di². u, v, w share the same gradient because + # the cost treats them symmetrically through the combined fx/fy/fz terms. + grad_all = ( + 2 * Cx * np.gradient(np.gradient(fx, dx, axis=2), dx, axis=2) + + 2 * Cy * np.gradient(np.gradient(fy, dy, axis=1), dy, axis=1) + + 2 * Cz * np.gradient(np.gradient(fz, dz, axis=0), dz, axis=0) + ) + grad_u = grad_all.copy() + grad_v = grad_all.copy() + grad_w = grad_all.copy() # Impermeability condition grad_w[0, :, :] = 0 diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 6614c216..3b899e43 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -113,6 +113,49 @@ def test_calculate_rad_velocity_cost_jax(): assert np.all(grad == 0) +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_rad_velocity_cost_nonzero_jax(): + """Nonzero wind field produces nonzero cost and gradient matching numpy.""" + vrs, azs, els, wts, weights = _make_radvel_inputs() + rng = np.random.default_rng(42) + u = rng.random((20, 20, 20)) + v = rng.random((20, 20, 20)) + w = rng.random((20, 20, 20)) + rmsVr = 1.0 + + numpy_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights + ) + jax_cost = pydda.cost_functions.jax.calculate_radial_vel_cost_function( + [np.array(vrs[0])], + azs, + els, + jnp.array(u), + jnp.array(v), + jnp.array(w), + [jnp.array(np.array(wts[0]))], + rmsVr, + jnp.array(weights), + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=1e-5) + + numpy_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr + ) + jax_grad = pydda.cost_functions.jax.calculate_grad_radial_vel( + [np.array(vrs[0])], + els, + azs, + jnp.array(u), + jnp.array(v), + jnp.array(w), + [jnp.array(np.array(wts[0]))], + jnp.array(weights), + rmsVr, + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + @pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") def test_calculate_rad_velocity_cost_tf(): """Test with a zero velocity field radar""" @@ -145,6 +188,49 @@ def test_calculate_rad_velocity_cost_tf(): assert np.all(grad.numpy() == 0) +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_rad_velocity_cost_nonzero_tf(): + """Nonzero wind field produces nonzero cost and gradient matching numpy.""" + vrs, azs, els, wts, weights = _make_radvel_inputs() + rng = np.random.default_rng(42) + u = rng.random((20, 20, 20)) + v = rng.random((20, 20, 20)) + w = rng.random((20, 20, 20)) + rmsVr = 1.0 + + numpy_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights + ) + tf_cost = pydda.cost_functions.tf.calculate_radial_vel_cost_function( + [tf.constant(np.array(vrs[0]), dtype=tf.float32)], + [tf.constant(np.array(azs[0]), dtype=tf.float32)], + [tf.constant(np.array(els[0]), dtype=tf.float32)], + tf.constant(u.astype(np.float32), dtype=tf.float32), + tf.constant(v.astype(np.float32), dtype=tf.float32), + tf.constant(w.astype(np.float32), dtype=tf.float32), + [tf.constant(np.array(wts[0]), dtype=tf.float32)], + rmsVr, + [tf.constant(weights[0].astype(np.float32), dtype=tf.float32)], + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr + ) + tf_grad = pydda.cost_functions.tf.calculate_grad_radial_vel( + [tf.constant(np.array(vrs[0]), dtype=tf.float32)], + [tf.constant(np.array(els[0]), dtype=tf.float32)], + [tf.constant(np.array(azs[0]), dtype=tf.float32)], + tf.constant(u.astype(np.float32), dtype=tf.float32), + tf.constant(v.astype(np.float32), dtype=tf.float32), + tf.constant(w.astype(np.float32), dtype=tf.float32), + [tf.constant(np.array(wts[0]), dtype=tf.float32)], + [tf.constant(weights[0].astype(np.float32), dtype=tf.float32)], + rmsVr, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + def test_calculate_fall_speed(): """Check to see if fall speeds are realistic""" ref_field = 10 * np.ones((10, 100, 100)) @@ -753,3 +839,404 @@ def test_model_cost_tf(): u, v, w, weights, u - 1, v - 1, w ) assert cost2 > cost1 + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_smoothness_cost_nonzero_jax(): + """Nonzero field: JAX smoothness cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + + numpy_cost = pydda.cost_functions.calculate_smoothness_cost(u, v, w, dx, dy, dz) + jax_cost = pydda.cost_functions.jax.calculate_smoothness_cost(u, v, w, dx, dy, dz) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_smoothness_gradient(u, v, w, dx, dy, dz) + jax_grad = pydda.cost_functions.jax.calculate_smoothness_gradient( + u, v, w, dx, dy, dz + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_smoothness_cost_nonzero_tf(): + """Nonzero field: TensorFlow smoothness cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + + numpy_cost = pydda.cost_functions.calculate_smoothness_cost(u, v, w, dx, dy, dz) + tf_cost = pydda.cost_functions.tf.calculate_smoothness_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_smoothness_gradient(u, v, w, dx, dy, dz) + tf_grad = pydda.cost_functions.tf.calculate_smoothness_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_calculate_mass_continuity_nonzero_jax(): + """Nonzero divergent field: JAX mass continuity cost matches numpy; gradient is nonzero. + + Note: the numpy gradient uses an analytic adjoint formula while JAX uses + jax.vjp; the discrete transpose differs at grid boundaries, so cross-engine + gradient comparison is not used here. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + z = np.arange(0, 1000.0, 100) + + numpy_cost = pydda.cost_functions.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + jax_cost = pydda.cost_functions.jax.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + jax_grad = pydda.cost_functions.jax.calculate_mass_continuity_gradient( + u, v, w, z, dx, dy, dz, anel=0 + ) + assert np.any(np.array(jax_grad) != 0) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_calculate_mass_continuity_nonzero_tf(): + """Nonzero divergent field: TensorFlow mass continuity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + z = np.arange(0, 1000.0, 100) + + numpy_cost = pydda.cost_functions.calculate_mass_continuity( + u, v, w, z, dx, dy, dz, anel=0 + ) + tf_cost = pydda.cost_functions.tf.calculate_mass_continuity( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(z.astype(np.float32)), + dx, + dy, + dz, + anel=0, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_mass_continuity_gradient( + u, v, w, z, dx, dy, dz, anel=0 + ) + tf_grad = pydda.cost_functions.tf.calculate_mass_continuity_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(z.astype(np.float32)), + dx, + dy, + dz, + anel=0, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_background_cost_nonzero_jax(): + """Nonzero background mismatch: JAX background cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + u_back = 5.0 * np.ones(10) + v_back = 3.0 * np.ones(10) + + numpy_cost = pydda.cost_functions.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + jax_cost = pydda.cost_functions.jax.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + jax_grad = pydda.cost_functions.jax.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_background_cost_nonzero_tf(): + """Nonzero background mismatch: TF background cost and gradient match numpy. + + Note: TF background functions omit the w parameter. + """ + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + u_back = 5.0 * np.ones(10) + v_back = 3.0 * np.ones(10) + + numpy_cost = pydda.cost_functions.calculate_background_cost( + u, v, w, weights, u_back, v_back + ) + tf_cost = pydda.cost_functions.tf.calculate_background_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_back.astype(np.float32)), + tf.constant(v_back.astype(np.float32)), + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_background_gradient( + u, v, w, weights, u_back, v_back + ) + tf_grad = pydda.cost_functions.tf.calculate_background_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_back.astype(np.float32)), + tf.constant(v_back.astype(np.float32)), + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_vert_vorticity_nonzero_jax(): + """Nonzero rotating field: JAX vorticity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + Ut, Vt = 10.0, 10.0 + + numpy_cost = pydda.cost_functions.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt + ) + jax_cost = pydda.cost_functions.jax.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt + ) + jax_grad = pydda.cost_functions.jax.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_vert_vorticity_nonzero_tf(): + """Nonzero rotating field: TensorFlow vorticity cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = np.zeros((10, 10, 10)) + dx, dy, dz = 100.0, 100.0, 100.0 + Ut, Vt = 10.0, 10.0 + + # TF default coeff=1 differs from numpy/JAX default coeff=1e-5; pass explicitly + numpy_cost = pydda.cost_functions.calculate_vertical_vorticity_cost( + u, v, w, dx, dy, dz, Ut, Vt, coeff=1e-5 + ) + tf_cost = pydda.cost_functions.tf.calculate_vertical_vorticity_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + Ut, + Vt, + coeff=1e-5, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_vertical_vorticity_gradient( + u, v, w, dx, dy, dz, Ut, Vt, coeff=1e-5 + ) + tf_grad = pydda.cost_functions.tf.calculate_vertical_vorticity_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + dx, + dy, + dz, + Ut, + Vt, + coeff=1e-5, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_model_cost_nonzero_jax(): + """Nonzero model mismatch: JAX model cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + u_model = 5.0 * np.ones((10, 10, 10)) + v_model = 3.0 * np.ones((10, 10, 10)) + w_model = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + + numpy_cost = pydda.cost_functions.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + jax_cost = pydda.cost_functions.jax.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + jax_grad = pydda.cost_functions.jax.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_model_cost_nonzero_tf(): + """Nonzero model mismatch: TensorFlow model cost and gradient match numpy.""" + rng = np.random.default_rng(42) + u = rng.random((10, 10, 10)) + v = rng.random((10, 10, 10)) + w = rng.random((10, 10, 10)) + u_model = 5.0 * np.ones((10, 10, 10)) + v_model = 3.0 * np.ones((10, 10, 10)) + w_model = np.zeros((10, 10, 10)) + weights = np.ones((10, 10, 10)) + + numpy_cost = pydda.cost_functions.calculate_model_cost( + u, v, w, weights, u_model, v_model, w_model + ) + tf_cost = pydda.cost_functions.tf.calculate_model_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_model.astype(np.float32)), + tf.constant(v_model.astype(np.float32)), + tf.constant(w_model.astype(np.float32)), + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_model_gradient( + u, v, w, weights, u_model, v_model, w_model + ) + tf_grad = pydda.cost_functions.tf.calculate_model_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + tf.constant(w.astype(np.float32)), + tf.constant(weights.astype(np.float32)), + tf.constant(u_model.astype(np.float32)), + tf.constant(v_model.astype(np.float32)), + tf.constant(w_model.astype(np.float32)), + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") +def test_point_cost_nonzero_jax(): + """Nonzero point observation mismatch: JAX point cost and gradient match numpy.""" + u = np.ones((10, 10, 10)) + v = np.ones((10, 10, 10)) + x_1d = np.linspace(-10, 10, 10) + y_1d = np.linspace(-10, 10, 10) + z_1d = np.linspace(-10, 10, 10) + x, y, z = np.meshgrid(x_1d, y_1d, z_1d) + my_point = {"x": 0, "y": 0, "z": 0, "u": 3.0, "v": 3.0, "w": 0.0} + + numpy_cost = pydda.cost_functions.calculate_point_cost( + u, v, x, y, z, [my_point], roi=2.0 + ) + jax_cost = pydda.cost_functions.jax.calculate_point_cost( + u, v, x, y, z, [my_point], roi=2.0 + ) + np.testing.assert_allclose(float(jax_cost), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_point_gradient( + u, v, x, y, z, [my_point], roi=2.0 + ) + jax_grad = pydda.cost_functions.jax.calculate_point_gradient( + u, v, x, y, z, [my_point], roi=2.0 + ) + np.testing.assert_allclose(np.array(jax_grad), numpy_grad, rtol=0.03, atol=1e-4) + + +@pytest.mark.skipif(not TENSORFLOW_AVAILABLE, reason="TensorFlow not installed") +def test_point_cost_nonzero_tf(): + """Nonzero point observation mismatch: TensorFlow point cost and gradient match numpy.""" + u = np.ones((10, 10, 10)) + v = np.ones((10, 10, 10)) + x_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + y_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + z_1d = tf.constant(np.linspace(-10, 10, 10), dtype=tf.float32) + x_tf, y_tf, z_tf = tf.meshgrid(x_1d, y_1d, z_1d) + x_np, y_np, z_np = np.meshgrid( + np.linspace(-10, 10, 10), np.linspace(-10, 10, 10), np.linspace(-10, 10, 10) + ) + my_point = {"x": 0, "y": 0, "z": 0, "u": 3.0, "v": 3.0, "w": 0.0} + + numpy_cost = pydda.cost_functions.calculate_point_cost( + u, v, x_np, y_np, z_np, [my_point], roi=2.0 + ) + tf_cost = pydda.cost_functions.tf.calculate_point_cost( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + x_tf, + y_tf, + z_tf, + [my_point], + roi=2.0, + ) + np.testing.assert_allclose(float(tf_cost.numpy()), numpy_cost, rtol=0.03, atol=1e-4) + + numpy_grad = pydda.cost_functions.calculate_point_gradient( + u, v, x_np, y_np, z_np, [my_point], roi=2.0 + ) + tf_grad = pydda.cost_functions.tf.calculate_point_gradient( + tf.constant(u.astype(np.float32)), + tf.constant(v.astype(np.float32)), + x_tf, + y_tf, + z_tf, + [my_point], + roi=2.0, + ) + np.testing.assert_allclose(np.array(tf_grad), numpy_grad, rtol=0.03, atol=1e-4)