diff --git a/setup.cfg b/setup.cfg index de5b72d4..5d7d6c19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = cell2location -version = 0.1.4 +version = 0.1.5 description = cell2location: High-throughput spatial mapping of cell types long_description = file: README.md long_description_content_type = text/markdown diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index f762c3f7..bd2c9a1c 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from pyro.infer.autoguide import AutoHierarchicalNormalMessenger from scvi.data import synthetic_iid @@ -319,3 +320,85 @@ def test_cell2location(): sample_key="batch", ) melt_signal_target_data_frame(weighted_avg_dict, distance_bins) + + +@pytest.mark.parametrize("sliding_window_size", [0, 4]) +@pytest.mark.parametrize("use_distance_function_prior_on_w_sf", [True, False]) +@pytest.mark.parametrize("use_distance_function_effect_on_w_sf", [True, False]) +@pytest.mark.parametrize("use_aggregated_w_sf", [False]) +@pytest.mark.parametrize("amortised", [True, False]) +@pytest.mark.parametrize("amortised_sliding_window_size", [0, 4]) +def test_cell2location_with_aggregation( + sliding_window_size, + use_distance_function_prior_on_w_sf, + use_distance_function_effect_on_w_sf, + use_aggregated_w_sf, + amortised, + amortised_sliding_window_size, +): + save_path = "./cell2location_model_test" + if torch.cuda.is_available(): + accelerator = "gpu" + else: + accelerator = "cpu" + dataset = synthetic_iid(n_labels=5) + dataset.obsm["X_spatial"] = np.random.normal(0, 1, [dataset.n_obs, 2]) + RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch") + + # train regression model to get signatures of cell types + sc_model = RegressionModel(dataset) + # test minibatch training + sc_model.train(max_epochs=1, batch_size=100, accelerator=accelerator) + # export the estimated cell abundance (summary of the posterior distribution) + dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) + # test quantile export + export_posterior_sc(sc_model, dataset) + sc_model.plot_QC(summary_name="q05") + # export estimated expression in each cluster + if "means_per_cluster_mu_fg" in dataset.varm.keys(): + inf_aver = dataset.varm["means_per_cluster_mu_fg"][ + [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]] + ].copy() + else: + inf_aver = dataset.var[[f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]].copy() + inf_aver.columns = dataset.uns["mod"]["factor_names"] + ### test cell2location model with convolutions ### + use_distance_fun = use_distance_function_prior_on_w_sf or use_distance_function_effect_on_w_sf + Cell2location.setup_anndata( + dataset, + batch_key="batch", + position_key=None if not use_distance_fun else "X_spatial", + ) + ## full data ## + st_model = Cell2location( + dataset, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + average_distance_prior=5.0, + sliding_window_size=sliding_window_size, + image_size=[20, 20], + use_distance_function_prior_on_w_sf=use_distance_function_prior_on_w_sf, + use_distance_function_effect_on_w_sf=use_distance_function_effect_on_w_sf, + use_aggregated_w_sf=use_aggregated_w_sf, + amortised=amortised, + encoder_mode="multiple", + amortised_sliding_window_size=amortised_sliding_window_size, + ) + shuffle = False if (sliding_window_size > 0) or (amortised_sliding_window_size > 0) else True + # test full data training + st_model.train( + max_epochs=1, + accelerator=accelerator, + shuffle_set_split=shuffle, + # datasplitter_kwargs={"shuffle": shuffle, "shuffle_set_split": shuffle}, + ) + # test save/load + st_model.save(save_path, overwrite=True, save_anndata=True) + st_model = Cell2location.load(save_path) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + if not use_distance_fun: + dataset = st_model.export_posterior( + dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs} + )