From 5a07ed3f4657d8b720e21670fb5ffbdd1cb9e353 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 28 Jun 2024 19:56:56 +0000 Subject: [PATCH 1/3] variable length unfold model --- b3d/chisight/shared/particle_system.py | 56 ++++++++++----- b3d/modeling_utils.py | 55 ++++++++++++++- notebooks/integration.ipynb | 96 +++++++++++++++++++------- 3 files changed, 163 insertions(+), 44 deletions(-) diff --git a/b3d/chisight/shared/particle_system.py b/b3d/chisight/shared/particle_system.py index 6b7c1d13..ec63b387 100644 --- a/b3d/chisight/shared/particle_system.py +++ b/b3d/chisight/shared/particle_system.py @@ -8,7 +8,7 @@ from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs from b3d import Pose, Mesh from b3d.chisight.sparse.gps_utils import add_dummy_var -from b3d.chisight.sparse.pose_utils import uniform_pose_in_ball +from b3d.pose.pose_utils import uniform_pose_in_ball dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None)) @@ -113,7 +113,8 @@ def particle_system_state_step(carried_state, _): @gen def latent_particle_model( - num_timesteps, # const object + max_num_timesteps, # const object + num_timesteps, num_particles, # const object num_clusters, # const object relative_particle_poses_prior_params, @@ -132,32 +133,49 @@ def latent_particle_model( camera_pose_prior_params ) @ "state0" - final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+" + masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator( + particle_system_state_step, + n=(max_num_timesteps.const-1) + )( + state0, + genjax.Mask( + # This next line tells the scan combinator how many timesteps to run + jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1, + jnp.zeros(max_num_timesteps.const - 1) + ) + ) @ "states1+" + # concatenate each element of init_retval, scan_retvals - return jax.tree.map( + concatenated_states_possibly_invalid = jax.tree.map( lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0), - init_retval, scan_retvals + init_retval, masked_scan_retvals.value + ) + masked_concatenated_states = genjax.Mask( + jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]), + concatenated_states_possibly_invalid ) + return masked_concatenated_states @genjax.gen def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma): # TODO: add visibility uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const) - uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" + uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" return uv_ @genjax.gen def sparse_gps_model(latent_particle_model_args, obs_model_args): - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))( - particle_dynamics_summary["absolute_particle_poses"], - particle_dynamics_summary["camera_pose"], - particle_dynamics_summary["vis_mask"], + masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))( + masked_particle_dynamics_summary.flag, + _UNSAFE_particle_dynamics_summary["absolute_particle_poses"], + _UNSAFE_particle_dynamics_summary["camera_pose"], + _UNSAFE_particle_dynamics_summary["vis_mask"], *obs_model_args ) @ "observation" - return (particle_dynamics_summary, obs) + return (masked_particle_dynamics_summary, masked_obs) @@ -166,15 +184,17 @@ def make_dense_gps_model(renderer): @genjax.gen def dense_gps_model(latent_particle_model_args, dense_likelihood_args): - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1] - camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1] + masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + + last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1 + absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index] + camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index] absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame (meshes, likelihood_args) = dense_likelihood_args merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame) image = dense_observation_model(merged_mesh, likelihood_args) @ "observation" - return (particle_dynamics_summary, image) + return (masked_particle_dynamics_summary, image) return dense_gps_model \ No newline at end of file diff --git a/b3d/modeling_utils.py b/b3d/modeling_utils.py index 9df52310..ea57d9cc 100644 --- a/b3d/modeling_utils.py +++ b/b3d/modeling_utils.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp +from genjax import Mask uniform_discrete = genjax.exact_density( lambda key, vals: jax.random.choice(key, vals), @@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs): d = dist(*args, **kwargs) return jnp.sum(d.log_prob(v)) - return genjax.exact_density(sampler, logpdf) \ No newline at end of file + return genjax.exact_density(sampler, logpdf) + +normal = tfp_distribution(tfp.distributions.Normal) + +def masked_scan_combinator(step, **scan_kwargs): + """ + Given a generative function `step` so that `step.scan(n=N)` is valid, + return a generative function accepting an input + `(initial_state, masked_input_values_array)` and returning a pair + `(masked_final_state, masked_returnvalue_sequence)`. + This operates similarly to `step.scan`, but the input values can be masked. + """ + mstep = step.mask().dimap( + pre=lambda masked_state, masked_inval: ( + jnp.logical_and(masked_state.flag, masked_inval.flag), + masked_state.value, + masked_inval.value + ), + post=lambda args, masked_retval: ( + Mask(masked_retval.flag, masked_retval.value[0]), + Mask(masked_retval.flag, masked_retval.value[1]) + ) + ) + + # This should be given a pair ( + # Mask(True, initial_state), + # Mask(bools_indicating_active, input_vals) + # ). + # It wll output a pair (masked_final_state, masked_returnvalue_sequence). + scanned = mstep.scan(**scan_kwargs) + + scanned_nice = scanned.dimap( + pre=lambda initial_state, masked_input_values: ( + Mask(True, initial_state), + Mask(masked_input_values.flag, masked_input_values.value) + ), + post=lambda args, retval: retval + ) + + return scanned_nice + +def variable_length_unfold_combinator(step, **scan_kwargs): + """ + Step should accept one arg, `state`, as input, + and should return a pair `(new_state, retval_for_this_timestep)`. + """ + scanned = masked_scan_combinator(step, **scan_kwargs) + return scanned.dimap( + pre=lambda initial_state, n_steps: ( + initial_state, + Mask(jnp.array()) + ) + ) \ No newline at end of file diff --git a/notebooks/integration.ipynb b/notebooks/integration.ipynb index d9da1c07..a1944165 100644 --- a/notebooks/integration.ipynb +++ b/notebooks/integration.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -21,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -53,13 +44,52 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "StaticTrace(...)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "key = jax.random.PRNGKey(125)\n", "\n", - "num_timesteps = Pytree.const(4)\n", + "max_num_timesteps = Pytree.const(8)\n", + "num_timesteps = 4\n", "num_particles = Pytree.const(5)\n", "num_clusters = Pytree.const(3)\n", "relative_particle_poses_prior_params = (Pose.identity(), .5, 0.25)\n", @@ -70,7 +100,8 @@ "\n", "trace = ps.sparse_gps_model.simulate(key, (\n", " (\n", - " num_timesteps, # const object\n", + " max_num_timesteps, # const object\n", + " num_timesteps,\n", " num_particles, # const object\n", " num_clusters, # const object\n", " relative_particle_poses_prior_params,\n", @@ -78,18 +109,33 @@ " camera_pose_prior_params\n", " ),\n", " (instrinsics, sigma_obs)\n", - "))" + "))\n", + "trace" ] }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -101,7 +147,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -113,13 +159,13 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ - "XorChm(...)" + "StaticChm(...)" ] }, - "execution_count": 65, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } From ed3d5ca5a998942969df4e028cb181c108d9a27d Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 28 Jun 2024 19:56:56 +0000 Subject: [PATCH 2/3] variable length unfold model --- b3d/chisight/shared/particle_system.py | 61 ++++++++++++++-- b3d/modeling_utils.py | 55 ++++++++++++++- notebooks/integration.ipynb | 96 +++++++++++++++++++------- 3 files changed, 182 insertions(+), 30 deletions(-) diff --git a/b3d/chisight/shared/particle_system.py b/b3d/chisight/shared/particle_system.py index 7d55bead..4cef1cfe 100644 --- a/b3d/chisight/shared/particle_system.py +++ b/b3d/chisight/shared/particle_system.py @@ -8,7 +8,11 @@ from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs from b3d import Pose, Mesh from b3d.chisight.sparse.gps_utils import add_dummy_var +<<<<<<< HEAD from b3d.pose import uniform_pose_in_ball +======= +from b3d.pose.pose_utils import uniform_pose_in_ball +>>>>>>> 5a07ed3 (variable length unfold model) dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None)) @@ -113,7 +117,8 @@ def particle_system_state_step(carried_state, _): @gen def latent_particle_model( - num_timesteps, # const object + max_num_timesteps, # const object + num_timesteps, num_particles, # const object num_clusters, # const object relative_particle_poses_prior_params, @@ -132,23 +137,45 @@ def latent_particle_model( camera_pose_prior_params ) @ "state0" - final_state, scan_retvals = particle_system_state_step.scan(n=(num_timesteps.const - 1))(state0, None) @ "states1+" + masked_final_state, masked_scan_retvals = b3d.modeling_utils.masked_scan_combinator( + particle_system_state_step, + n=(max_num_timesteps.const-1) + )( + state0, + genjax.Mask( + # This next line tells the scan combinator how many timesteps to run + jnp.arange(max_num_timesteps.const - 1) < num_timesteps - 1, + jnp.zeros(max_num_timesteps.const - 1) + ) + ) @ "states1+" + # concatenate each element of init_retval, scan_retvals - return jax.tree.map( + concatenated_states_possibly_invalid = jax.tree.map( lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0), +<<<<<<< HEAD init_retval, scan_retvals ), final_state +======= + init_retval, masked_scan_retvals.value + ) + masked_concatenated_states = genjax.Mask( + jnp.concatenate([jnp.array([True]), masked_scan_retvals.flag]), + concatenated_states_possibly_invalid + ) + return masked_concatenated_states +>>>>>>> 5a07ed3 (variable length unfold model) @genjax.gen def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma): # TODO: add visibility uv = b3d.camera.screen_from_world(particle_absolute_poses.pos, camera_pose, instrinsics.const) - uv_ = genjax.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" + uv_ = b3d.modeling_utils.normal(uv, jnp.tile(sigma, uv.shape)) @ "sensor_coordinates" return uv_ @genjax.gen def sparse_gps_model(latent_particle_model_args, obs_model_args): +<<<<<<< HEAD # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))( @@ -158,6 +185,18 @@ def sparse_gps_model(latent_particle_model_args, obs_model_args): *obs_model_args ) @ "obs" return (particle_dynamics_summary, final_state, obs) +======= + masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))( + masked_particle_dynamics_summary.flag, + _UNSAFE_particle_dynamics_summary["absolute_particle_poses"], + _UNSAFE_particle_dynamics_summary["camera_pose"], + _UNSAFE_particle_dynamics_summary["vis_mask"], + *obs_model_args + ) @ "observation" + return (masked_particle_dynamics_summary, masked_obs) +>>>>>>> 5a07ed3 (variable length unfold model) @@ -166,19 +205,33 @@ def make_dense_gps_model(renderer): @genjax.gen def dense_gps_model(latent_particle_model_args, dense_likelihood_args): +<<<<<<< HEAD # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1] camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1] +======= + masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" + _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value + + last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1 + absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index] + camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index] +>>>>>>> 5a07ed3 (variable length unfold model) absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame (meshes, likelihood_args) = dense_likelihood_args merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame) +<<<<<<< HEAD image = dense_observation_model(merged_mesh, likelihood_args) @ "obs" return (particle_dynamics_summary, final_state, image) return dense_gps_model +======= + image = dense_observation_model(merged_mesh, likelihood_args) @ "observation" + return (masked_particle_dynamics_summary, image) +>>>>>>> 5a07ed3 (variable length unfold model) def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state): import rerun as rr diff --git a/b3d/modeling_utils.py b/b3d/modeling_utils.py index 9df52310..ea57d9cc 100644 --- a/b3d/modeling_utils.py +++ b/b3d/modeling_utils.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp +from genjax import Mask uniform_discrete = genjax.exact_density( lambda key, vals: jax.random.choice(key, vals), @@ -35,4 +36,56 @@ def logpdf(v, *args, **kwargs): d = dist(*args, **kwargs) return jnp.sum(d.log_prob(v)) - return genjax.exact_density(sampler, logpdf) \ No newline at end of file + return genjax.exact_density(sampler, logpdf) + +normal = tfp_distribution(tfp.distributions.Normal) + +def masked_scan_combinator(step, **scan_kwargs): + """ + Given a generative function `step` so that `step.scan(n=N)` is valid, + return a generative function accepting an input + `(initial_state, masked_input_values_array)` and returning a pair + `(masked_final_state, masked_returnvalue_sequence)`. + This operates similarly to `step.scan`, but the input values can be masked. + """ + mstep = step.mask().dimap( + pre=lambda masked_state, masked_inval: ( + jnp.logical_and(masked_state.flag, masked_inval.flag), + masked_state.value, + masked_inval.value + ), + post=lambda args, masked_retval: ( + Mask(masked_retval.flag, masked_retval.value[0]), + Mask(masked_retval.flag, masked_retval.value[1]) + ) + ) + + # This should be given a pair ( + # Mask(True, initial_state), + # Mask(bools_indicating_active, input_vals) + # ). + # It wll output a pair (masked_final_state, masked_returnvalue_sequence). + scanned = mstep.scan(**scan_kwargs) + + scanned_nice = scanned.dimap( + pre=lambda initial_state, masked_input_values: ( + Mask(True, initial_state), + Mask(masked_input_values.flag, masked_input_values.value) + ), + post=lambda args, retval: retval + ) + + return scanned_nice + +def variable_length_unfold_combinator(step, **scan_kwargs): + """ + Step should accept one arg, `state`, as input, + and should return a pair `(new_state, retval_for_this_timestep)`. + """ + scanned = masked_scan_combinator(step, **scan_kwargs) + return scanned.dimap( + pre=lambda initial_state, n_steps: ( + initial_state, + Mask(jnp.array()) + ) + ) \ No newline at end of file diff --git a/notebooks/integration.ipynb b/notebooks/integration.ipynb index d9da1c07..a1944165 100644 --- a/notebooks/integration.ipynb +++ b/notebooks/integration.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -21,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -53,13 +44,52 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "StaticTrace(...)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "key = jax.random.PRNGKey(125)\n", "\n", - "num_timesteps = Pytree.const(4)\n", + "max_num_timesteps = Pytree.const(8)\n", + "num_timesteps = 4\n", "num_particles = Pytree.const(5)\n", "num_clusters = Pytree.const(3)\n", "relative_particle_poses_prior_params = (Pose.identity(), .5, 0.25)\n", @@ -70,7 +100,8 @@ "\n", "trace = ps.sparse_gps_model.simulate(key, (\n", " (\n", - " num_timesteps, # const object\n", + " max_num_timesteps, # const object\n", + " num_timesteps,\n", " num_particles, # const object\n", " num_clusters, # const object\n", " relative_particle_poses_prior_params,\n", @@ -78,18 +109,33 @@ " camera_pose_prior_params\n", " ),\n", " (instrinsics, sigma_obs)\n", - "))" + "))\n", + "trace" ] }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -101,7 +147,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -113,13 +159,13 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ - "XorChm(...)" + "StaticChm(...)" ] }, - "execution_count": 65, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } From 81a1ac287b6ca49dd34f9bc1097d52997247d82d Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 1 Jul 2024 14:09:56 +0000 Subject: [PATCH 3/3] fix unfinished merge --- b3d/chisight/shared/particle_system.py | 36 -------------------------- 1 file changed, 36 deletions(-) diff --git a/b3d/chisight/shared/particle_system.py b/b3d/chisight/shared/particle_system.py index 4cef1cfe..30035fb9 100644 --- a/b3d/chisight/shared/particle_system.py +++ b/b3d/chisight/shared/particle_system.py @@ -8,11 +8,7 @@ from b3d.chisight.dense.dense_likelihood import make_dense_observation_model, DenseImageLikelihoodArgs from b3d import Pose, Mesh from b3d.chisight.sparse.gps_utils import add_dummy_var -<<<<<<< HEAD from b3d.pose import uniform_pose_in_ball -======= -from b3d.pose.pose_utils import uniform_pose_in_ball ->>>>>>> 5a07ed3 (variable length unfold model) dummy_mapped_uniform_pose = add_dummy_var(uniform_pose_in_ball).vmap(in_axes=(0,None,None,None)) @@ -153,10 +149,6 @@ def latent_particle_model( # concatenate each element of init_retval, scan_retvals concatenated_states_possibly_invalid = jax.tree.map( lambda t1, t2: jnp.concatenate([t1[None, :], t2], axis=0), -<<<<<<< HEAD - init_retval, scan_retvals - ), final_state -======= init_retval, masked_scan_retvals.value ) masked_concatenated_states = genjax.Mask( @@ -164,7 +156,6 @@ def latent_particle_model( concatenated_states_possibly_invalid ) return masked_concatenated_states ->>>>>>> 5a07ed3 (variable length unfold model) @genjax.gen def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, instrinsics, sigma): @@ -175,17 +166,6 @@ def sparse_observation_model(particle_absolute_poses, camera_pose, visibility, i @genjax.gen def sparse_gps_model(latent_particle_model_args, obs_model_args): -<<<<<<< HEAD - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - obs = sparse_observation_model.vmap(in_axes=(0, 0, 0, None, None))( - particle_dynamics_summary["absolute_particle_poses"], - particle_dynamics_summary["camera_pose"], - particle_dynamics_summary["vis_mask"], - *obs_model_args - ) @ "obs" - return (particle_dynamics_summary, final_state, obs) -======= masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value masked_obs = sparse_observation_model.mask().vmap(in_axes=(0, 0, 0, 0, None, None))( @@ -196,7 +176,6 @@ def sparse_gps_model(latent_particle_model_args, obs_model_args): *obs_model_args ) @ "observation" return (masked_particle_dynamics_summary, masked_obs) ->>>>>>> 5a07ed3 (variable length unfold model) @@ -205,33 +184,18 @@ def make_dense_gps_model(renderer): @genjax.gen def dense_gps_model(latent_particle_model_args, dense_likelihood_args): -<<<<<<< HEAD - # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1) - particle_dynamics_summary, final_state = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" - absolute_particle_poses_last_frame = particle_dynamics_summary["absolute_particle_poses"][-1] - camera_pose_last_frame = particle_dynamics_summary["camera_pose"][-1] -======= masked_particle_dynamics_summary = latent_particle_model(*latent_particle_model_args) @ "particle_dynamics" _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary.value last_timestep_index = jnp.sum(masked_particle_dynamics_summary.flag) - 1 absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary["absolute_particle_poses"][last_timestep_index] camera_pose_last_frame = _UNSAFE_particle_dynamics_summary["camera_pose"][last_timestep_index] ->>>>>>> 5a07ed3 (variable length unfold model) absolute_particle_poses_in_camera_frame = camera_pose_last_frame.inv() @ absolute_particle_poses_last_frame (meshes, likelihood_args) = dense_likelihood_args merged_mesh = Mesh.transform_and_merge_meshes(meshes, absolute_particle_poses_in_camera_frame) -<<<<<<< HEAD - image = dense_observation_model(merged_mesh, likelihood_args) @ "obs" - return (particle_dynamics_summary, final_state, image) - - return dense_gps_model - -======= image = dense_observation_model(merged_mesh, likelihood_args) @ "observation" return (masked_particle_dynamics_summary, image) ->>>>>>> 5a07ed3 (variable length unfold model) def visualize_particle_system(latent_particle_model_args, particle_dynamics_summary, final_state): import rerun as rr