diff --git a/edward2/jax/nn/random_feature.py b/edward2/jax/nn/random_feature.py index acc717f6..5535c20d 100644 --- a/edward2/jax/nn/random_feature.py +++ b/edward2/jax/nn/random_feature.py @@ -207,7 +207,11 @@ def __call__(self, inputs: Array) -> Array: # Performs forward pass. inputs = jnp.asarray(inputs, self.dtype) - outputs = lax.dot_general(inputs, kernel.value, + # Cast the kernel to correct dtype in case the parameter is saved and + # restored with a different dtype. + # TODO(b/235921783): Avoid casting dtype here. + kernel_value = jnp.asarray(kernel.value, self.dtype) + outputs = lax.dot_general(inputs, kernel_value, (contracting_dims, batch_dims)) outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)