Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuebert authored and sritchie committed Feb 22, 2024
1 parent 82474b6 commit 37a0dc5
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 624 deletions.
13 changes: 6 additions & 7 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from genjax.incremental import Diff, NoChange, UnknownChange

import bayes3d as b
import bayes3d.scene_graph

from .genjax_distributions import (
contact_params_uniform,
Expand Down Expand Up @@ -128,14 +127,14 @@ def get_far_plane(trace):


def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = get_indices(trace).shape[0] + 1
N = b.get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]


add_object_jit = jax.jit(add_object)
Expand All @@ -152,7 +151,7 @@ def print_trace(trace):


def viz_trace_meshcat(trace, colors=None):
b.clear_visualizer()
b.clear()
b.show_cloud(
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
)
Expand Down Expand Up @@ -224,14 +223,14 @@ def enumerator(trace, key, *args):
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[0]
)[2]

def enumerator_with_weight(trace, key, *args):
return trace.update(
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[0:2]
)[1:3]

def enumerator_score(trace, key, *args):
return enumerator(trace, key, *args).get_score()
Expand Down Expand Up @@ -302,4 +301,4 @@ def update_address(trace, key, address, value):
key,
genjax.choice_map({address: value}),
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
)[0]
)[2]
58 changes: 0 additions & 58 deletions bayes3d/viser.py

This file was deleted.

373 changes: 0 additions & 373 deletions demo_c2f.ipynb

This file was deleted.

Loading

0 comments on commit 37a0dc5

Please sign in to comment.