You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to use this package with 0.7.2, but I encounter an error with the following code.
from aqt.jax.v2 import config
dot_general = config.dot_general_make(8, 8)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 4))
print(jnp.einsum('ij,jk->ik', x, y))
print(jnp.einsum('ij,jk->ik', x, y, _dot_general=dot_general))
ValueError: Non-hashable static arguments are not supported. An error occurred during a call to '_einsum' while trying to hash an object of type <class 'aqt.jax.v2.aqt_dot_general.DotGeneral'>, DotGeneral(fwd=DotGeneralRaw(lhs=Tensor(use_fwd_q
How to use it??
The text was updated successfully, but these errors were encountered:
I tried to use this package with 0.7.2, but I encounter an error with the following code.
How to use it??
The text was updated successfully, but these errors were encountered: