Skip to content

Commit

Permalink
Switch vmap back to jax.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 31, 2023
1 parent 5038e6f commit cb5b2e3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array:
Returns:
jnp.ndarray (batch_size,): Predicted log_e posterior value.
"""
get_logprob = nn.vmap(self.__call__, in_axes=[0, None])
get_logprob = jax.vmap(self.__call__, in_axes=[0, None])
logprob = get_logprob(x, temperature)

return logprob
Expand Down Expand Up @@ -222,7 +222,7 @@ def setup(self):
self.conditioner = conditioner
self.scalar = scalar

self.vmap_call = nn.vmap(self.__call__)
self.vmap_call = jax.vmap(self.__call__)

def bijector_fn(params: jnp.ndarray):
return distrax.RationalQuadraticSpline(
Expand Down Expand Up @@ -326,7 +326,7 @@ def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array:
jnp.ndarray (batch_size,): Predicted log_e posterior value.
"""

get_logprob = nn.vmap(self.__call__, in_axes=[0, None])
get_logprob = jax.vmap(self.__call__, in_axes=[0, None])
logprob = get_logprob(x, temperature)

return logprob
Expand Down

0 comments on commit cb5b2e3

Please sign in to comment.