Skip to content

Commit

Permalink
update example folder
Browse files Browse the repository at this point in the history
  • Loading branch information
mayalenE committed Feb 22, 2023
1 parent 26978a3 commit 0c83d06
Show file tree
Hide file tree
Showing 8 changed files with 3,231 additions and 3,275 deletions.
29 changes: 7 additions & 22 deletions autodiscjax/utils/create_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,10 @@ def create_intervention_module(intervention_config):
if intervention_config.intervention_type == "set_uniform":
intervention_fn = grn.PiecewiseSetConstantIntervention(
time_to_interval_fn=grn.TimeToInterval(intervals=intervention_config.controlled_intervals))
intervention_params_tree = DictTree()
for y_idx in intervention_config.controlled_node_ids:
intervention_params_tree.y[y_idx] = "placeholder"
intervention_params_treedef = jtu.tree_structure(intervention_params_tree)
intervention_params_shape = jtu.tree_map(lambda _: (len(intervention_config.controlled_intervals),), intervention_params_tree)
intervention_params_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)

intervention_low = DictTree(intervention_config.low)
intervention_low = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
intervention_low, intervention_params_shape,
intervention_params_dtype)
intervention_high = DictTree(intervention_config.high)
intervention_high = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
intervention_high, intervention_params_shape,
intervention_params_dtype)
random_intervention_generator = imgep.UniformRandomGenerator(intervention_params_treedef,
intervention_params_shape,
intervention_params_dtype,
intervention_low, intervention_high)
random_intervention_generator = imgep.UniformRandomGenerator(intervention_config.out_treedef,
intervention_config.out_shape,
intervention_config.out_dtype,
intervention_config.low, intervention_config.high)
else:
raise ValueError
return random_intervention_generator, intervention_fn
Expand Down Expand Up @@ -217,8 +202,8 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
gc_intervention_optimizer_config.high,
gc_intervention_optimizer_config.n_optim_steps,
gc_intervention_optimizer_config.n_workers,
init_noise_std=gc_intervention_optimizer_config.init_noise_std,
lr=gc_intervention_optimizer_config.lr,
gc_intervention_optimizer_config.init_noise_std,
gc_intervention_optimizer_config.lr,
)


Expand All @@ -230,7 +215,7 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
gc_intervention_optimizer_config.high,
gc_intervention_optimizer_config.n_optim_steps,
gc_intervention_optimizer_config.n_workers,
init_noise_std=gc_intervention_optimizer_config.init_noise_std
gc_intervention_optimizer_config.init_noise_std
)

else:
Expand Down
124 changes: 109 additions & 15 deletions examples/analyze_imgep_evaluation.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 0c83d06

Please sign in to comment.