Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scPRINT is hanging forever #3

Open
jkobject opened this issue Sep 19, 2024 · 1 comment
Open

scPRINT is hanging forever #3

jkobject opened this issue Sep 19, 2024 · 1 comment
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@jkobject
Copy link
Collaborator

I downloaded the checkpoints from hugging face and loaded them. I am up to the embedder step in this tutorial https://github.com/jkobject/scPRINT/blob/main/docs/notebooks/cancer_usecase.ipynb

I first ran

adata, metrics = embedder(model, adata, cache=False, output_expression="none")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Embedder.__call__() got an unexpected keyword argument 'output_expression'

Then I ran with "output_expression" parameter removed. However it stops and automatically quits my python terminal. (I am running python interactively inside a conda env). I am wondering if this is a memory issue (currently using 1 GPU with 128GB). Should I try increasing the memory?

adata, metrics = embedder(model, adata, cache=False)
0%|                                                           | 0/1304 [00:00<?, ?it/s] 
(quits python terminal here)                                                                                                                                                                               

Originally posted by @kavithakrishna1 in jkobject#9 (comment)

@jkobject jkobject added bug Something isn't working help wanted Extra attention is needed labels Sep 19, 2024
@jkobject
Copy link
Collaborator Author

The code that the model is running is flash attention 2. It is not a dependency but part of the model. ScPRINT does it through triton. I have never tested scPRINT on 11.4..
So, you would have to use pdb and check if the model.predict() function gets called within the embedder class. Also can you check if the GPU memory gets used?

Finally to test it, you should set the input context to 200 and the minibatch size to 1 to check what happens.. maybe it is not using the GPU. (These are parameters of the embedder class)
e.g. embedder = Embedder(batch_size=1,num_workers=1, max_len=200) and maybe use an adata of only a couple cells

To make sure that this is due to triton, you can run the model with regular attention by doing:
model = scPrint.load_from_checkpoint( ckpt_path, precpt_gene_emb=None, transformer="normal")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant