Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support MPS during training and inference #3100

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

ori-kron-wis
Copy link
Collaborator

@ori-kron-wis ori-kron-wis commented Dec 17, 2024

add support for Mac GPU (m1,m2,m3) for scvi and revert it to be default option when running from Mac (if available)

Testing are done manually on my MAC (has m3) + verification that still works on CPU/CUDA here

References:
pytorch/pytorch#132605
pytorch/pytorch#77764
https://discourse.scverse.org/t/macbook-m1-m2-mps-acceleration-with-scvi/2075/7

…ult option when running from mac (if availble)
@ori-kron-wis ori-kron-wis added cuda tests Run test suite on CUDA on-merge: backport to 1.2.x on-merge: backport to 1.2.x labels Dec 17, 2024
@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.2 milestone Dec 17, 2024
@ori-kron-wis ori-kron-wis self-assigned this Dec 17, 2024
Copy link

codecov bot commented Dec 17, 2024

Codecov Report

Attention: Patch coverage is 75.96154% with 25 lines in your changes missing coverage. Please review.

Project coverage is 83.20%. Comparing base (6ae39a4) to head (cc9a16e).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/scvi/train/_trainingplans.py 44.44% 5 Missing ⚠️
src/scvi/data/_preprocessing.py 66.66% 4 Missing ⚠️
src/scvi/module/_autozivae.py 63.63% 4 Missing ⚠️
src/scvi/distributions/_negative_binomial.py 90.62% 3 Missing ⚠️
src/scvi/external/velovi/_module.py 70.00% 3 Missing ⚠️
src/scvi/model/_utils.py 40.00% 3 Missing ⚠️
src/scvi/external/decipher/_module.py 66.66% 1 Missing ⚠️
src/scvi/model/base/_rnamixin.py 91.66% 1 Missing ⚠️
src/scvi/nn/_base_components.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3100      +/-   ##
==========================================
- Coverage   87.67%   83.20%   -4.48%     
==========================================
  Files         180      180              
  Lines       15187    15227      +40     
==========================================
- Hits        13315    12669     -646     
- Misses       1872     2558     +686     
Files with missing lines Coverage Δ
src/scvi/external/cellassign/_module.py 97.24% <100.00%> (ø)
src/scvi/model/_totalvi.py 87.29% <100.00%> (ø)
src/scvi/module/_mrdeconv.py 95.13% <100.00%> (ø)
src/scvi/module/_vae.py 94.92% <100.00%> (ø)
src/scvi/external/decipher/_module.py 98.76% <66.66%> (-1.24%) ⬇️
src/scvi/model/base/_rnamixin.py 94.46% <91.66%> (ø)
src/scvi/nn/_base_components.py 94.77% <66.66%> (-0.34%) ⬇️
src/scvi/distributions/_negative_binomial.py 83.26% <90.62%> (-0.65%) ⬇️
src/scvi/external/velovi/_module.py 81.36% <70.00%> (+0.17%) ⬆️
src/scvi/model/_utils.py 88.74% <40.00%> (-3.21%) ⬇️
... and 3 more

... and 9 files with indirect coverage changes

@canergen
Copy link
Member

Please enable the Mac Runner for changes about MPS.

@ori-kron-wis ori-kron-wis marked this pull request as ready for review December 18, 2024 15:57
@ori-kron-wis
Copy link
Collaborator Author

A comparison for training CPU, MPS M3 AND NVidia RTX 6000 Ada 48GB GDDR6
image

@canergen
Copy link
Member

The difference was much bigger in my hands when increasing batch size (larger batch size creates larger matrix multiplications where MPS is more efficient than CPU) and compilation. It was actually faster than an A100 on Google Colab then. We should optimize this a bit.

np.asarray(1.0 - (data > 0).sum(0) / data.shape[0]).ravel()
).to(device)
# in MPS we need to first change to float 32, as the MPS framework doesn't support float64.
if device.type == "mps":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do it for other devices. Float32 is sufficient for all computations.

@ori-kron-wis ori-kron-wis removed the mps label Dec 22, 2024
@canergen canergen changed the title feat: Support MPS during training feat: Support MPS during training and inference Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda tests Run test suite on CUDA on-merge: backport to 1.2.x on-merge: backport to 1.2.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants