-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Posterior sampling with Messenger Pyro guides -> huge memory use. #1801
Comments
Hi @vitkl, Could you provide more details about the memory leak? Does memory use happen during training/inference or during prediction? Where are the observed data being saved? Are you EDIT to clarify this bug, is it an issue with prediction in |
The goal of this function is to generate samples from the posterior, however, the same problem exists with This is not really a memory leak (GPU memory use doesn't increase with time). It's simply that
returns all sites including observed sites. For gene expression data, that means returning the densified count matrix n_samples number of times. That's a huge matrix. During training, the model obviously needs to compute the likelihood of this data and that's not a problem because of minibatch training and because during training only one sample is generated. However, there is no need to return observed sites when generating posterior samples and when computing the median and quantiles. An alternative is |
To add my "solution" to this which uses infer.Predictive and checks once that the variable is on a per cell basis and not a global parameter. It records all samples. I am not sure what problem there is with
It is critical for me to copy to cpu during execution (GPU memory to small). I have to run gc.collect() afterwards to get the GPU emptied so I think there is an additional leak. |
The code in scvi-tools already checks which variables are in minibatch plate and aggregates minibatched variables along the correct dimensions in https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L357-L409. This code also copies each sample to CPU to avoid keeping 1000 samples on GPU: https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L210 . You code assumes that presence of a plate means that this variable is minibatched and assumes -1 dimension. The code above implements a more general solution that checks plate name and fetches dimension. I don't see how your code solves the problem of not saving observed variables. Does When I was implementing this @martinjankowiak recommended to not use |
Create a separate method to detect observed sites and apply it everywhere necessary to automatically exclude observed sites.
There is an additional bug: |
Hi
Posterior sampling with Messenger Pyro guides does not remove observed variables leading to huge memory use.
https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L184
I don't know whether when sampling from Messenger guides it is possible to easily detect and exclude observed variables. @fritzo any recommendations?
This can be addressed
Related to BayraktarLab/cell2location#144
The text was updated successfully, but these errors were encountered: