Skip to content

Commit

Permalink
add tests and bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl committed Nov 26, 2023
1 parent a99ca5e commit 444cbd4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = cell2location
version = 0.1.4
version = 0.1.5
description = cell2location: High-throughput spatial mapping of cell types
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down
83 changes: 83 additions & 0 deletions tests/test_cell2location.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch
from pyro.infer.autoguide import AutoHierarchicalNormalMessenger
from scvi.data import synthetic_iid
Expand Down Expand Up @@ -319,3 +320,85 @@ def test_cell2location():
sample_key="batch",
)
melt_signal_target_data_frame(weighted_avg_dict, distance_bins)


@pytest.mark.parametrize("sliding_window_size", [0, 4])
@pytest.mark.parametrize("use_distance_function_prior_on_w_sf", [True, False])
@pytest.mark.parametrize("use_distance_function_effect_on_w_sf", [True, False])
@pytest.mark.parametrize("use_aggregated_w_sf", [False])
@pytest.mark.parametrize("amortised", [True, False])
@pytest.mark.parametrize("amortised_sliding_window_size", [0, 4])
def test_cell2location_with_aggregation(
sliding_window_size,
use_distance_function_prior_on_w_sf,
use_distance_function_effect_on_w_sf,
use_aggregated_w_sf,
amortised,
amortised_sliding_window_size,
):
save_path = "./cell2location_model_test"
if torch.cuda.is_available():
accelerator = "gpu"
else:
accelerator = "cpu"
dataset = synthetic_iid(n_labels=5)
dataset.obsm["X_spatial"] = np.random.normal(0, 1, [dataset.n_obs, 2])
RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch")

# train regression model to get signatures of cell types
sc_model = RegressionModel(dataset)
# test minibatch training
sc_model.train(max_epochs=1, batch_size=100, accelerator=accelerator)
# export the estimated cell abundance (summary of the posterior distribution)
dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10})
# test quantile export
export_posterior_sc(sc_model, dataset)
sc_model.plot_QC(summary_name="q05")
# export estimated expression in each cluster
if "means_per_cluster_mu_fg" in dataset.varm.keys():
inf_aver = dataset.varm["means_per_cluster_mu_fg"][
[f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]
].copy()
else:
inf_aver = dataset.var[[f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]].copy()
inf_aver.columns = dataset.uns["mod"]["factor_names"]
### test cell2location model with convolutions ###
use_distance_fun = use_distance_function_prior_on_w_sf or use_distance_function_effect_on_w_sf
Cell2location.setup_anndata(
dataset,
batch_key="batch",
position_key=None if not use_distance_fun else "X_spatial",
)
## full data ##
st_model = Cell2location(
dataset,
cell_state_df=inf_aver,
N_cells_per_location=30,
detection_alpha=200,
average_distance_prior=5.0,
sliding_window_size=sliding_window_size,
image_size=[20, 20],
use_distance_function_prior_on_w_sf=use_distance_function_prior_on_w_sf,
use_distance_function_effect_on_w_sf=use_distance_function_effect_on_w_sf,
use_aggregated_w_sf=use_aggregated_w_sf,
amortised=amortised,
encoder_mode="multiple",
amortised_sliding_window_size=amortised_sliding_window_size,
)
shuffle = False if (sliding_window_size > 0) or (amortised_sliding_window_size > 0) else True
# test full data training
st_model.train(
max_epochs=1,
accelerator=accelerator,
shuffle_set_split=shuffle,
# datasplitter_kwargs={"shuffle": shuffle, "shuffle_set_split": shuffle},
)
# test save/load
st_model.save(save_path, overwrite=True, save_anndata=True)
st_model = Cell2location.load(save_path)
# export the estimated cell abundance (summary of the posterior distribution)
# full data
if not use_distance_fun:
dataset = st_model.export_posterior(
dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}
)

0 comments on commit 444cbd4

Please sign in to comment.