Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Posterior sampling with Messenger Pyro guides -> huge memory use. #1801

Open
vitkl opened this issue Nov 21, 2022 · 5 comments
Open

Posterior sampling with Messenger Pyro guides -> huge memory use. #1801

vitkl opened this issue Nov 21, 2022 · 5 comments

Comments

@vitkl
Copy link
Contributor

vitkl commented Nov 21, 2022

Hi

Posterior sampling with Messenger Pyro guides does not remove observed variables leading to huge memory use.
https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L184
I don't know whether when sampling from Messenger guides it is possible to easily detect and exclude observed variables. @fritzo any recommendations?

This can be addressed

  1. By adding an additional variable which allows users to exclude observed variables by name (like done here https://github.com/BayraktarLab/cell2location/blob/improved_posterior_quantile/cell2location/models/base/_pyro_mixin.py#L250).
  2. By deleting https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L182-L185

Related to BayraktarLab/cell2location#144

@vitkl vitkl added the bug label Nov 21, 2022
@fritzo
Copy link

fritzo commented Nov 21, 2022

Hi @vitkl, Could you provide more details about the memory leak? Does memory use happen during training/inference or during prediction? Where are the observed data being saved? Are you poutine.traceing the guide or somehow saving sample statements beyond the invocation of a single .guide()? Could you point to some places in the code where the problem occurs?

EDIT to clarify this bug, is it an issue with prediction in .median() and .quantiles(), or is it an issue with the usual training method .__call__()?

@vitkl
Copy link
Contributor Author

vitkl commented Nov 22, 2022

The goal of this function is to generate samples from the posterior, however, the same problem exists with .median() and .quantiles().

This is not really a memory leak (GPU memory use doesn't increase with time). It's simply that

if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)

returns all sites including observed sites. For gene expression data, that means returning the densified count matrix n_samples number of times. That's a huge matrix. During training, the model obviously needs to compute the likelihood of this data and that's not a problem because of minibatch training and because during training only one sample is generated. However, there is no need to return observed sites when generating posterior samples and when computing the median and quantiles.

An alternative is poutine.trace approach https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L186-L208 for sampling which doesn't have this problem because observed sites can be excluded https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L200. It doesn't address the issue with .median() and .quantiles() though.

@canergen
Copy link
Member

canergen commented Nov 22, 2022

To add my "solution" to this which uses infer.Predictive and checks once that the variable is on a per cell basis and not a global parameter. It records all samples. I am not sure what problem there is with .median()

    adata = self._validate_anndata(adata)
    train_dl = self._make_data_loader(
        adata=adata, indices=indices, shuffle=False, batch_size=batch_size
    )

    self.to_device(device)
    model = model if model else self.module.model

    # sample local parameters
    for tensor_dict in track(
        train_dl,
        style="tqdm",
        description="Sampling local variables, batch: ",
    ):
        args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
        args = [a.to(device) for a in args]
        kwargs = {k: v.to(device) for k, v in kwargs.items()}
        if library_size is not None:
            kwargs["library"] = torch.full_like(
                kwargs["library"], torch.log(torch.tensor(library_size))
            )

        samples_ = infer.Predictive(
            model,
            num_samples=num_samples,
            guide=self.module.guide,
            return_sites=return_sites,
        )(*args, **kwargs)
        if not samples:
            model_trace = poutine.trace(model).get_trace(*args, **kwargs)
            return_sites = [
                i
                for i in model_trace.nodes.keys()
                if model_trace.nodes[i].pop("cond_indep_stack", None)
            ]
            samples = {k: [v.cpu()] for k, v in samples_.items()}
        else:
            # Record minibatches if variable is minibatch dependent
            samples = {
                k: v + [samples_[k].cpu()] if k in return_sites else v
                for k, v in samples.items()
            }
    samples = {
        k: torch.cat(v, axis=1).numpy() for k, v in samples.items()
    }  # for each variable

It is critical for me to copy to cpu during execution (GPU memory to small). I have to run gc.collect() afterwards to get the GPU emptied so I think there is an additional leak.

@vitkl
Copy link
Contributor Author

vitkl commented Nov 23, 2022

The code in scvi-tools already checks which variables are in minibatch plate and aggregates minibatched variables along the correct dimensions in https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L357-L409. This code also copies each sample to CPU to avoid keeping 1000 samples on GPU: https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L210 .

You code assumes that presence of a plate means that this variable is minibatched and assumes -1 dimension. The code above implements a more general solution that checks plate name and fetches dimension.

I don't see how your code solves the problem of not saving observed variables. Does infer.Predictive return only unobserved variables?

When I was implementing this @martinjankowiak recommended to not use infer.Predictive because we have a simpler problem.

vitkl added a commit to vitkl/scvi-tools that referenced this issue Nov 24, 2022
Create a separate method to detect observed sites and apply it everywhere necessary to automatically exclude observed sites.
@vitkl
Copy link
Contributor Author

vitkl commented Nov 24, 2022

There is an additional bug: _get_obs_plate_sites correctly excludes observed variables from its list of obs_plate_sites - which means that observed sites are treated as global variables. This might explain why memory use increased from 3.8GB to 40GB rather than >> 40GB. Created a PR to fix both issues: #1805

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants