Skip to content

Commit

Permalink
changing distance prior, option to cap distance effect to 10x of aver…
Browse files Browse the repository at this point in the history
…age distance function, heatmap with vcenter
  • Loading branch information
vitkl committed Jul 28, 2024
1 parent fd2babc commit 92ed022
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
9 changes: 8 additions & 1 deletion cell2location/models/_cellcomm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(
use_non_negative_weights: bool = False,
n_pathways: int = 20,
use_diffusion_domain: bool = False,
use_global_cell_abundance_model: bool = True,
use_global_cell_abundance_model: bool = False,
use_max_distance_threshold: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
self.use_non_negative_weights = use_non_negative_weights
self.use_diffusion_domain = use_diffusion_domain
self.use_global_cell_abundance_model = use_global_cell_abundance_model
self.use_max_distance_threshold = use_max_distance_threshold

self.weights = PyroModule()

Expand Down Expand Up @@ -292,6 +294,10 @@ def cell_comm_effect(
average_distance_prior=average_distance_prior,
)
# compute LR occupancy
if self.use_max_distance_threshold:
max_distance_threshold = average_distance_prior * 10.0
else:
max_distance_threshold = None
bound_receptor_abundance_src = module.signal_receptor_occupancy_spatial(
signal_abundance,
receptor_abundance,
Expand All @@ -301,6 +307,7 @@ def cell_comm_effect(
obs_in_use=obs_in_use,
w_sf=w_sf,
use_diffusion_domain=use_diffusion_domain,
max_distance_threshold=max_distance_threshold,
)
# compute cell abundance prediction
w_sf_mu = module.signal_receptor_tf_effect_spatial(
Expand Down
11 changes: 9 additions & 2 deletions cell2location/nn/CellCommunicationToEffectNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ def inverse_sigmoid_distance_function(
layer=layer,
name=name_,
weights_shape=weights_shape,
prior_alpha=5.0,
prior_beta=5.0 / 5.0,
prior_alpha=2.0,
prior_beta=2.0 / 1.0,
prior_fun=pyro.distributions.Gamma,
)
name_ = f"{name}DistanceGammaDistance" # strictly positive
Expand Down Expand Up @@ -1159,6 +1159,7 @@ def signal_receptor_occupancy_spatial(
obs_in_use=None,
w_sf: torch.Tensor = None,
use_diffusion_domain: bool = False,
max_distance_threshold: float = None,
):
n_locations = signal_abundance.shape[-2]
n_signals = signal_abundance.shape[-1]
Expand Down Expand Up @@ -1211,6 +1212,12 @@ def signal_receptor_occupancy_spatial(
if tiles is not None:
tiles_mask = tiles @ tiles.T
signal_distance_effect_ss_b = torch.einsum("sop,op->sop", signal_distance_effect_ss_b, tiles_mask)
if max_distance_threshold is not None:
signal_distance_effect_ss_b = torch.einsum(
"sop,op->sop",
signal_distance_effect_ss_b,
(distances < torch.tensor(max_distance_threshold, device=distances.device)).float(),
)
signal_abundance = torch.einsum(
"ps,sop->os",
signal_abundance,
Expand Down
11 changes: 9 additions & 2 deletions cell2location/plt/plot_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def heatmap(
title="",
vmin=None,
vmax=None,
vcenter=None,
):
r"""Plot heatmap with row and column labels using plt.imshow
Expand All @@ -38,9 +39,13 @@ def heatmap(

array = np.array(array)
if log:
plt.imshow(array, interpolation="nearest", cmap=cmap, norm=matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax))
norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
else:
plt.imshow(array, interpolation="nearest", cmap=cmap)
if vcenter is None:
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
else:
norm = matplotlib.colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
plt.imshow(array, interpolation="nearest", cmap=cmap, norm=norm)

if cbar is True:
plt.colorbar()
Expand Down Expand Up @@ -163,6 +168,7 @@ def clustermap(
array_size=None,
vmin=None,
vmax=None,
vcenter=None,
):
r"""Plot heatmap with hierarchically clustered rows and columns using `cell2location.plt.plot_heatmap.heatmap()`
and `cell2location.plt.plot_heatmap.dotplot()`.
Expand Down Expand Up @@ -216,6 +222,7 @@ def clustermap(
title=title,
vmin=vmin,
vmax=vmax,
vcenter=vcenter,
)
elif fun_type == "dotplot":
# plot dotplot
Expand Down

0 comments on commit 92ed022

Please sign in to comment.