Skip to content

Commit

Permalink
Fix plots
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Apr 14, 2024
1 parent 05a6a5c commit 3ced4ff
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 80 deletions.
158 changes: 79 additions & 79 deletions pertpy/tools/_cinemaot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -639,6 +639,84 @@ def attribution_scatter(
s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
return c_effect, s_effect

def plot_vis_matching(
self,
adata: AnnData,
de: AnnData,
pert_key: str,
control: str,
de_label: str,
source_label: str,
matching_rep: str = "ot",
resolution: float = 0.5,
normalize: str = "col",
title: str = "CINEMA-OT matching matrix",
min_val: float = 0.01,
show: bool = True,
save: str | None = None,
ax: Axes | None = None,
**kwargs,
) -> None:
"""Visualize the CINEMA-OT matching matrix.
Args:
adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted.
de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted().
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
control: Control category from the `pert_key` column.
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
source_label: the confounder / cell type label.
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
normalize: normalize the coarse-grained matching matrix by row / column.
title: the title for the figure.
min_val: The min value to truncate the matching matrix.
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
**kwargs: Other parameters to input for seaborn.heatmap.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.cinemaot_example()
>>> cot = pt.tl.Cinemaot()
>>> de = cot.causaleffect(
>>> adata, pert_key="perturbation", control="No stimulation", return_matching=True,
>>> thres=0.5, smoothness=1e-5, eps=1e-3, solver="Sinkhorn", preweight_label="cell_type0528")
>>> cot.plot_vis_matching(
>>> adata, de, pert_key="perturbation",control="No stimulation", de_label=None, source_label="cell_type0528")
"""
adata_ = adata[adata.obs[pert_key] == control]

df = pd.DataFrame(de.obsm[matching_rep])
if de_label is None:
de_label = "leiden"
sc.pp.neighbors(de, use_rep="X_embedding")
sc.tl.leiden(de, resolution=resolution)
df["de_label"] = de.obs[de_label].astype(str).values
df["de_label"] = "Response " + df["de_label"]
df = df.groupby("de_label").sum().T
df["source_label"] = adata_.obs[source_label].astype(str).values
df = df.groupby("source_label").sum()

if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T
df = df.clip(lower=min_val) - min_val
if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T

g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
plt.title(title)
_utils.savefig_or_show("matching_heatmap", show=show, save=save)
if not show:
if ax is not None:
return ax
else:
return g


class Xi:
"""
Expand Down Expand Up @@ -859,81 +937,3 @@ def fit(self, P: ArrayLike):
P_eps = np.diag(self._D1)[:, None] * P * np.diag(self._D2)[None, :]

return P_eps

def plot_vis_matching(
self,
adata: AnnData,
de: AnnData,
pert_key: str,
control: str,
de_label: str,
source_label: str,
matching_rep: str = "ot",
resolution: float = 0.5,
normalize: str = "col",
title: str = "CINEMA-OT matching matrix",
min_val: float = 0.01,
show: bool = True,
save: str | None = None,
ax: Axes | None = None,
**kwargs,
) -> None:
"""Visualize the CINEMA-OT matching matrix.
Args:
adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted.
de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted().
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
control: Control category from the `pert_key` column.
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
source_label: the confounder / cell type label.
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
normalize: normalize the coarse-grained matching matrix by row / column.
title: the title for the figure.
min_val: The min value to truncate the matching matrix.
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
**kwargs: Other parameters to input for seaborn.heatmap.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.cinemaot_example()
>>> cot = pt.tl.Cinemaot()
>>> de = cot.causaleffect(
>>> adata, pert_key="perturbation", control="No stimulation", return_matching=True,
>>> thres=0.5, smoothness=1e-5, eps=1e-3, solver="Sinkhorn", preweight_label="cell_type0528")
>>> cot.plot_vis_matching(
>>> adata, de, pert_key="perturbation",control="No stimulation", de_label=None, source_label="cell_type0528")
"""
adata_ = adata[adata.obs[pert_key] == control]

df = pd.DataFrame(de.obsm[matching_rep])
if de_label is None:
de_label = "leiden"
sc.pp.neighbors(de, use_rep="X_embedding")
sc.tl.leiden(de, resolution=resolution)
df["de_label"] = de.obs[de_label].astype(str).values
df["de_label"] = "Response " + df["de_label"]
df = df.groupby("de_label").sum().T
df["source_label"] = adata_.obs[source_label].astype(str).values
df = df.groupby("source_label").sum()

if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T
df = df.clip(lower=min_val) - min_val
if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T

g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
plt.title(title)
_utils.savefig_or_show("matching_heatmap", show=show, save=save)
if not show:
if ax is not None:
return ax
else:
return g
2 changes: 1 addition & 1 deletion pertpy/tools/_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,5 +461,5 @@ def plot_gsea(
n=n,
interactive_plot=interactive_plot,
)
fig.subtitle(cluster)
fig.suptitle(cluster)
fig.show()

0 comments on commit 3ced4ff

Please sign in to comment.