Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 449584663
  • Loading branch information
fehiepsi authored and edward-bot committed Jun 14, 2022
1 parent ff5e774 commit 1d79315
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion edward2/jax/nn/random_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ def __call__(self, inputs: Array) -> Array:

# Performs forward pass.
inputs = jnp.asarray(inputs, self.dtype)
outputs = lax.dot_general(inputs, kernel.value,
# Cast the kernel to correct dtype in case the parameter is saved and
# restored with a different dtype.
# TODO(b/235921783): Avoid casting dtype here.
kernel_value = jnp.asarray(kernel.value, self.dtype)
outputs = lax.dot_general(inputs, kernel_value,
(contracting_dims, batch_dims))
outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)

Expand Down

0 comments on commit 1d79315

Please sign in to comment.