diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index ec1cef57..c9495abe 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -267,7 +267,7 @@ def calculate_smoothness_gradient( return y.flatten() -def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): +def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3): """ Calculates the cost function related to point observations. A mean square error cost function term is applied to points that are within the sphere of influence @@ -305,14 +305,15 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): """ J = 0.0 for the_point in point_list: - the_box = jnp.logical_and( - jnp.logical_and( - jnp.abs(x - the_point["x"]) < roi, jnp.abs(y - the_point["y"]) < roi - ), - jnp.abs(z - the_point["z"]) < roi, + dist = jnp.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) - the_box = jnp.where(the_box, 1.0, 0.0) - J += jnp.sum(((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2) * the_box) + dist = jnp.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / jnp.sum(weight) + J += jnp.sum(weight * ((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2)) return J * Cp @@ -358,18 +359,16 @@ def calculate_point_gradient(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): gradJ_w = jnp.zeros_like(u) for the_point in point_list: - the_box = jnp.where( - jnp.logical_and( - jnp.logical_and( - np.abs(x - the_point["x"]) < roi, np.abs(y - the_point["y"]) < roi - ), - np.abs(z - the_point["z"]) < roi, - ), - 1.0, - 0.0, + dist = jnp.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) - gradJ_u += 2 * (u - the_point["u"]) * the_box - gradJ_v += 2 * (v - the_point["v"]) * the_box + dist = jnp.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / jnp.sum(weight) + gradJ_u += 2 * (u - the_point["u"]) * weight + gradJ_v += 2 * (v - the_point["v"]) * weight gradJ = jnp.stack([gradJ_u, gradJ_v, gradJ_w], axis=0).flatten() return gradJ * Cp diff --git a/pydda/cost_functions/_cost_functions_numpy.py b/pydda/cost_functions/_cost_functions_numpy.py index fc578511..95150117 100644 --- a/pydda/cost_functions/_cost_functions_numpy.py +++ b/pydda/cost_functions/_cost_functions_numpy.py @@ -303,7 +303,7 @@ def calculate_smoothness_gradient( return y.flatten() -def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): +def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, power=2): """ Calculates the cost function related to point observations. A mean square error cost function term is applied to points that are within the sphere of influence @@ -339,18 +339,17 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): for the_point in point_list: # Instead of worrying about whole domain, just find points in radius of influence # Since we know that the weight will be zero outside the sphere of influence anyways - the_box = np.where( - np.logical_and.reduce( - ( - np.abs(x - the_point["x"]) < roi, - np.abs(y - the_point["y"]) < roi, - np.abs(z - the_point["z"]) < roi, - ) - ) - ) - J += np.sum( - ((u[the_box] - the_point["u"]) ** 2 + (v[the_box] - the_point["v"]) ** 2) + + dist = np.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) + dist = np.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / np.max(weight) + + J += np.sum(weight * ((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2)) return J * Cp @@ -392,17 +391,16 @@ def calculate_point_gradient(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): gradJ_w = np.zeros_like(u) for the_point in point_list: - the_box = np.where( - np.logical_and.reduce( - ( - np.abs(x - the_point["x"]) < roi, - np.abs(y - the_point["y"]) < roi, - np.abs(z - the_point["z"]) < roi, - ) - ) + dist = np.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) - gradJ_u[the_box] += 2 * (u[the_box] - the_point["u"]) - gradJ_v[the_box] += 2 * (v[the_box] - the_point["v"]) + dist = np.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / np.max(weight) + gradJ_u += 2 * weight * (u - the_point["u"]) + gradJ_v += 2 * weight * (v - the_point["v"]) gradJ = np.stack([gradJ_u, gradJ_v, gradJ_w], axis=0).flatten() return gradJ * Cp diff --git a/pydda/cost_functions/_cost_functions_tensorflow.py b/pydda/cost_functions/_cost_functions_tensorflow.py index d687fd36..76668aed 100644 --- a/pydda/cost_functions/_cost_functions_tensorflow.py +++ b/pydda/cost_functions/_cost_functions_tensorflow.py @@ -333,23 +333,18 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): for the_point in point_list: # Instead of worrying about whole domain, just find points in radius of influence # Since we know that the weight will be zero outside the sphere of influence anyways - xp = tf.ones_like(x) * the_point["x"] - yp = tf.ones_like(y) * the_point["y"] - zp = tf.ones_like(z) * the_point["z"] up = tf.ones_like(u) * the_point["u"] vp = tf.ones_like(v) * the_point["v"] - - the_box = tf.where( - tf.math.logical_and( - tf.math.logical_and( - tf.math.abs(x - xp) < roi, tf.math.abs(y - yp) < roi - ), - tf.math.abs(z - zp) < roi, - ), - 1.0, - 0.0, + dist = tf.math.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) - J.assign_add(tf.math.reduce_sum(((u - up) ** 2 + (v - vp) ** 2) * the_box)) + dist = tf.math.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / tf.reduce_max(weight) + + J.assign_add(tf.math.reduce_sum(((u - up) ** 2 + (v - vp) ** 2) * weight)) return J * Cp @@ -394,24 +389,19 @@ def calculate_point_gradient(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): for the_point in point_list: # Instead of worrying about whole domain, just find points in radius of influence # Since we know that the weight will be zero outside the sphere of influence anyways - xp = tf.ones_like(x, dtype=tf.float32) * the_point["x"] - yp = tf.ones_like(y, dtype=tf.float32) * the_point["y"] - zp = tf.ones_like(z, dtype=tf.float32) * the_point["z"] up = tf.ones_like(u, dtype=tf.float32) * the_point["u"] vp = tf.ones_like(v, dtype=tf.float32) * the_point["v"] - the_box = tf.where( - tf.math.logical_and( - tf.math.logical_and( - tf.math.abs(x - xp) < roi, tf.math.abs(y - yp) < roi - ), - tf.math.abs(z - zp) < roi, - ), - 1.0, - 0.0, + dist = tf.math.sqrt( + (x - the_point["x"]) ** 2 + + (y - the_point["y"]) ** 2 + + (z - the_point["z"]) ** 2 ) - gradJ_u.assign_add((2 * (u - up) * the_box)) - gradJ_v.assign_add((2 * (v - vp) * the_box)) + dist = tf.math.maximum(dist, 1.0) + weight = 1 / dist**2 + weight = weight / tf.reduce_max(weight) + gradJ_u.assign_add((2 * (u - up) * weight)) + gradJ_v.assign_add((2 * (v - vp) * weight)) gradJ = tf.stack([gradJ_u, gradJ_v, gradJ_w], axis=0) gradJ = tf.reshape(gradJ, (3 * np.prod(u.shape),)) return gradJ * Cp diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 6614c216..c234c720 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -538,10 +538,20 @@ def test_point_cost(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost > 0 @@ -551,19 +561,39 @@ def test_point_cost(): my_point2 = {"x": 3, "y": 3, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost > 0 assert np.all(grad >= 0) cost = pydda.cost_functions.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost > 0 assert ~np.all(grad >= 0) @@ -571,10 +601,20 @@ def test_point_cost(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} my_point2 = {"x": 3, "y": 3, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} cost = pydda.cost_functions.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost == 0 assert np.all(grad == 0) @@ -591,10 +631,20 @@ def test_point_cost_jax(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.jax.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.jax.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost > 0 @@ -604,19 +654,39 @@ def test_point_cost_jax(): my_point2 = {"x": 3, "y": 3, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.jax.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.jax.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost > 0 assert np.all(grad >= 0) cost = pydda.cost_functions.jax.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.jax.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost > 0 assert ~np.all(grad >= 0) @@ -624,10 +694,20 @@ def test_point_cost_jax(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} my_point2 = {"x": 3, "y": 3, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} cost = pydda.cost_functions.jax.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.jax.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost == 0 assert np.all(grad == 0) @@ -646,10 +726,20 @@ def test_point_cost_tf(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.tf.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.tf.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost.numpy() > 0 @@ -659,19 +749,39 @@ def test_point_cost_tf(): my_point2 = {"x": 3, "y": 3, "z": 0, "u": 2.0, "v": 2.0, "w": 0.0} cost = pydda.cost_functions.tf.calculate_point_cost( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) grad = pydda.cost_functions.tf.calculate_point_gradient( - u, v, x, y, z, [my_point1], roi=2.0 + u, + v, + x, + y, + z, + [my_point1], ) assert cost.numpy() > 0 assert tf.math.reduce_all(grad >= 0) cost = pydda.cost_functions.tf.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.tf.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost.numpy() > 0 assert ~tf.math.reduce_all(grad >= 0) @@ -679,10 +789,20 @@ def test_point_cost_tf(): my_point1 = {"x": 0, "y": 0, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} my_point2 = {"x": 3, "y": 3, "z": 0, "u": 1.0, "v": 1.0, "w": 0.0} cost = pydda.cost_functions.tf.calculate_point_cost( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) grad = pydda.cost_functions.tf.calculate_point_gradient( - u, v, x, y, z, [my_point1, my_point2], roi=2.0 + u, + v, + x, + y, + z, + [my_point1, my_point2], ) assert cost.numpy() == 0 assert tf.math.reduce_all(grad == 0)