diff --git a/b3d/camera.py b/b3d/camera.py index 8e4f8cd4..c311fb04 100644 --- a/b3d/camera.py +++ b/b3d/camera.py @@ -90,7 +90,7 @@ def camera_from_screen_and_depth( def camera_from_screen(uv: ScreenCoordinates, intrinsics) -> CameraCoordinates: - z = jnp.ones_like(uv.shape[-1:]) + z = jnp.ones(uv.shape[:-1]) return camera_from_screen_and_depth(uv, z, intrinsics) @@ -136,14 +136,12 @@ def screen_from_camera(xyz: CameraCoordinates, intrinsics) -> ScreenCoordinates: Returns: (...,2) array of screen coordinates. """ - # TODO: check this - xyz = jnp.clip(xyz, - jnp.array([-jnp.inf, -jnp.inf, intrinsics.near]), - jnp.array([jnp.inf, jnp.inf, intrinsics.far])) - _, _, fx, fy, cx, cy, _, _ = intrinsics + # TODO: We need to clip? Culling? + _, _, fx, fy, cx, cy, near, _ = intrinsics x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] u = x * fx / z + cx v = y * fy / z + cy + return jnp.stack([u, v], axis=-1) @@ -154,6 +152,9 @@ def screen_from_world(x, cam, intr): """Maps to screen coordintaes `uv` from world coordinates `xyz`.""" return screen_from_camera(cam.inv().apply(x), intr) +def world_from_screen(uv, cam, intr): + """Maps to world coordintaes `xyz` from screen coords `uv`.""" + return cam.apply(camera_from_screen(uv, intr)) def camera_matrix_from_intrinsics(intr: Intrinsics) -> CameraMatrix3x3: """ diff --git a/b3d/chisight/sfm/__init__.py b/b3d/chisight/sfm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/b3d/chisight/sfm/epipolar.py b/b3d/chisight/sfm/epipolar.py new file mode 100644 index 00000000..97b8f74d --- /dev/null +++ b/b3d/chisight/sfm/epipolar.py @@ -0,0 +1,307 @@ +import jax +import jax.numpy as jnp +from b3d.pose import Pose, Rot +from b3d.camera import ( + screen_from_world, + screen_from_camera, + camera_from_screen, + world_from_screen, + camera_from_screen_and_depth, +) +from b3d.utils import keysplit +from sklearn.utils import Bunch + + +# # # # # # # # # # # # # # # # # # # # +# +# Epipolar geometry +# +# # # # # # # # # # # # # # # # # # # # + +def get_epipole(cam, intr): + """Get epipole of a camera with respect to fixed standard camera (at origin).""" + e = screen_from_world(jnp.zeros(3), cam, intr) + return e + +def get_epipoles(cam0, cam1, intr): + """Get epipoles of two cameras.""" + e0 = screen_from_world(cam1.pos, cam0, intr) + e1 = screen_from_world(cam0.pos, cam1, intr) + return jnp.stack([e0, e1], axis=0) + +def dist_to_line(u, l): + """ + Returns the distance of 'u' to the line through 'l'. + """ + # Normalize and + # rotate by 90 degrees + l = l/jnp.sqrt(l[...,[0]]**2 + l[...,[1]]**2) + il = jnp.stack([-l[...,1],l[...,0]], axis=-1) + d = u[...,0]*il[...,0] + u[...,1]*il[...,1] + return jnp.abs(d) + + +def dist_to_and_along_line(u, l): + """ + Returns the distance of 'u' to the line through 'l', and + the amount that of `u` along `l/|l|`, that is, + + `|dot(u, il/|il|)|` and `dot(u, l/|l|)`. + """ + # Normalize and + # rotate by 90 degrees + l = l/jnp.sqrt(l[...,[0]]**2 + l[...,[1]]**2) + il = jnp.stack([-l[...,1],l[...,0]], axis=-1) + d = u[...,0]*il[...,0] + u[...,1]*il[...,1] + s = u[...,0]* l[...,0] + u[...,1]* l[...,1] + return jnp.abs(d), s + + +def _epi_constraint(cam, u0, u1, intr): + """ + Textbook epipolar constraint. Computes the (unsigned) alignment between the epipolar planes + spanned by `u0`, `u1`, and `cam`; zero means perfectly aligned. + + Args: + cam: Relative camera Pose + u0: Array of shape (..., 2) + u1: Array of shape (..., 2) (same shape as `u0`) + intr: Intrinsics + + Returns + Array of shape (...) + """ + # TODO: Add a reference. + # NOTE: We work with a relative pose here, that is, we assume + # at time 0 world and camera frames are the same. + v0 = camera_from_screen(u0, intr) + v1 = world_from_screen(u1, cam, intr) - cam.pos + c = cam.pos + + # Normalize + v0 = v0/jnp.sqrt(v0[...,[0]]**2 + v0[...,[1]]**2 + v0[...,[2]]**2) + v1 = v1/jnp.sqrt(v1[...,[0]]**2 + v1[...,[1]]**2 + v1[...,[2]]**2) + c = c/jnp.sqrt(c[...,[0]]**2 + c[...,[1]]**2+ c[...,[2]]**2) + n = jnp.cross(v0, c[None], axis=-1) + n = n/jnp.sqrt(n[...,[0]]**2 + n[...,[1]]**2+ n[...,[2]]**2) + d = (n * v1).sum(-1) + + # Project to epi plane spanned by v0 and c + v1_ = v1 - (v1*n).sum(-1)[:,None]*n + v1_ = v1_/jnp.linalg.norm(v1_, axis=-1, keepdims=True) + h = jnp.abs(d) + + aux = dict(v0=v0, v1=v1, v1_in_epiplane=v1_) + return h, aux + +vmap_epi_constraint = jax.vmap( + lambda cam, uv0, uv1, intr: _epi_constraint(cam, uv0, uv1, intr)[0], + (0,None,None,None) +) + + +# NOTE: Experimental, don't rely on this +def _epi_constraint_variation_1(cam, u0, u1, intr): + h, aux = _epi_constraint(cam, u0, u1, intr)[0] + v0 = aux["v0"] + v1_ = aux["v1_on_epiplane"] + c = cam.pos/jnp.linalg.norm(cam.pos) + h = h - jnp.sign( (v0 * c).sum(-1) - (v1_ * c).sum(-1) ).sum() + return h, None + + +def _epi_distance(cam, u0, u1, intr): + """ + Projected version of epipolar constraints. + + Computes the distances of `u1` to the epipolar line + on the sensor canvas induced by `u0` and `cam`. + + Args: + cam: Relative camera Pose + u0: Array of shape (..., 2) + u1: Array of shape (..., 2) (same shape as `u0`) + intr: Intrinsics + + Returns + Array of shape (...) + """ + # NOTE: We work with a relative pose here, that is, we assume + # at time 0 world and camera frames are the same. + # TODO: Constrain so that we only consider the + # positive part of the line. That "should" (might) get rid of + # weird local maxima with points behind the camera. + # One should also look at the far end of the line since + # beyond this one cannot reach. + + # Get epipole in frame 1 + e = screen_from_world(jnp.zeros(3), cam, intr) + + # Take a point on the ray shooting through u0, + # and project onto opposite screen + x = camera_from_screen(u0, intr) + v1 = screen_from_world(x, cam, intr) + l = v1 - e + u = u1 - e + + d, _ = dist_to_and_along_line(u, l) + aux = {"epipole": e, "line_direction": l,} + return d, aux + +vmap_epi_distance = jax.vmap( + lambda cam, uv0, uv1, intr: _epi_distance(cam, uv0, uv1, intr)[0], + (0,None,None,None) +) + +# # # # # # # # # # # # # # # # # # # # +# +# Debugging +# +# # # # # # # # # # # # # # # # # # # # + +def _get_epipolar_debugging_data(cam, u0, u1, intr): + + # Get epipole in frame 1 + e = screen_from_world(jnp.zeros(3), cam, intr) + + # Take a point on the ray shooting through u0, + # and project onto opposite screen + x = camera_from_screen(u0, intr) + v1 = screen_from_world(x, cam, intr) + l = v1 - e + u = u1 - e + d, s = dist_to_and_along_line(u, l) + + l_norm = jnp.sqrt(l[...,[0]]**2 + l[...,[1]]**2) + + proj_vec = s[...,None] * l/l_norm + error_vec = proj_vec - u + + + x_near = camera_from_screen_and_depth(u0, jnp.array([intr.near]), intr) + x_far = camera_from_screen_and_depth(u0, jnp.array([intr.far]), intr) + v_near = screen_from_world(x_near, cam, intr) + v_far = screen_from_world(x_far, cam, intr) + vs = jnp.stack([v_near, v_far], axis=1) + + + v0 = camera_from_screen(u0, intr) + v1 = world_from_screen(u1, cam, intr) - cam.pos + c = cam.pos + + + # Normalize + v0 = v0/jnp.sqrt(v0[...,[0]]**2 + v0[...,[1]]**2 + v0[...,[2]]**2) + v1 = v1/jnp.sqrt(v1[...,[0]]**2 + v1[...,[1]]**2 + v1[...,[2]]**2) + c = c/jnp.sqrt(c[...,[0]]**2 + c[...,[1]]**2+ c[...,[2]]**2) + n = jnp.cross(v0, c[None], axis=-1) + n = n/jnp.sqrt(n[...,[0]]**2 + n[...,[1]]**2+ n[...,[2]]**2) + + return dict( + epipole = e, + line_directions = l, + epi_distance = d, + epi_scalar = s, + projection_vector = proj_vec, + error_vector = error_vec, + near_far_screen = vs, + near_far_world = jnp.stack([x_near, x_far], axis=1), + v0 = v0, + v1 = v1, + c = c, + n = jnp.cross(v0, c[None], axis=-1) + ) + + +# # # # # # # # # # # # # # # # # # # # +# +# Helper +# +# # # # # # # # # # # # # # # # # # # # + +def angle(v,w): + v = v/jnp.linalg.norm(v, axis=-1, keepdims=True) + w = w/jnp.linalg.norm(w, axis=-1, keepdims=True) + return jnp.arccos((v*w).sum(-1)) + + +# # # # # # # # # # # # # # # # # # # # +# +# Proposal Factories +# +# # # # # # # # # # # # # # # # # # # # +from b3d.pose import uniform_pose_in_ball +vmap_uniform_pose = jax.jit(jax.vmap(uniform_pose_in_ball.sample, (0,None,None,None))) + + +def make_two_frame_proposal(loss_func): + """ + Returns a pose proposal, using the following recipe. + - Sample *uniformly* around target pose, then + - compute the lossess, and + - return the the argmin. + """ + + def proposal(key, p0, p1, uvs0, uvs1, intr, rx=1.5, rq=0.25, S=100): + """ + Return pose a proposal around target pose `p1` as follows: + - Sample *uniformly* around target pose `p1`, then + - compute the lossess, and + - return the the argmin. + """ + # Create new key branch + _, key = keysplit(key, 1, 1) + + # Switch to relative poses. + q = p0.inv() @ p1 + + # Sample and score + # test poses + key, keys = keysplit(key, 1, S) + qs = vmap_uniform_pose(keys, q, rx, rq) + losses_ = jax.vmap(loss_func, (0,None,None,None))(qs, uvs0, uvs1, intr)[0] + loss = jnp.nan_to_num(losses_.sum(1), nan=jnp.inf) + + # Pick best test pose + # TODO: Resample? + i = jnp.argmin(loss) + q = qs[i] + + aux = {"proposals": qs, "loss": loss, "winner_index": i, "winner_loss": loss[i]} + + return q, aux + + return proposal + + +# # # # # # # # # # # # # # # # # # # # +# +# Appendix +# +# # # # # # # # # # # # # # # # # # # # +# NOTE/TODO: This doesn't work as well as the other scorer. +# I am just keeping this for further analysis. +def _epi_scorer_other_version(cam, u0, u1, intr): + """ + Computes the distances of `u1` to the epipolar lines induced by `u0` and `cam`. + """ + e = get_epipole(cam, intr) + + x0 = camera_from_screen_and_depth(u0, intr.far*jnp.ones(u0.shape[:-1]), intr) + l = screen_from_world(x0, cam, intr) + l_norm = jnp.sqrt(l[...,0]**2 + l[...,1]**2) + + l = l - e + u = u1 - e + + # TODO: Constrain so that we only consider the + # positive part of the line. That "should" (might) get rid of + # weird local maxima with points behind the camera. + d, s = dist_to_and_along_line(u, l) + d = jnp.where(s > 0.0, d, 1e2) + d = jnp.where(s < l_norm, d, 1e2) + + s = jnp.clip(s, 0.0, jnp.inf) + ys = e + s[:,None]*l/l_norm[:,None] + + return d, ys \ No newline at end of file diff --git a/b3d/chisight/sfm/gradient_descent_particle_inference.py b/b3d/chisight/sfm/gradient_descent_particle_inference.py new file mode 100644 index 00000000..8511f26d --- /dev/null +++ b/b3d/chisight/sfm/gradient_descent_particle_inference.py @@ -0,0 +1,87 @@ +import jax +import jax.numpy as jnp +import optax +from b3d.utils import keysplit +from b3d.pose import Pose, Rot +from b3d.camera import camera_from_screen_and_depth +from .utils import reprojection_error + + +def map_nested_fn(fn): + '''Recursively apply `fn` to the key-value pairs of a nested dict.''' + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + for k, v in nested_dict.items()} + return map_fn +label_fn = map_nested_fn(lambda k, _: k) + + +def map_over_nested_dict_values(f): + '''Recursively apply `f` to the values of a nested dict.''' + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else f(v)) + for k, v in nested_dict.items()} + return map_fn + + +def init_params(key, uvs0, cam0, intr): + _, key = keysplit(key,1,1) + + # Initialize 3d keypoints in + # fixed camera frame + N = uvs0.shape[0] + z = jax.random.normal(key, (N,))*.1 + 6. + xs = cam0(camera_from_screen_and_depth(uvs0, z, intr)) + params = {"xs": xs} + + return params + +def get_particle_positions(params): + return params["xs"] + + +def loss_function(params, uvs0, uvs1, cam0, cam1, intr): + xs = get_particle_positions(params) + err0 = reprojection_error(xs, uvs0, cam0, intr) + err1 = reprojection_error(xs, uvs1, cam1, intr) + return ( + jnp.mean(err0 + err1) + ) + +loss_func_grad = jax.value_and_grad(loss_function, argnums=(0,)) + + + +def make_fit(key, uvs0, uvs1, cam0, cam1, intr, learning_rate=1e-3): + + optimizer = optax.multi_transform( + { + 'xs': optax.adam(learning_rate), + }, + label_fn + ) + + @jax.jit + def step(carry, _): + params, opt_state, loss_args = carry + ell, (grads,) = loss_func_grad(params, *loss_args) + updates, opt_state = optimizer.update(grads, opt_state) + updates = map_over_nested_dict_values(jnp.nan_to_num)(updates) + params = optax.apply_updates(params, updates) + params['xs'] = params['xs'].at[:,2].set(jnp.clip(params['xs'][:,2], 0., jnp.inf)) + return ((params, opt_state, loss_args), ell) + + + params = init_params(key, uvs0, cam0, intr) + loss_args = (uvs0, uvs1, cam0, cam1, intr) + + def fit(params, steps=1_000): + _, subkey = keysplit(key,1,1) + opt_state = optimizer.init(params) + loss_args = (uvs0, uvs1, cam0, cam1, intr) + (params, opt_state, loss_args), losses = jax.lax.scan(step, (params, opt_state, loss_args), xs=None, length=steps) + + return params, losses + + return params, fit + diff --git a/b3d/chisight/sfm/particle_inference.py b/b3d/chisight/sfm/particle_inference.py new file mode 100644 index 00000000..471c2cc0 --- /dev/null +++ b/b3d/chisight/sfm/particle_inference.py @@ -0,0 +1,166 @@ +import jax +import jax.numpy as jnp +from b3d.pose import Pose, Rot +from b3d.camera import ( + screen_from_world, + screen_from_camera, + camera_from_screen, + world_from_screen, + camera_from_screen_and_depth, +) +from b3d.utils import keysplit +from sklearn.utils import Bunch + + +# TODO: Check this. ChatGPT spit that out. +def closest_points_on_lines(x, v, x_prime, v_prime): + """ + Given two affine lines computes point on each line + with minimal distance between them. + """ + # Define the direction vectors + a = v + b = v_prime + + # Define the vector between the two points on the lines + w0 = x - x_prime + + # Calculate coefficients for the system of linear equations + a_dot_a = jnp.dot(a, a) + b_dot_b = jnp.dot(b, b) + a_dot_b = jnp.dot(a, b) + a_dot_w0 = jnp.dot(a, w0) + b_dot_w0 = jnp.dot(b, w0) + + # Solving the system of linear equations for t and s + denom = a_dot_a * b_dot_b - a_dot_b * a_dot_b + + t = (a_dot_b * b_dot_w0 - b_dot_b * a_dot_w0) / denom + s = (a_dot_a * b_dot_w0 - a_dot_b * a_dot_w0) / denom + + # Calculate the closest points on the lines + p1 = x + t * a + p2 = x_prime + s * b + + return (p1, p2) + + +def _latent_keypoint_from_lines(u0, u1, cam0, cam1, intr): + """ + Returns keypoint that is closest to both keypoint lines. + """ + x0 = world_from_screen(u0, cam0, intr) + x1 = world_from_screen(u1, cam1, intr) + + a, b = closest_points_on_lines( + cam0.pos, x0 - cam0.pos, + cam1.pos, x1 - cam1.pos) + + x = (a + b)/2 + + return x + + +# # # # # # # # # # # # # # # # # # # # # # # # +# +# Gaussian Inference +# +# # # # # # # # # # # # # # # # # # # # # # # # +from jax.scipy.linalg import inv + + +# TODO: Check that +def gaussian_pdf_product(mean1, cov1, mean2, cov2): + """ + Computes the product of two 3D Gaussian PDFs. + + Args: + mean1: Mean vector of the first Gaussian (3-dimensional). + cov1: Covariance matrix of the first Gaussian (3x3 matrix). + mean2: Mean vector of the second Gaussian (3-dimensional). + cov2: Covariance matrix of the second Gaussian (3x3 matrix). + + Returns: + mean_prod: Mean vector of the product Gaussian. + cov_prod: Covariance matrix of the product Gaussian. + + + """ + # Someone bless the internet: + # > https://math.stackexchange.com/questions/157172/product-of-two-multivariate-gaussians-distributions + cov_prod = inv(inv(cov1) + inv(cov2)) + mean_prod = cov_prod @ (inv(cov1) @ mean1 + inv(cov2) @ mean2) + + return mean_prod, cov_prod + + +# TODO: Check that +def gaussian_pdf_product_multiple(means, covariances): + """ + Computes the product of multiple 3D Gaussian PDFs. + + Args: + means: A 2D array where each row is a mean vector of a Gaussian (N x 3). + covariances: A 3D array where each slice along the first dimension is a 3x3 covariance matrix (N x 3 x 3). + + Returns: + mean_prod: Mean vector of the product Gaussian. + cov_prod: Covariance matrix of the product Gaussian. + """ + # Convert means and covariances to JAX arrays + means = jnp.asarray(means) + covariances = jnp.asarray(covariances) + + # Compute the inverse of each covariance matrix + inv_covariances = jax.vmap(inv)(covariances) + + # Sum of the inverses of covariance matrices + cov_prod_inv = jnp.sum(inv_covariances, axis=0) + + # Compute the product covariance matrix + cov_prod = inv(cov_prod_inv) + + # Compute the weighted sum of means + weighted_means_sum = jnp.sum(jax.vmap(lambda inv_cov, mean: inv_cov @ mean)(inv_covariances, means), axis=0) + + # Compute the product mean vector + mean_prod = cov_prod @ weighted_means_sum + + return mean_prod, cov_prod + + +def rotation_from_first_column(key, a): + b = jnp.cros(a, jax.random.normal(key, (3,))) + c = jnp.cross(a, b) + + a = a/jnp.sqrt(a[0]**2 + a[1]**2 + a[2]**2) + b = b/jnp.sqrt(b[0]**2 + b[1]**2 + b[2]**2) + c = b/jnp.sqrt(c[0]**2 + c[1]**2 + c[2]**2) + + return jnp.stack([a,b,c], axis=1) + + +from b3d.pose import Pose, Rot +def cov_from_dq_composition(diag, quat): + """ + Covariance matrix from particle representation `(diag, quat)`, + where `diag` is an array of eigenvalues and `quat` is a quaternion + representing the matrix of eigenvectors. + """ + U = Rot.from_quat(quat).as_matrix() + C = U @ jnp.diag(diag) @ U.T + return C + + +def gaussian_from_keypoint(z, diag, u, cam, intr): + x = cam(camera_from_screen_and_depth(u, z, cam, intr)) + p = Pose.from_position_and_target(cam.pos, x) + cov = cov_from_dq_composition(diag, p.quat) + return x, cov + + +def _gaussian_keypoint_posterior(z0, z1, diag0, diag1, u0, u1, cam0, cam1, intr): + mu0, cov0 = gaussian_from_keypoint(z0, diag0, u0, cam0, intr) + mu1, cov1 = gaussian_from_keypoint(z1, diag1, u1, cam1, intr) + mu, cov = gaussian_pdf_product(mu0, cov0, mu1, cov1) + return mu, cov \ No newline at end of file diff --git a/b3d/chisight/sfm/utils.py b/b3d/chisight/sfm/utils.py new file mode 100644 index 00000000..3f61e43e --- /dev/null +++ b/b3d/chisight/sfm/utils.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +from b3d.camera import screen_from_world +from b3d.pose import Pose, Rot + + +def reprojection_error(xs, us, cam, intr): + us_ = screen_from_world(xs, cam, intr) + err = jnp.linalg.norm(us_ - us, axis=-1).sum() + return err + + +def line_intersects_box(x, dx, width, height): + + + dx = dx/jnp.linalg.norm(dx) + + # Define the box boundaries + x_min, x_max = 0.0, width + y_min, y_max = 0.0, height + + # Define the parameter t for the line equation x + t * dx + # t0_x = jnp.where(dx[0] != 0.0, (x_min - x[0]) / dx[0], -jnp.inf) + # t1_x = jnp.where(dx[0] != 0.0, (x_max - x[0]) / dx[0], jnp.inf) + # t0_y = jnp.where(dx[1] != 0.0, (y_min - x[1]) / dx[1], -jnp.inf) + # t1_y = jnp.where(dx[1] != 0.0, (y_max - x[1]) / dx[1], jnp.inf) + t0_x = (x_min - x[0]) / dx[0] + t1_x = (x_max - x[0]) / dx[0] + t0_y = (y_min - x[1]) / dx[1] + t1_y = (y_max - x[1]) / dx[1] + + ps = jnp.array([ + x + t0_x * dx, + x + t1_x * dx, + x + t0_y * dx, + x + t1_y * dx + ]) + + eps=1e-3 + valid = ( + (-eps <= ps[:,0]) * + (ps[:,0] <= width+eps) * + (-eps <= ps[:,1]) * + (ps[:,1] <= height+eps) + ) + + # TODO: Fix edgecases when x is far out + ds = jnp.abs(ps - jnp.array([[width/2, height/2]])).sum(1) + inds = jnp.argsort(ds)[:2] + seg = jax.lax.cond(valid.sum()>=2, lambda: ps[inds], lambda: jnp.tile(-jnp.inf, (2,2))) + + + return seg[0], seg[1] \ No newline at end of file diff --git a/b3d/chisight/shared/particle_system.py b/b3d/chisight/shared/particle_system.py index 7d55bead..c68a1cb2 100644 --- a/b3d/chisight/shared/particle_system.py +++ b/b3d/chisight/shared/particle_system.py @@ -1,17 +1,22 @@ -import jax.numpy as jnp import b3d -from b3d import Pose import jax import jax.numpy as jnp import genjax from genjax import gen from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs -from b3d import Pose, Mesh +from b3d import Mesh +from typing import Any, TypeAlias +from genjax import ChoiceMapBuilder as C +SparseGPSModelTrace: TypeAlias = Any +DenseGPSModelTrace: TypeAlias = Any +GPSModelTrace: TypeAlias = Any from b3d.chisight.sparse.gps_utils import add_dummy_var -from b3d.pose import uniform_pose_in_ball -dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None)) - +from b3d.pose import Pose, uniform_pose_in_ball +from b3d.utils import Bunch +from tensorflow_probability.substrates import jax as tfp +tfp_normal = b3d.modeling_utils.tfp_distribution(tfp.distributions.Normal) +dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None)) uniform_pose_args = (Pose.identity(), 2.0, 0.5) @gen @@ -140,16 +145,55 @@ def latent_particle_model( init_retval, scan_retvals ), final_state +# # # # # # # # # # # # # # # # # # # # # # # # # # +# +# Sparse +# +# # # # # # # # # # # # # # # # # # # # # # # # # # @genjax.gen def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma): # TODO: add visibility uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const) - uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" + uv_ = tfp_normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" return uv_ @genjax.gen def sparse_gps_model(latent_particle_model_args, obs_model_args): - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) + """ + Args: + `latent_particle_model_args`: Tuple containing + - number_of_timesteps:genjax.Pytree.const + - number_of_particlegenjax.Pytree.const + - number_of_object_clusters:genjax.Pytree.const + - particle prior args: (p:Pose, x_radius:Float, q_radius:Float) + - object prior args: (p:Pose, x_radius:Float, q_radius:Float) + - camera prior args: (p:Pose, x_radius:Float, q_radius:Float) + `obs_model_args`: Tuple containing instrinsics, and noise std. dev. + + Example: + ``` + particle_prior_params = (Pose.identity(), .5, 0.25) + object_prior_params = (Pose.identity(), 2., 0.5) + camera_prior_params = (Pose.identity(), 0.1, 0.1) + instrinsics = genjax.Pytree.const(b3d.camera.Intrinsics(120, 100, 50., 50., 50., 50., 0.001, 16.)) + sigma_obs = 0.2 + + model = sparse_gps_model + latent_args = ( + genjax.Pytree.const(T), # const object + genjax.Pytree.const(N), # const object + genjax.Pytree.const(K), # const object + particle_prior_params, + object_prior_params, + camera_prior_params + ) + observation_args = ( + instrinsics, + sigma_obs + ) + args = (latent_args, observation_args) + ``` + """ particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))( particle_dynamics_summary["absolute_particle_poses"], @@ -157,10 +201,38 @@ def sparse_gps_model(latent_particle_model_args, obs_model_args): particle_dynamics_summary["vis_mask"], *obs_model_args ) @ "obs" - return (particle_dynamics_summary, final_state, obs) + return Bunch( + particle_dynamics_summary = particle_dynamics_summary, + final_state = final_state, + observation = obs + ) +def get_sparse_test_model_and_args(T=4, N=5, K=3): + particle_prior_params = (Pose.identity(), .5, 0.25) + object_prior_params = (Pose.identity(), 2., 0.5) + camera_prior_params = (Pose.identity(), 0.1, 0.1) + instrinsics = genjax.Pytree.const(b3d.camera.Intrinsics(120, 100, 50., 50., 50., 50., 0.001, 16.)) + sigma_obs = 0.2 + + model = sparse_gps_model + latent_args = ( + genjax.Pytree.const(T), # const object + genjax.Pytree.const(N), # const object + genjax.Pytree.const(K), # const object + particle_prior_params, + object_prior_params, + camera_prior_params + ) + observation_args = (instrinsics, sigma_obs) + args = (latent_args, observation_args) + return model, args +# # # # # # # # # # # # # # # # # # # # # # # # # # +# +# Dense +# +# # # # # # # # # # # # # # # # # # # # # # # # # # def make_dense_gps_model(renderer): dense_observation_model = make_dense_observation_model(renderer) @@ -175,11 +247,55 @@ def dense_gps_model(latent_particle_model_args, dense_likelihood_args): (meshes, likelihood_args) = dense_likelihood_args merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame) image = dense_observation_model(merged_mesh, likelihood_args) @ "obs" - return (particle_dynamics_summary, final_state, image) + return Bunch( + particle_dynamics_summary = particle_dynamics_summary, + final_state = final_state, + observation = image + ) return dense_gps_model - +# # # # # # # # # # # # # # # # # # # # # # # # # # +# +# Quick Access +# +# # # # # # # # # # # # # # # # # # # # # # # # # # +def get_cameras(tr: GPSModelTrace): + # TODO: Should we leave it like that or grab it from the choice addresses + latent = tr.get_retval()["particle_dynamics_summary"] + return latent["camera_pose"] + +def get_observations(tr: GPSModelTrace): + # TODO: Should we leave it like that or grab it from the choice addrtresses + return tr.get_retval()["observation"] + +def set_camera_choice(t, cam: Pose, ch=None): + if ch is None: ch = C.n() + if t == Ellipsis: + ch = ch.merge(C["particle_dynamics", "state0", + "initial_camera_pose"].set(cam[0])) + ch = ch.merge(C["particle_dynamics", "states1+", + jnp.arange(cam.shape[0]-1), "camera_pose"].set(cam[1:])) + else: + if t == 0: + ch = ch.merge(C["particle_dynamics", "state0", "initial_camera_pose"].set(cam)) + elif t > 0: + ch = ch.merge(C["particle_dynamics", "states1+", t-1, "camera_pose"].set(cam)) + return ch + +def set_sensor_coordinates_choice(t, uvs, ch=None): + if ch is None: ch = C.n() + if t == Ellipsis: + ch = ch.merge(C["obs", jnp.arange(uvs.shape[0]), "sensor_coordinates"].set(uvs)) + else: + ch = ch.merge(C["obs", t , "sensor_coordinates"].set(uvs)) + return ch + +# # # # # # # # # # # # # # # # # # # # # # # # # # +# +# Vis +# +# # # # # # # # # # # # # # # # # # # # # # # # # # def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state): import rerun as rr (dynamic_state, static_state) = final_state diff --git a/b3d/io/feature_track_data.py b/b3d/io/feature_track_data.py index 7dbdc989..e710f37e 100644 --- a/b3d/io/feature_track_data.py +++ b/b3d/io/feature_track_data.py @@ -12,9 +12,10 @@ DESCR = """ FeatureTrackData: - Timesteps: {data.uv.shape[0]} + Num Frames: {data.uv.shape[0]} Num Keypoints: {data.uv.shape[1]} - Sensor shape (width x height): {data.rgb.shape[2]} x {data.rgb.shape[1]} + Image shape (width x height): {data.rgb.shape[2]} x {data.rgb.shape[1]} + FPS: {data.fps:0.0f} """ @dataclass @@ -122,7 +123,22 @@ def uv(self): return self.observed_keypoints_positions def vis(self): return self.visibility @property - def rgb(self): return self.rgbd_images[...,:3] + def rgb_float(self): + rgb = self.rgbd_images[...,:3] + if rgb.max() > 1.: rgb = rgb/255 + return rgb + + @property + def rgb_uint(self): + rgb = self.rgbd_images[...,:3] + if rgb.max() <= 1.: rgb = (rgb*255).astype(jnp.uint8) + return rgb + + @property + def rgb(self): + rgb = self.rgbd_images[...,:3] + if rgb.max() > 1.: rgb = rgb/255 + return self.rgb_float @property def visibility(self): return self.keypoint_visibility @@ -340,10 +356,10 @@ def quick_plot(self, t=None, fname=None, ax=None, figsize=(3,3), downsize=10): ax.set_aspect(1) ax.axis("off") - rgb = downsize_images(self.rgb, downsize) + rgb = downsize_images(self.rgb_float, downsize) if t is None: h,w = self.rgb.shape[1:3] - ax.imshow(np.concatenate(rgb/255, axis=1)) + ax.imshow(np.concatenate(rgb, axis=1)) ax.scatter(*np.concatenate( [ self.uv[t, self.vis[t]]/downsize + np.array([t*w,0])/downsize @@ -351,7 +367,7 @@ def quick_plot(self, t=None, fname=None, ax=None, figsize=(3,3), downsize=10): ] ).T, s=1) else: - ax.imshow(rgb[t]/255) + ax.imshow(rgb[t]) ax.scatter(*(self.uv[t, self.vis[t]]/downsize).T, s=1) diff --git a/b3d/plotting.py b/b3d/plotting.py new file mode 100644 index 00000000..8a078469 --- /dev/null +++ b/b3d/plotting.py @@ -0,0 +1,46 @@ +import numpy as np +import matplotlib.pyplot as plt +import imageio +from sklearn.utils import Bunch +from PIL import Image +import io +import jax.numpy as jnp +from IPython.display import Image as IPImage, display + + +def fig_to_image(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + buf = io.BytesIO() + fig.savefig(buf, format="png") + buf.seek(0) + img = Image.open(buf) + return img + + +def save_as_gif(fname, images, fps=10, loop=0): + """Save a list of images as a gif""" + if isinstance(images[0], np.ndarray) or isinstance(images[0], jnp.ndarray): + images = [Image.fromarray(im) for im in images] + + if not isinstance(images[0], Image.Image): + raise Exception("images need to be `(j)numpy.ndarray` or `PIL.Image.Image`") + + images[0].save( + fname, + save_all=True, + append_images=images[1:], + optimize=False, + duration=1000.0 / fps, + loop=loop, + ) + + return fname + + +def display_gif(fname): + return display(IPImage(data=open(fname, "rb").read(), format="png")) + + +def save_and_display_gif(fname, images, fps=10, loop=0): + return display_gif(save_as_gif(fname, images, fps=fps, loop=loop)) + diff --git a/b3d/pose/core.py b/b3d/pose/core.py index 230855b8..e62866fb 100644 --- a/b3d/pose/core.py +++ b/b3d/pose/core.py @@ -220,6 +220,13 @@ def __next__(self): def __getitem__(self, index): return Pose(self.pos[index], self.quat[index]) + # TODO: implement `.at[].set()`` + def _at_set(self, key, p): + return Pose( + self.pos.at[key].set(p.pos), + self.quat.at[key].set(p.quat) + ) + def slice(self, i): return Pose(self.pos[i], self.quat[i]) diff --git a/b3d/pose/pose_utils.py b/b3d/pose/pose_utils.py index 01cf6176..852e67d5 100644 --- a/b3d/pose/pose_utils.py +++ b/b3d/pose/pose_utils.py @@ -26,7 +26,7 @@ def unit_disc_to_sphere(x): xs = jnp.concatenate( [jnp.sin(r * jnp.pi / 2) * phi, jnp.cos(r * jnp.pi / 2)], axis=1 ) - return xs + return jnp.where(r==0, jnp.array([0.,0.,0.,1.]), xs) def volume_of_3_ball(r): diff --git a/b3d/utils.py b/b3d/utils.py index ffd20a54..6bf9e6ee 100644 --- a/b3d/utils.py +++ b/b3d/utils.py @@ -15,8 +15,9 @@ import trimesh import rerun as rr import distinctipy +from jax.tree_util import register_pytree_node_class +from builtins import tuple as _tuple -from sklearn.utils import Bunch # # # # # # # # # # # # # @@ -73,8 +74,48 @@ def keysplit(key, *ns): for n in ns: keys.append(keysplit(key, n)) return keys - + +@register_pytree_node_class +class Bunch(tuple): + """ + A Pytree Tuple Bunch Class. + Can be accessed like Tuple, Dict, and Bunch. + + Example: + ``` + b = Bunch(0, x=1, y=2) + asssert 0 == b[0] + assert 1 == b[1] and 2 == b[2] + asssert 1 == b.x and 2 == b.y + asssert 1 == b["x"] and 2 == b["y"] + ``` + """ + def __new__(cls, *args, **kwargs): + # NOTE: Keyword argument order is preserved + # > https://docs.python.org/3/whatsnew/3.6.html#whatsnew36-pep468 + return _tuple.__new__(cls, list(args) + list(kwargs.values())) + + def __init__(self, *args, **kwargs): + self._d = dict() + self._keys = list(kwargs.keys()) + for k,v in kwargs.items(): + self._d[k] = v + setattr(self, k, v) + + def __getitem__(self, k: str): + if isinstance(k, int): return super().__getitem__(k) + return self._d[k] + + def tree_flatten(self): + return (self, self._keys) + + @classmethod + def tree_unflatten(cls, aux_data, children): + k = len(aux_data) + n = len(children) + return cls(*children[:n-k], **dict(zip(aux_data, children[n-k:]))) + # # # # # # # # # # # # # # Other diff --git a/tests/test_bunch.py b/tests/test_bunch.py new file mode 100644 index 00000000..f00017de --- /dev/null +++ b/tests/test_bunch.py @@ -0,0 +1,30 @@ +import unittest +import jax +import jax.numpy as jnp +import b3d +from b3d.utils import Bunch +import genjax + +class MeshTests(unittest.TestCase): + + def test_bunch(self): + b = Bunch(1, 2, 3) + assert 1 == b[0] and 2 == b[1] and 3 == b[2] + + b = Bunch(1, x="x", y=3) + assert 1 == b[0] and "x" == b["x"] and 3 == b["y"] + assert 1 == b[0] and "x" == b.x and 3 == b.y + + @genjax.gen + def model(): + x = genjax.normal(0.,1.) @ "x" + return Bunch(1, x=x, y=2) + + key = jax.random.PRNGKey(0) + jsimulate = jax.jit(model.simulate) + + tr = jsimulate(key, ()) + b = tr.get_retval() + + assert b.x == tr.get_choices()("x").v, f"{tr.get_choices()('x').v}, {b.x}, {b}" + assert b[1] == tr.get_choices()("x").v, f"{tr.get_choices()('x').v}, {b[1]}, {b}" \ No newline at end of file diff --git a/tests/test_particle_system_setters.py b/tests/test_particle_system_setters.py new file mode 100644 index 00000000..219e1581 --- /dev/null +++ b/tests/test_particle_system_setters.py @@ -0,0 +1,70 @@ +import unittest +import genjax +from genjax import Pytree +import jax +import jax.numpy as jnp +import numpy as np +from b3d import Pose +import b3d +from b3d.chisight.shared.particle_system import ( + get_sparse_test_model_and_args, + get_cameras, + set_camera_choice, + get_observations, + set_sensor_coordinates_choice +) + +# Get a minimal model for testing +key = jax.random.PRNGKey(np.random.randint(1_000)) +T = 4 +N = 5 +K = 3 +model, args = get_sparse_test_model_and_args(T, N, K) + + +class MeshTests(unittest.TestCase): + + def test_camera_setter(self): + global key; + global model, args; + for t in range(T): + key,_ = jax.random.split(key) + + ch = set_camera_choice(t, Pose.id()) + tr, w = model.importance(key, ch, args) + cams = get_cameras(tr) + + assert jnp.allclose(cams[t].pos, jnp.zeros(3)) + assert jnp.allclose(cams[t].quat, jnp.array([0.,0.,0.,1.])) + + + key,_ = jax.random.split(key) + ch = set_camera_choice(..., Pose( + jnp.zeros((T,3)), + jnp.tile(jnp.array([0.,0.,0.,1.]), (T,1)) + )) + tr, w = model.importance(key, ch, args) + cams = get_cameras(tr) + + assert jnp.allclose(cams.pos, jnp.zeros((T,3))) + + + def test_observation_setter(self): + global key; + global model, args; + for t in range(T): + key,_ = jax.random.split(key) + + ch = set_sensor_coordinates_choice(t, jnp.zeros((N,2))) + tr, w = model.importance(key, ch, args) + uvs = get_observations(tr) + + # TODO: Wait for vmap importance bug has been resolved + assert jnp.allclose(uvs[t], jnp.zeros((N,2))) + + key,_ = jax.random.split(key) + ch = set_sensor_coordinates_choice(..., jnp.zeros((T,N,2))) + tr, w = model.importance(key, ch, args) + uvs = get_observations(tr) + + assert jnp.allclose(uvs, jnp.zeros((T,N,2))) \ No newline at end of file