diff --git a/pertpy/tools/_cinemaot.py b/pertpy/tools/_cinemaot.py index b614f67a..8e367deb 100644 --- a/pertpy/tools/_cinemaot.py +++ b/pertpy/tools/_cinemaot.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -639,6 +639,84 @@ def attribution_scatter( s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6) return c_effect, s_effect + def plot_vis_matching( + self, + adata: AnnData, + de: AnnData, + pert_key: str, + control: str, + de_label: str, + source_label: str, + matching_rep: str = "ot", + resolution: float = 0.5, + normalize: str = "col", + title: str = "CINEMA-OT matching matrix", + min_val: float = 0.01, + show: bool = True, + save: str | None = None, + ax: Axes | None = None, + **kwargs, + ) -> None: + """Visualize the CINEMA-OT matching matrix. + + Args: + adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted. + de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted(). + pert_key: The column of `.obs` with perturbation categories, should also contain `control`. + control: Control category from the `pert_key` column. + de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0. + source_label: the confounder / cell type label. + matching_rep: the place that stores the matching matrix. default de.obsm['ot']. + normalize: normalize the coarse-grained matching matrix by row / column. + title: the title for the figure. + min_val: The min value to truncate the matching matrix. + show: Show the plot, do not return axis. + save: If `True` or a `str`, save the figure. A string is appended to the default filename. + Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. + **kwargs: Other parameters to input for seaborn.heatmap. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.cinemaot_example() + >>> cot = pt.tl.Cinemaot() + >>> de = cot.causaleffect( + >>> adata, pert_key="perturbation", control="No stimulation", return_matching=True, + >>> thres=0.5, smoothness=1e-5, eps=1e-3, solver="Sinkhorn", preweight_label="cell_type0528") + >>> cot.plot_vis_matching( + >>> adata, de, pert_key="perturbation",control="No stimulation", de_label=None, source_label="cell_type0528") + """ + adata_ = adata[adata.obs[pert_key] == control] + + df = pd.DataFrame(de.obsm[matching_rep]) + if de_label is None: + de_label = "leiden" + sc.pp.neighbors(de, use_rep="X_embedding") + sc.tl.leiden(de, resolution=resolution) + df["de_label"] = de.obs[de_label].astype(str).values + df["de_label"] = "Response " + df["de_label"] + df = df.groupby("de_label").sum().T + df["source_label"] = adata_.obs[source_label].astype(str).values + df = df.groupby("source_label").sum() + + if normalize == "col": + df = df / df.sum(axis=0) + else: + df = (df.T / df.sum(axis=1)).T + df = df.clip(lower=min_val) - min_val + if normalize == "col": + df = df / df.sum(axis=0) + else: + df = (df.T / df.sum(axis=1)).T + + g = sns.heatmap(df, annot=True, ax=ax, **kwargs) + plt.title(title) + _utils.savefig_or_show("matching_heatmap", show=show, save=save) + if not show: + if ax is not None: + return ax + else: + return g + class Xi: """ @@ -859,81 +937,3 @@ def fit(self, P: ArrayLike): P_eps = np.diag(self._D1)[:, None] * P * np.diag(self._D2)[None, :] return P_eps - - def plot_vis_matching( - self, - adata: AnnData, - de: AnnData, - pert_key: str, - control: str, - de_label: str, - source_label: str, - matching_rep: str = "ot", - resolution: float = 0.5, - normalize: str = "col", - title: str = "CINEMA-OT matching matrix", - min_val: float = 0.01, - show: bool = True, - save: str | None = None, - ax: Axes | None = None, - **kwargs, - ) -> None: - """Visualize the CINEMA-OT matching matrix. - - Args: - adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted. - de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted(). - pert_key: The column of `.obs` with perturbation categories, should also contain `control`. - control: Control category from the `pert_key` column. - de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0. - source_label: the confounder / cell type label. - matching_rep: the place that stores the matching matrix. default de.obsm['ot']. - normalize: normalize the coarse-grained matching matrix by row / column. - title: the title for the figure. - min_val: The min value to truncate the matching matrix. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. - **kwargs: Other parameters to input for seaborn.heatmap. - - Examples: - >>> import pertpy as pt - >>> adata = pt.dt.cinemaot_example() - >>> cot = pt.tl.Cinemaot() - >>> de = cot.causaleffect( - >>> adata, pert_key="perturbation", control="No stimulation", return_matching=True, - >>> thres=0.5, smoothness=1e-5, eps=1e-3, solver="Sinkhorn", preweight_label="cell_type0528") - >>> cot.plot_vis_matching( - >>> adata, de, pert_key="perturbation",control="No stimulation", de_label=None, source_label="cell_type0528") - """ - adata_ = adata[adata.obs[pert_key] == control] - - df = pd.DataFrame(de.obsm[matching_rep]) - if de_label is None: - de_label = "leiden" - sc.pp.neighbors(de, use_rep="X_embedding") - sc.tl.leiden(de, resolution=resolution) - df["de_label"] = de.obs[de_label].astype(str).values - df["de_label"] = "Response " + df["de_label"] - df = df.groupby("de_label").sum().T - df["source_label"] = adata_.obs[source_label].astype(str).values - df = df.groupby("source_label").sum() - - if normalize == "col": - df = df / df.sum(axis=0) - else: - df = (df.T / df.sum(axis=1)).T - df = df.clip(lower=min_val) - min_val - if normalize == "col": - df = df / df.sum(axis=0) - else: - df = (df.T / df.sum(axis=1)).T - - g = sns.heatmap(df, annot=True, ax=ax, **kwargs) - plt.title(title) - _utils.savefig_or_show("matching_heatmap", show=show, save=save) - if not show: - if ax is not None: - return ax - else: - return g diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index 49991439..03699ed9 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -461,5 +461,5 @@ def plot_gsea( n=n, interactive_plot=interactive_plot, ) - fig.subtitle(cluster) + fig.suptitle(cluster) fig.show()