diff --git a/b3d/chisight/shared/particle_system.py b/b3d/chisight/shared/particle_system.py index 7d55bead..ed96e461 100644 --- a/b3d/chisight/shared/particle_system.py +++ b/b3d/chisight/shared/particle_system.py @@ -113,7 +113,8 @@ def particle_system_state_step(carried_state, _): @gen def latent_particle_model( - num_timesteps, # const object + max_num_timesteps, # const object + num_timesteps, num_particles, # const object num_clusters, # const object relative_particle_poses_prior_params, @@ -121,9 +122,15 @@ def latent_particle_model( camera_pose_prior_params ): """ - Retval is a dict with keys "relative_particle_poses", "absolute_particle_poses", - "object_poses", "camera_poses", "vis_mask" - Leading dimension for each timestep is the batch dimension. + The retval is a dict with keys "object_assignments" and "masked_dynamic_state". + The value at "masked_dynamic_state" is a genjax.Mask object `m`. + `m.value` is a dictionary with keys "relative_particle_poses", "absolute_particle_poses", + "object_poses", "camera_poses", "vis_mask". + The leading dimension for each will have size `max_num_timesteps`. + The boolean array `m.flag` will indicate which of these timesteps are valid + (and which are values >= `num_timesteps`). + The values at these invalid timesteps are undefined. + Using these values directly will cause silent errors. """ (state0, init_retval) = initial_particle_system_state( num_particles, num_clusters, @@ -132,32 +139,57 @@ def latent_particle_model( camera_pose_prior_params ) @ "state0" - final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+" + masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator( + particle_system_state_step, + n=(max_num_timesteps.const-1) + )( + state0, + genjax.Mask( + # This next line tells the scan combinator how many timesteps to run + jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1, + jnp.zeros(max_num_timesteps.const - 1) + ) + ) @ "states1+" + # concatenate each element of init_retval, scan_retvals - return jax.tree.map( + concatenated_states_possibly_invalid = jax.tree.map( lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0), - init_retval, scan_retvals - ), final_state + init_retval, masked_scan_retvals.value + ) + masked_concatenated_states = genjax.Mask( + jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]), + concatenated_states_possibly_invalid + ) + + object_assignments = state0[1][0] + latent_dynamics_summary = { + "object_assignments": object_assignments, + "masked_dynamic_state": masked_concatenated_states, + } + + return latent_dynamics_summary @genjax.gen def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma): # TODO: add visibility uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const) - uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" + uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" return uv_ @genjax.gen def sparse_gps_model(latent_particle_model_args, obs_model_args): - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))( - particle_dynamics_summary["absolute_particle_poses"], - particle_dynamics_summary["camera_pose"], - particle_dynamics_summary["vis_mask"], + latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"] + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))( + masked_particle_dynamics_summary.flag, + _UNSAFE_particle_dynamics_summary["absolute_particle_poses"], + _UNSAFE_particle_dynamics_summary["camera_pose"], + _UNSAFE_particle_dynamics_summary["vis_mask"], *obs_model_args ) @ "obs" - return (particle_dynamics_summary, final_state, obs) + return (latent_dynamics_summary, masked_obs) @@ -166,26 +198,28 @@ def make_dense_gps_model(renderer): @genjax.gen def dense_gps_model(latent_particle_model_args, dense_likelihood_args): - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1] - camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1] + latent_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"] + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + + last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1 + absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index] + camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index] absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame (meshes, likelihood_args) = dense_likelihood_args merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame) image = dense_observation_model(merged_mesh, likelihood_args) @ "obs" - return (particle_dynamics_summary, final_state, image) + return (latent_dynamics_summary, image) return dense_gps_model -def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state): +def visualize_particle_system(latent_particle_model_args, latent_dynamics_summary): import rerun as rr - (dynamic_state, static_state) = final_state - ( - num_timesteps, # const object + max_num_timesteps, # const object + num_timesteps, num_particles, # const object num_clusters, # const object relative_particle_poses_prior_params, @@ -194,17 +228,20 @@ def visualize_particle_system(latent_particle_model_args, particle_dynamics_summ ) = latent_particle_model_args colors = b3d.distinct_colors(num_clusters.const) - absolute_particle_poses = particle_dynamics_summary["absolute_particle_poses"] - object_poses = particle_dynamics_summary["object_poses"] - camera_pose = particle_dynamics_summary["camera_pose"] - object_assignments = static_state[0] + + masked_particle_dynamics_summary = latent_dynamics_summary["masked_dynamic_state"] + object_assignments = latent_dynamics_summary["object_assignments"] + _UNSAFE_absolute_particle_poses = masked_particle_dynamics_summary.value["absolute_particle_poses"] + _UNSAFE_object_poses = masked_particle_dynamics_summary.value["object_poses"] + _UNSAFE_camera_pose = masked_particle_dynamics_summary.value["camera_pose"] cluster_colors = jnp.array(b3d.distinct_colors(num_clusters.const)) - for t in range(num_timesteps.const): + for t in range(num_timesteps): rr.set_time_sequence("time", t) + assert masked_particle_dynamics_summary.flag[t], "Erroring before attempting to unmask invalid masked data." - cam_pose = camera_pose[t] + cam_pose = _UNSAFE_camera_pose[t] rr.log( f"/camera", rr.Transform3D(translation=cam_pose.position, rotation=rr.Quaternion(xyzw=cam_pose.xyzw)), @@ -220,10 +257,10 @@ def visualize_particle_system(latent_particle_model_args, particle_dynamics_summ rr.log( "absolute_particle_poses", rr.Points3D( - absolute_particle_poses[t].pos, + _UNSAFE_absolute_particle_poses[t].pos, colors=cluster_colors[object_assignments] ) ) for i in range(num_clusters.const): - b3d.rr_log_pose(f"cluster/{i}", object_poses[t][i]) + b3d.rr_log_pose(f"cluster/{i}", _UNSAFE_object_poses[t][i]) diff --git a/b3d/modeling_utils.py b/b3d/modeling_utils.py index 9df52310..ea57d9cc 100644 --- a/b3d/modeling_utils.py +++ b/b3d/modeling_utils.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp +from genjax import Mask uniform_discrete = genjax.exact_density( lambda key, vals: jax.random.choice(key, vals), @@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs): d = dist(*args, **kwargs) return jnp.sum(d.log_prob(v)) - return genjax.exact_density(sampler, logpdf) \ No newline at end of file + return genjax.exact_density(sampler, logpdf) + +normal = tfp_distribution(tfp.distributions.Normal) + +def masked_scan_combinator(step, **scan_kwargs): + """ + Given a generative function `step` so that `step.scan(n=N)` is valid, + return a generative function accepting an input + `(initial_state, masked_input_values_array)` and returning a pair + `(masked_final_state, masked_returnvalue_sequence)`. + This operates similarly to `step.scan`, but the input values can be masked. + """ + mstep = step.mask().dimap( + pre=lambda masked_state, masked_inval: ( + jnp.logical_and(masked_state.flag, masked_inval.flag), + masked_state.value, + masked_inval.value + ), + post=lambda args, masked_retval: ( + Mask(masked_retval.flag, masked_retval.value[0]), + Mask(masked_retval.flag, masked_retval.value[1]) + ) + ) + + # This should be given a pair ( + # Mask(True, initial_state), + # Mask(bools_indicating_active, input_vals) + # ). + # It wll output a pair (masked_final_state, masked_returnvalue_sequence). + scanned = mstep.scan(**scan_kwargs) + + scanned_nice = scanned.dimap( + pre=lambda initial_state, masked_input_values: ( + Mask(True, initial_state), + Mask(masked_input_values.flag, masked_input_values.value) + ), + post=lambda args, retval: retval + ) + + return scanned_nice + +def variable_length_unfold_combinator(step, **scan_kwargs): + """ + Step should accept one arg, `state`, as input, + and should return a pair `(new_state, retval_for_this_timestep)`. + """ + scanned = masked_scan_combinator(step, **scan_kwargs) + return scanned.dimap( + pre=lambda initial_state, n_steps: ( + initial_state, + Mask(jnp.array()) + ) + ) \ No newline at end of file diff --git a/notebooks/integration.ipynb b/notebooks/integration.ipynb index d9da1c07..030ed799 100644 --- a/notebooks/integration.ipynb +++ b/notebooks/integration.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -21,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -53,13 +44,52 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "