Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small Einsum is hanging #24929

Open
ryan112358 opened this issue Nov 16, 2024 · 5 comments · May be fixed by #25214
Open

Small Einsum is hanging #24929

ryan112358 opened this issue Nov 16, 2024 · 5 comments · May be fixed by #25214
Assignees
Labels
bug Something isn't working

Comments

@ryan112358
Copy link

Description

Here's a small case where np.einsum works but jnp.einsum does not

import numpy as np
import jax.numpy as jnp

formula = 'a,c,d,db,ab,cb,ac,cd,ad,b->dbc'

arrays = [np.random.rand(*(2,)*len(key)) for key in formula.split('->')[0].split(',')]

np.einsum(formula, *arrays)
array([[[6.26532636e-05, 9.94054312e-04],
        [3.24902199e-05, 2.90052489e-03]],

       [[1.21862902e-05, 9.85561040e-05],
        [2.81959491e-06, 1.77314102e-04]]])


jnp.einsum(formula, *arrays)  # this hangs and does not complete

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='849fd340451c', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

@ryan112358 ryan112358 added the bug Something isn't working label Nov 16, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 16, 2024

Thanks for the report! It looks like a bug in opt_einsum, which jnp.einsum uses. Here's a more direct reproduction of the issue:

import opt_einsum
opt_einsum.contract_path(
    formula, *arrays, einsum_call=True, use_blas=True, optimize='optimal')

It would be worth reporting upstream I think – would you like to report the issue there, or would you like us to take over?

@ryan112358
Copy link
Author

Ah I see interesting, I guess in that case I can get immediately unblocked by just changing the optimize kwarg for now. Went ahead and reported dgasmith/opt_einsum#243

@ryan112358
Copy link
Author

According to dgasmith/opt_einsum#243, setting path='auto' might be a preferable default. As far as I understand, it defaults to 'optimal' if the number of components is small and will use something different if that will not run in a reasonable amount of time.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

#25055 changes to optimize='auto' for multi_dot; perhaps we should do the same for einsum.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

#25214

@jakevdp jakevdp linked a pull request Dec 2, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants