From cb5b2e34df11964457c7851870a9e4402776bfb8 Mon Sep 17 00:00:00 2001 From: alicjapolanska Date: Tue, 31 Oct 2023 17:37:35 +0000 Subject: [PATCH] Switch vmap back to jax. --- harmonic/flows.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/harmonic/flows.py b/harmonic/flows.py index 401a001f..a910b940 100644 --- a/harmonic/flows.py +++ b/harmonic/flows.py @@ -150,7 +150,7 @@ def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array: Returns: jnp.ndarray (batch_size,): Predicted log_e posterior value. """ - get_logprob = nn.vmap(self.__call__, in_axes=[0, None]) + get_logprob = jax.vmap(self.__call__, in_axes=[0, None]) logprob = get_logprob(x, temperature) return logprob @@ -222,7 +222,7 @@ def setup(self): self.conditioner = conditioner self.scalar = scalar - self.vmap_call = nn.vmap(self.__call__) + self.vmap_call = jax.vmap(self.__call__) def bijector_fn(params: jnp.ndarray): return distrax.RationalQuadraticSpline( @@ -326,7 +326,7 @@ def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array: jnp.ndarray (batch_size,): Predicted log_e posterior value. """ - get_logprob = nn.vmap(self.__call__, in_axes=[0, None]) + get_logprob = jax.vmap(self.__call__, in_axes=[0, None]) logprob = get_logprob(x, temperature) return logprob