From 0695876d7980318b13b08cb764b2cb587dd91db3 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Wed, 2 Oct 2024 16:25:58 -0700 Subject: [PATCH] Backport PR #3007 on branch 1.2.x (Change outlier detection for MRVI to ball admissibility calculation) (#3009) Backport PR #3007: Change outlier detection for MRVI to ball admissibility calculation Co-authored-by: Justin Hong --- CHANGELOG.md | 14 +++++++++++ src/scvi/external/mrvi/_model.py | 42 +++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb781b50a6..45ccc2f743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,20 @@ to [Semantic Versioning]. Full commit history is available in the ## Version 1.2 +### 1.2.1 (2024-XX-XX) + +#### Added + +#### Fixed + +- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI` + to correctly compute the maxmimum log-density across in-sample cells rather than the + aggregated posterior log-density {pr}`3007`. + +#### Changed + +#### Removed + ### 1.2.0 (2024-09-26) #### Added diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 4f3dabfbc4..853e1a4a63 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -14,6 +14,7 @@ from scvi.data import AnnDataManager, fields from scvi.external.mrvi._module import MRVAE from scvi.external.mrvi._types import MRVIReduction +from scvi.external.mrvi._utils import rowwise_max_excluding_diagonal from scvi.model.base import BaseModelClass, JaxTrainingMixin from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -745,7 +746,10 @@ def get_aggregated_posterior( indices: npt.ArrayLike | None = None, batch_size: int = 256, ) -> Distribution: - """Compute the aggregated posterior over the ``u`` latent representations. + """Computes the aggregated posterior over the ``u`` latent representations. + + For the specified samples, it computes the aggregated posterior over the ``u`` latent + representations. Returns a NumPyro MixtureSameFamily distribution. Parameters ---------- @@ -959,12 +963,13 @@ def get_outlier_cell_sample_pairs( admissibility_threshold: float = 0.0, batch_size: int = 256, ) -> xr.Dataset: - """Compute outlier cell-sample pairs. + """Compute admissibility scores for cell-sample pairs. - This function fits a GMM for each sample based on the latent representation of the cells in - the sample or computes an approximate aggregated posterior for each sample. Then, for every - cell, it computes the log-probability of the cell under the approximated posterior of each - sample as a measure of admissibility. + This function computes the posterior distribution for u for each cell. Then, for every + cell, it computes the log-probability of the cell under the posterior of each cell + each sample and takes the maximum value for a given sample as a measure of admissibility + for that sample. Additionally, it computes a threshold that determines if + a cell-sample pair is admissible based on the within-sample admissibility scores. Parameters ---------- @@ -995,21 +1000,34 @@ def get_outlier_cell_sample_pairs( adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) - log_probs_s = jnp.quantile( - ap.log_prob(adata_s.obsm["U"]).sum(axis=1), q=quantile_threshold - ) - n_splits = adata.n_obs // batch_size + in_max_comp_log_probs = ap.component_distribution.log_prob( + np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) + ).sum(axis=1) + log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) + log_probs_ = [] + n_splits = adata.n_obs // batch_size for u_rep in np.array_split(adata.obsm["U"], n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True))) + log_probs_.append( + jax.device_get( + ap.component_distribution.log_prob( + np.expand_dims(u_rep, ap.mixture_dim) + ) # (n_cells_batch, n_cells_ap, n_latent_dim) + .sum(axis=1) # (n_cells_batch, n_latent_dim) + .max(axis=1, keepdims=True) # (n_cells_batch, 1) + ) + ) log_probs_ = np.concatenate(log_probs_, axis=0) # (n_cells, 1) threshs.append(np.array(log_probs_s)) log_probs.append(np.array(log_probs_)) + threshs_all = np.concatenate(threshs) + global_thresh = np.quantile(threshs_all, q=quantile_threshold) + threshs = np.array(len(log_probs) * [global_thresh]) + log_probs = np.concatenate(log_probs, 1) - threshs = np.array(threshs) log_ratios = log_probs - threshs coords = {