Skip to content

Commit

Permalink
add everything as one commit
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl committed Nov 24, 2022
1 parent 91193e4 commit ed3a3f3
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions scvi/model/base/_pyromixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,18 @@ def _get_one_posterior_sample(
and (
(return_sites is None) or (name in return_sites)
) # selected in return_sites list
and (
(
(not site.get("is_observed", True)) or return_observed
) # don't save observed unless requested
or (site.get("infer", False).get("_deterministic", False))
) # unless it is deterministic
and not isinstance(
site.get("fn", None), poutine.subsample_messenger._Subsample
) # don't save plates
)
}

if not return_observed:
observed_not_deterministic = self._get_observed_sites(*args, **kwargs)
sample = {
k: v for k, v in sample.items() if k not in observed_not_deterministic
}

sample = {name: site.cpu().numpy() for name, site in sample.items()}

return sample
Expand Down Expand Up @@ -309,21 +309,58 @@ def _get_obs_plate_sites(
for name, site in trace.nodes.items()
if (
(site["type"] == "sample") # sample statement
and (
(
(not site.get("is_observed", True)) or return_observed
) # don't save observed unless requested
or (site.get("infer", False).get("_deterministic", False))
) # unless it is deterministic
and not isinstance(
site.get("fn", None), poutine.subsample_messenger._Subsample
) # don't save plates
)
if any(f.name == plate_name for f in site["cond_indep_stack"])
}
if not return_observed:
observed_not_deterministic = self._get_observed_sites(*args, **kwargs)
obs_plate = {
k: v
for k, v in obs_plate.items()
if k not in observed_not_deterministic
}

return obs_plate

def _get_observed_sites(
self,
args: list,
kwargs: dict,
):
"""
Automatically guess which model sites correspond to observed variables
This excludes pyro.deterministic variables.
Parameters
----------
args
Arguments to the model.
kwargs
Keyword arguments to the model.
Returns
-------
List with site names.
"""
trace = poutine.trace(self.module.model).get_trace(*args, **kwargs)
observed_sites = [
name
for name, site in trace.nodes.items()
if (
(site["type"] == "sample") # sample statement
and (
site.get("is_observed", True)
and site.get("infer", False).get("_deterministic", False)
) # exclude deterministic sites
)
]

return observed_sites

def _posterior_samples_minibatch(
self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
):
Expand Down

0 comments on commit ed3a3f3

Please sign in to comment.