-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Small Einsum is hanging #24929
Comments
Thanks for the report! It looks like a bug in import opt_einsum
opt_einsum.contract_path(
formula, *arrays, einsum_call=True, use_blas=True, optimize='optimal') It would be worth reporting upstream I think – would you like to report the issue there, or would you like us to take over? |
Ah I see interesting, I guess in that case I can get immediately unblocked by just changing the optimize kwarg for now. Went ahead and reported dgasmith/opt_einsum#243 |
According to dgasmith/opt_einsum#243, setting path='auto' might be a preferable default. As far as I understand, it defaults to 'optimal' if the number of components is small and will use something different if that will not run in a reasonable amount of time. |
#25055 changes to |
Description
Here's a small case where np.einsum works but jnp.einsum does not
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='849fd340451c', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: