Skip to content

Commit

Permalink
minor fix in ClampModule to deal with low, high when given as floats
Browse files Browse the repository at this point in the history
  • Loading branch information
mayalenE committed Feb 22, 2023
1 parent 277fb0a commit 26978a3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions autodiscjax/modules/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ def __init__(self, out_treedef, out_shape, out_dtype, low=None, high=None):
super().__init__(out_treedef, out_shape, out_dtype)

if isinstance(low, float):
self.low = self.out_treedef.unflatten([low]*self.out_treedef.num_leaves)
self.low = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
low, self.out_shape, self.out_dtype)
else:
self.low = low

if isinstance(high, float):
self.high = self.out_treedef.unflatten([high]*self.out_treedef.num_leaves)
self.high = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
high, self.out_shape, self.out_dtype)
else:
self.high = high

Expand Down

0 comments on commit 26978a3

Please sign in to comment.