Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Training is very slow on Linux. #1504

Closed
haifengl opened this issue May 20, 2024 · 9 comments · Fixed by #1510
Closed

[PyTorch] Training is very slow on Linux. #1504

haifengl opened this issue May 20, 2024 · 9 comments · Fixed by #1510

Comments

@haifengl
Copy link

Training 10 epochs of MNIST (the sample code from your project README) on takes > 500 seconds on Linux (24 cores, ubuntu 22.04). It takes only about 50 seconds on an old mac (4 cores). Both use CPU (no GPU or MPS).

@saudet
Copy link
Member

saudet commented May 20, 2024

Try to reduce the number of threads used by PyTorch to 6 or 12, see https://stackoverflow.com/questions/76084214/what-is-recommended-number-of-threads-for-pytorch-based-on-available-cpu-cores

@HGuillemet
Copy link
Collaborator

It's most probably related to pytorch not finding openblas and/or MKL in your path.
Have you added mkl-platform-redist to your dependencies ?
You can also try to download and use the official libtorch, add the path containing its libs to your library path, and set -Dorg.bytedeco.javacpp.pathsFirst: the official binaries are statically built with MKL.

@haifengl
Copy link
Author

haifengl commented May 22, 2024

It helps a lot by set OMP_NUM_THREADS=12 on linux. The training speed is on par with mac (4 threads). Without it, torch.get_num_threads() returns 48. So the slowness may be caused by hyper-threading. According to your link, PyTorch will set the number of threads to the half of vCores. If so, we shouldn't have this issue on Linux. However, it is not the case with JavaCPP building. Do we miss some building configuration for Linux? Thanks!

@saudet
Copy link
Member

saudet commented May 22, 2024

So the default is 24 on that machine, but it doesn't mean it's going to give good results

@haifengl
Copy link
Author

haifengl commented May 23, 2024

The default is 48 with JavaCPP build, which is too high. It should be 24 for this case.

@HGuillemet
Copy link
Collaborator

Have you tried with the official libtorch ?

@haifengl
Copy link
Author

haifengl commented May 23, 2024

libtorch sets it to 24 by default on my box. And it works well. Why does JavaCPP build libtorch from source? Why not package the precompiled libtorch library from pytorch.org?

@HGuillemet
Copy link
Collaborator

See discussion here

@HGuillemet
Copy link
Collaborator

HGuillemet commented Jun 22, 2024

Here is the result of running the sample MNIST code on a machine with 32 vcores and 16 physical cores:

OpenMP lib Default num thread Speed
omp 32 Very slow
gomp 32 Somewhat slow
mkl static (official build) 16 fast

When forcing the num thread to 16 using OMP_NUM_THREADS or torch.set_num_threads, it's fast in all cases.
I'll try to rationalize that in the PR so that torch is linked with gomp on linux.
Also the fact that the presets preloads every possible openmp lib it finds, leading to possibly multiple different libraries loaded surely doesn't help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants