Skip to content

Commit

Permalink
diffusion domain function
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl committed Jul 25, 2024
1 parent fc60db2 commit fd2babc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
17 changes: 15 additions & 2 deletions cell2location/models/_cellcomm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class CellCommModule(PyroModule):
dropout_rate = 0.0
min_distance = 25.0
r_l_affinity_alpha_prior = 10.0
use_global_cell_abundance_model = False
record_sr_occupancy = False
use_spatial_receptor_info_remove_sp_signal = True

Expand Down Expand Up @@ -64,6 +63,8 @@ def __init__(
fixed_w_sf_mean_var_ratio: Optional[float] = None,
use_non_negative_weights: bool = False,
n_pathways: int = 20,
use_diffusion_domain: bool = False,
use_global_cell_abundance_model: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -114,6 +115,8 @@ def __init__(
self.use_normal_likelihood = use_normal_likelihood
self.fixed_w_sf_mean_var_ratio = fixed_w_sf_mean_var_ratio
self.use_non_negative_weights = use_non_negative_weights
self.use_diffusion_domain = use_diffusion_domain
self.use_global_cell_abundance_model = use_global_cell_abundance_model

self.weights = PyroModule()

Expand Down Expand Up @@ -277,6 +280,9 @@ def cell_comm_effect(
tiles,
obs_plate,
average_distance_prior=50.0,
obs_in_use=None,
w_sf=None,
use_diffusion_domain=False,
):
# get module
module = self.get_cell_communication_module(
Expand All @@ -292,6 +298,9 @@ def cell_comm_effect(
distances,
tiles,
obs_plate,
obs_in_use=obs_in_use,
w_sf=w_sf,
use_diffusion_domain=use_diffusion_domain,
)
# compute cell abundance prediction
w_sf_mu = module.signal_receptor_tf_effect_spatial(
Expand Down Expand Up @@ -321,6 +330,7 @@ def forward(
positions=positions,
in_tissue=in_tissue,
)
obs_in_use = None
if tiles_unexpanded is not None:
tiles_in_use = (tiles.mean(0) > torch.tensor(0.99, device=tiles.device)).bool()
obs_in_use = (tiles_unexpanded[:, tiles_in_use].sum(1) > torch.tensor(0.0, device=tiles.device)).bool()
Expand Down Expand Up @@ -373,6 +383,9 @@ def forward(
tiles=tiles,
average_distance_prior=self.average_distance_prior,
obs_plate=obs_plate,
obs_in_use=obs_in_use,
w_sf=w_sf,
use_diffusion_domain=self.use_diffusion_domain,
)
if not self.training and self.record_sr_occupancy:
with obs_plate:
Expand All @@ -382,7 +395,7 @@ def forward(
"r (c f) -> f c r",
f=self.receptor_abundance.shape[-1],
).sum(-3)
if tiles_unexpanded is not None:
if obs_in_use is not None:
bound_receptor_abundance_src = bound_receptor_abundance_src[obs_in_use, :]
pyro.deterministic(
"bound_receptor_abundance_sr_c",
Expand Down
62 changes: 62 additions & 0 deletions cell2location/nn/CellCommunicationToEffectNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CellCommunicationToTfActivityNN(

promoter_distance_prior = 50
use_footprinting_masked_tn5_index = -1
n_spatial_domains = 5

use_cached_effects = False
cached_effects = dict()
Expand Down Expand Up @@ -1109,13 +1110,55 @@ def signal_receptor_occupancy(

return bound_receptor_abundance_src

def diffusion_domain_function(
self,
signal_abundance: torch.Tensor,
w_sf: torch.Tensor,
):
# Low dimensional diffusion limiter - for every signal limit where it can diffuse.
# Maybe this can lead to more reasonable distributions
# without requiring suppressing effects to get rid of the signal.
n_signals = signal_abundance.shape[-1]
n_cell_types = w_sf.shape[-1]

name = "diffusion_domain_function"
# x_cs = sum_q y_cq * y_qs
# y_qs ~ Beta(100, 1)
# y_cq = y_cq / sum_q y_cq
# y_cq = sum_f w_cf * y_fq # maybe w_cf is lvl3 but better lvl5
# y_fq ~ Gamma(1, 1)
y_fq = self.get_dist_prior(
layer="",
name=f"{name}_y_fq",
weights_shape=[n_cell_types, self.n_spatial_domains],
prior_alpha=1.0,
prior_beta=1.0,
prior_fun=pyro.distributions.Gamma,
)
y_cq = torch.einsum("cf,fq->cq", w_sf, y_fq)
y_cq = y_cq / y_cq.sum(dim=-1, keepdim=True)
y_qs = self.get_dist_prior(
layer="",
name=f"{name}_y_qs",
weights_shape=[self.n_spatial_domains, n_signals],
prior_alpha=100.0,
prior_beta=1.0,
prior_fun=pyro.distributions.Beta,
)
x_cs = torch.einsum("cq,qs->cs", y_cq, y_qs)
signal_abundance = signal_abundance * x_cs
return signal_abundance

def signal_receptor_occupancy_spatial(
self,
signal_abundance: torch.Tensor,
receptor_abundance: torch.Tensor,
distances: torch.Tensor = None,
tiles: torch.Tensor = None,
obs_plate=None,
obs_in_use=None,
w_sf: torch.Tensor = None,
use_diffusion_domain: bool = False,
):
n_locations = signal_abundance.shape[-2]
n_signals = signal_abundance.shape[-1]
Expand Down Expand Up @@ -1174,6 +1217,25 @@ def signal_receptor_occupancy_spatial(
signal_distance_effect_ss_b,
)

if use_diffusion_domain:
signal_abundance = self.diffusion_domain_function(
signal_abundance=signal_abundance,
w_sf=w_sf,
)

if not self.training:
with obs_plate:
if obs_in_use is not None:
pyro.deterministic(
"signal_abundance_local",
signal_abundance[obs_in_use, :],
)
else:
pyro.deterministic(
"signal_abundance_local",
signal_abundance,
)

# 2. Computing bound receptor concentrations using learnable a_{r,s} affinity ============
# first reshape inputs to be locations * cell type specific
# d_{c,s} -> d_{c,f,s}
Expand Down

0 comments on commit fd2babc

Please sign in to comment.