-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support MPS during training and inference #3100
base: main
Are you sure you want to change the base?
Conversation
…ult option when running from mac (if availble)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3100 +/- ##
==========================================
- Coverage 87.67% 83.20% -4.48%
==========================================
Files 180 180
Lines 15187 15227 +40
==========================================
- Hits 13315 12669 -646
- Misses 1872 2558 +686
|
Please enable the Mac Runner for changes about MPS. |
The difference was much bigger in my hands when increasing batch size (larger batch size creates larger matrix multiplications where MPS is more efficient than CPU) and compilation. It was actually faster than an A100 on Google Colab then. We should optimize this a bit. |
np.asarray(1.0 - (data > 0).sum(0) / data.shape[0]).ravel() | ||
).to(device) | ||
# in MPS we need to first change to float 32, as the MPS framework doesn't support float64. | ||
if device.type == "mps": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also do it for other devices. Float32 is sufficient for all computations.
for more information, see https://pre-commit.ci
…vi-tools into Ori-broadcast_all_fix
for more information, see https://pre-commit.ci
add support for Mac GPU (m1,m2,m3) for scvi and revert it to be default option when running from Mac (if available)
Testing are done manually on my MAC (has m3) + verification that still works on CPU/CUDA here
References:
pytorch/pytorch#132605
pytorch/pytorch#77764
https://discourse.scverse.org/t/macbook-m1-m2-mps-acceleration-with-scvi/2075/7