You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Still uses JAX, disconnect with scripts that have been converted (distutil.py)
Line 30: gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)
pyrenew/distutil.py:50 > if not torch.all(discrete_dist >= 0):
TypeError: all() received an invalid combination of arguments - got (jaxlib.xla_extension.ArrayImpl), but expected ... [bunch of tensor options]
etc. jax =/= torch errors
The text was updated successfully, but these errors were encountered:
Still uses JAX, disconnect with scripts that have been converted (
distutil.py
)etc. jax =/= torch errors
The text was updated successfully, but these errors were encountered: