diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index fa5feaeb86..56c45ac840 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -195,18 +195,18 @@ def _get_one_posterior_sample( and ( (return_sites is None) or (name in return_sites) ) # selected in return_sites list - and ( - ( - (not site.get("is_observed", True)) or return_observed - ) # don't save observed unless requested - or (site.get("infer", False).get("_deterministic", False)) - ) # unless it is deterministic and not isinstance( site.get("fn", None), poutine.subsample_messenger._Subsample ) # don't save plates ) } + if not return_observed: + observed_not_deterministic = self._get_observed_sites(*args, **kwargs) + sample = { + k: v for k, v in sample.items() if k not in observed_not_deterministic + } + sample = {name: site.cpu().numpy() for name, site in sample.items()} return sample @@ -309,21 +309,58 @@ def _get_obs_plate_sites( for name, site in trace.nodes.items() if ( (site["type"] == "sample") # sample statement - and ( - ( - (not site.get("is_observed", True)) or return_observed - ) # don't save observed unless requested - or (site.get("infer", False).get("_deterministic", False)) - ) # unless it is deterministic and not isinstance( site.get("fn", None), poutine.subsample_messenger._Subsample ) # don't save plates ) if any(f.name == plate_name for f in site["cond_indep_stack"]) } + if not return_observed: + observed_not_deterministic = self._get_observed_sites(*args, **kwargs) + obs_plate = { + k: v + for k, v in obs_plate.items() + if k not in observed_not_deterministic + } return obs_plate + def _get_observed_sites( + self, + args: list, + kwargs: dict, + ): + """ + Automatically guess which model sites correspond to observed variables + + This excludes pyro.deterministic variables. + + Parameters + ---------- + args + Arguments to the model. + kwargs + Keyword arguments to the model. + + Returns + ------- + List with site names. + """ + trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) + observed_sites = [ + name + for name, site in trace.nodes.items() + if ( + (site["type"] == "sample") # sample statement + and ( + site.get("is_observed", True) + and site.get("infer", False).get("_deterministic", False) + ) # exclude deterministic sites + ) + ] + + return observed_sites + def _posterior_samples_minibatch( self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs ):