You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Then I ran with "output_expression" parameter removed. However it stops and automatically quits my python terminal. (I am running python interactively inside a conda env). I am wondering if this is a memory issue (currently using 1 GPU with 128GB). Should I try increasing the memory?
The code that the model is running is flash attention 2. It is not a dependency but part of the model. ScPRINT does it through triton. I have never tested scPRINT on 11.4..
So, you would have to use pdb and check if the model.predict() function gets called within the embedder class. Also can you check if the GPU memory gets used?
Finally to test it, you should set the input context to 200 and the minibatch size to 1 to check what happens.. maybe it is not using the GPU. (These are parameters of the embedder class)
e.g. embedder = Embedder(batch_size=1,num_workers=1, max_len=200) and maybe use an adata of only a couple cells
To make sure that this is due to triton, you can run the model with regular attention by doing: model = scPrint.load_from_checkpoint( ckpt_path, precpt_gene_emb=None, transformer="normal")
I downloaded the checkpoints from hugging face and loaded them. I am up to the embedder step in this tutorial https://github.com/jkobject/scPRINT/blob/main/docs/notebooks/cancer_usecase.ipynb
I first ran
Then I ran with "output_expression" parameter removed. However it stops and automatically quits my python terminal. (I am running python interactively inside a conda env). I am wondering if this is a memory issue (currently using 1 GPU with 128GB). Should I try increasing the memory?
Originally posted by @kavithakrishna1 in jkobject#9 (comment)
The text was updated successfully, but these errors were encountered: