Skip to content

Commit

Permalink
working conv2d pooling + distance function effect
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl committed Nov 26, 2023
1 parent adba6c2 commit a99ca5e
Showing 1 changed file with 74 additions and 67 deletions.
141 changes: 74 additions & 67 deletions cell2location/models/_cell2location_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@ def __init__(
init_vals: Optional[dict] = None,
init_alpha: float = 20.0,
dropout_p: float = 0.0,
use_factorisation_prior_on_w_sf: bool = True,
use_distance_function_prior_on_w_sf: bool = False,
use_distance_function_effect_on_w_sf: bool = False,
average_distance_prior: float = 50.0,
sliding_window_size: Optional[int] = 0,
amortised_sliding_window_size: Optional[int] = 0,
image_size: Optional[tuple] = None,
use_aggregated_w_sf: bool = False,
):
super().__init__()

Expand All @@ -130,12 +131,13 @@ def __init__(
if self.dropout_p is not None:
self.dropout = torch.nn.Dropout(p=self.dropout_p)

self.use_factorisation_prior_on_w_sf = use_factorisation_prior_on_w_sf
self.use_distance_function_prior_on_w_sf = use_distance_function_prior_on_w_sf
self.use_distance_function_effect_on_w_sf = use_distance_function_effect_on_w_sf
self.average_distance_prior = average_distance_prior
self.sliding_window_size = sliding_window_size
self.amortised_sliding_window_size = amortised_sliding_window_size
self.image_size = image_size
self.use_aggregated_w_sf = use_aggregated_w_sf

self.weights = PyroModule()

Expand Down Expand Up @@ -174,6 +176,7 @@ def __init__(
self.register_buffer("cell_state", torch.tensor(cell_state_mat.T))

self.register_buffer("N_cells_per_location", torch.tensor(N_cells_per_location))
self.register_buffer("A_factors_per_location", torch.tensor(A_factors_per_location))
self.register_buffer("factors_per_groups", torch.tensor(factors_per_groups))
self.register_buffer("B_groups_per_location", torch.tensor(B_groups_per_location))
self.register_buffer("N_cells_mean_var_ratio", torch.tensor(N_cells_mean_var_ratio))
Expand Down Expand Up @@ -209,7 +212,9 @@ def __init__(
self.register_buffer("n_groups_tensor", torch.tensor(self.n_groups))

self.register_buffer("ones", torch.ones((1, 1)))
self.register_buffer("zeros", torch.zeros((1, 1)))
self.register_buffer("ones_1_n_groups", torch.ones((1, self.n_groups)))
self.register_buffer("ones_1_n_factors", torch.ones((1, self.n_factors)))
self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1)))
self.register_buffer("eps", torch.tensor(1e-8))

Expand All @@ -226,6 +231,28 @@ def _get_fn_args_from_batch(tensor_dict):
def create_plates(self, x_data, idx, batch_index, positions: torch.Tensor = None):
return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx)

def conv2d_aggregate(self, x_data):
x_data_agg = self.aggregate_conv2d(
x_data,
size=max(self.amortised_sliding_window_size, self.sliding_window_size),
padding="same",
)
x_data = torch.cat([x_data, x_data_agg], dim=-1)
return torch.log1p(x_data)

def learnable_conv2d(self, x_data):
x_data = torch.log1p(x_data)
x_data_agg = self.learnable_neighbour_effect_conv2d_nn(
x_data,
name="amortised_sliding_window",
size=max(self.amortised_sliding_window_size, self.sliding_window_size),
n_out=self.n_hidden,
padding="same",
)
# x_data = self.aggregate_conv2d(x_data, padding="same")
x_data = torch.cat([x_data, x_data_agg], dim=-1)
return x_data

def list_obs_plate_vars(self):
"""
Create a dictionary with:
Expand All @@ -239,39 +266,17 @@ def list_obs_plate_vars(self):
* values - the dimensions in non-plate axis of each variable (used to construct output
layer of encoder network when using amortised inference)
"""

def learnable_conv2d(x_data):
x_data_agg = self.learnable_neighbour_effect_conv2d_nn(
x_data,
name="amortised_sliding_window",
size=self.amortised_sliding_window_size,
n_out=self.n_hidden,
padding="same",
)
# x_data = self.aggregate_conv2d(x_data, padding="same")
x_data = torch.cat([x_data, x_data_agg], dim=-1)
return x_data

def conv2d_aggregate(x_data):
x_data_agg = self.aggregate_conv2d(
x_data,
size=self.amortised_sliding_window_size,
padding="same",
)
x_data = torch.cat([x_data, x_data_agg], dim=-1)
return x_data

input_transform = torch.log1p
n_in = self.n_vars

if (self.amortised_sliding_window_size > 0) and (self.sliding_window_size == 0):
input_transform = learnable_conv2d
input_transform = self.learnable_conv2d
n_in = self.n_vars + self.n_hidden
elif (self.amortised_sliding_window_size == 0) and (self.sliding_window_size > 0):
input_transform = conv2d_aggregate
input_transform = self.conv2d_aggregate
n_in = self.n_vars * 2
elif (self.amortised_sliding_window_size > 0) and (self.sliding_window_size > 0):
input_transform = learnable_conv2d
input_transform = self.learnable_conv2d
n_in = self.n_vars + self.n_hidden

return {
Expand All @@ -297,9 +302,12 @@ def conv2d_aggregate(x_data):

def reshape_input_2d(self, x):
# conv2d expects 4d input: [batch, channels, height, width]
size = int(np.sqrt(x.shape[-2]))
if self.image_size is None:
sizex = sizey = int(np.sqrt(x.shape[-2]))
else:
sizex, sizey = self.image_size
# here batch dim has just one element
return rearrange(x, "(p o) g -> g p o", p=size, o=size).unsqueeze(-4)
return rearrange(x, "(p o) g -> g p o", p=sizex, o=sizey).unsqueeze(-4)

def reshape_input_2d_inverse(self, x):
# conv2d expects 4d input: [batch, channels, height, width]
Expand All @@ -311,10 +319,10 @@ def crop_according_to_valid_padding(self, x):
# reshape to 2d
x = self.reshape_input_2d(x)
# crop to valid observations
x = x[
self.sliding_window_size // 2 : -self.sliding_window_size // 2,
self.sliding_window_size // 2 : -self.sliding_window_size // 2,
]
indx = np.arange(self.sliding_window_size // 2, x.shape[-2] - (self.sliding_window_size // 2))
indy = np.arange(self.sliding_window_size // 2, x.shape[-1] - (self.sliding_window_size // 2))
x = np.take(x, indx, axis=-2)
x = np.take(x, indy, axis=-1)
# reshape back to 1d
x = self.reshape_input_2d_inverse(x)
return x
Expand Down Expand Up @@ -426,6 +434,7 @@ def distance_function_neighbour_effect(
):
# distances [observations, observations]
distances = distances.view(*[distances.shape[0], distances.shape[1], 1, 1])
distances = distances + self.eps
# pyro version
param_shape = [1, 1, self.n_factors, self.n_factors]
# sigmoid function ============
Expand All @@ -442,13 +451,14 @@ def distance_function_neighbour_effect(
) # [self.n_factors, self.n_factors]
sigmoid_distance_function = self.inverse_sigmoid_lm(distances, sigmoid_weight, sigmoid_bias, scaling=None)
# gamma function ============
prior = torch.tensor(5.0, device=distances.device)
prior = torch.tensor(1.0, device=distances.device)
gamma_concentration = pyro.sample(
f"{name}_gamma_concentration",
dist.Gamma(prior, prior).expand(param_shape).to_event(len(param_shape)),
) # [self.n_factors, self.n_factors]
prior = torch.tensor(3.0, device=distances.device)
gamma_distance = pyro.sample(
f"{name}_gamma_concentration",
f"{name}_gamma_distance",
dist.Gamma(prior, prior).expand(param_shape).to_event(len(param_shape)),
) # [self.n_factors, self.n_factors]
gamma_distance = gamma_distance / torch.tensor(average_distance_prior, device=distances.device)
Expand All @@ -469,20 +479,20 @@ def distance_function_neighbour_effect(
) # [self.n_factors, self.n_factors]
x = torch.einsum( # sigmoid function
"hm,pohm,om->ph",
sigmoid_effect,
sigmoid_effect / torch.tensor(np.sqrt(self.n_factors), device=distances.device),
sigmoid_distance_function,
x_cm,
) + torch.einsum( # gamma function
"hm,pohm,om->ph",
gamma_effect,
gamma_effect / torch.tensor(np.sqrt(self.n_factors), device=distances.device),
gamma_distance_function,
x_cm,
)
# scale independent input abundances by the output of the distance function
x = x_cm * (
torch.nn.functional.softplus(x / torch.tensor(10.0, device=distances.device))
/ torch.tensor(0.7, device=distances.device)
x = torch.nn.functional.softplus(x / torch.tensor(100.0, device=distances.device)) / torch.tensor(
0.7, device=distances.device
) # average effect of 1
x = x_cm * x
return x

def factorisation_prior_on_w_sf(self, obs_plate):
Expand Down Expand Up @@ -558,11 +568,12 @@ def independent_prior_on_w_sf(self, obs_plate):
return w_sf

def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
if self.sliding_window_size > 0:
# remove observations that will not be included after convolution with padding='valid'
idx = self.crop_according_to_valid_padding(idx)
batch_index = self.crop_according_to_valid_padding(batch_index)
positions = self.crop_according_to_valid_padding(positions)
# if self.sliding_window_size > 0:
# # remove observations that will not be included after convolution with padding='valid'
# idx = self.crop_according_to_valid_padding(idx.unsqueeze(-1)).squeeze(-1)
# batch_index = self.crop_according_to_valid_padding(batch_index)
# if positions is not None:
# positions = self.crop_according_to_valid_padding(positions)
obs2sample = one_hot(batch_index, self.n_batch)
obs_plate = self.create_plates(x_data, idx, batch_index, positions)

Expand Down Expand Up @@ -590,9 +601,7 @@ def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
) # (1, n_vars)

# =====================Cell abundances w_sf======================= #
if self.use_factorisation_prior_on_w_sf and not (
self.use_distance_function_prior_on_w_sf or self.use_distance_function_effect_on_w_sf
):
if not (self.use_distance_function_prior_on_w_sf or self.use_distance_function_effect_on_w_sf):
w_sf_mu = self.factorisation_prior_on_w_sf(obs_plate)
with obs_plate:
k = "w_sf"
Expand All @@ -605,14 +614,13 @@ def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
) # (self.n_obs, self.n_factors)
elif self.use_distance_function_prior_on_w_sf:
w_sf_mu = self.independent_prior_on_w_sf(obs_plate)
if positions is not None:
# compute distance using positions [observations, 2]
distances = (
(positions.unsqueeze(1) - positions.unsqueeze(0)) # [observations, 1, 2] # [1, observations, 2]
.pow(2)
.sum(-1)
.sqrt()
)
# compute distance using positions [observations, 2]
distances = (
(positions.unsqueeze(1) - positions.unsqueeze(0)) # [observations, 1, 2] # [1, observations, 2]
.pow(2)
.sum(-1)
.sqrt()
)
w_sf_mu = self.distance_function_neighbour_effect(
x_cm=w_sf_mu,
distances=distances,
Expand All @@ -630,14 +638,13 @@ def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
) # (self.n_obs, self.n_factors)
elif self.use_distance_function_effect_on_w_sf:
w_sf_mu = self.independent_prior_on_w_sf(obs_plate)
if positions is not None:
# compute distance using positions [observations, 2]
distances = (
(positions.unsqueeze(1) - positions.unsqueeze(0)) # [observations, 1, 2] # [1, observations, 2]
.pow(2)
.sum(-1)
.sqrt()
)
# compute distance using positions [observations, 2]
distances = (
(positions.unsqueeze(1) - positions.unsqueeze(0)) # [observations, 1, 2] # [1, observations, 2]
.pow(2)
.sum(-1)
.sqrt()
)
w_sf_mu = self.distance_function_neighbour_effect(
x_cm=w_sf_mu,
distances=distances,
Expand All @@ -648,8 +655,8 @@ def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
k = "w_sf"
w_sf = pyro.deterministic(k, w_sf_mu) # (self.n_obs, self.n_factors)

if self.sliding_window_size > 0:
w_sf = self.aggregate_conv2d(w_sf, padding="valid")
if (self.sliding_window_size > 0) and self.use_aggregated_w_sf:
w_sf = self.aggregate_conv2d(w_sf, padding="same")
pyro.deterministic("aggregated_w_sf", w_sf)

# =====================Location-specific detection efficiency ======================= #
Expand Down Expand Up @@ -730,7 +737,7 @@ def forward(self, x_data, idx, batch_index, positions: torch.Tensor = None):
if self.dropout_p != 0:
x_data = self.dropout(x_data)
if self.sliding_window_size > 0:
x_data = self.aggregate_conv2d(x_data)
x_data = self.aggregate_conv2d(x_data, padding="same")
with obs_plate:
pyro.sample(
"data_target",
Expand Down

0 comments on commit a99ca5e

Please sign in to comment.