-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
High resolution model & distance function prior [draft] #337
base: master
Are you sure you want to change the base?
Conversation
…d distance prior, Gamma affinity prior, cell-type-independent global effects model (one receptor one effect not cell-type-specific receptor effect)) + updates to Visium HD model (normalisation)
…, per cell type normalisation, minor changes
…hoe prior, more likelihood options, zero diag & upper tri pathways, sqrt normalisation
…age distance function, heatmap with vcenter
To enable using total cell abundance estimates from histology images the following changes are necessary (
This branch can be installed as follows (I have not tested this particular recipe so please let me know if it doesn't work): export PYTHONNOUSERSITE="True"
conda create -y -n c2l_v015 python=3.10
conda activate c2l_v015
pip install git+https://github.com/vitkl/scvi-tools.git@pyro_fixes
pip install "cell2location[tutorials,dev] @ git+https://github.com/BayraktarLab/cell2location.git@hires_sliding_window"
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install jupyter ipykernel
conda activate c2l_v015
python -m ipykernel install --user --name=c2l_v015 --display-name='Environment (c2l_v015)' Temporary usage instructions. This will become detection_alpha = 200.0
N_cells_per_location_alpha_prior = 1000.0
use_per_cell_type_normalisation = False
# ideally this is not count of cells
# but % of spot occupied by cells * 0.9999 quantile of N cells across the data
N_cells_per_location = adata_vis.obs[['n_cell_occupancy']].values.astype('float32')
A_B_per_location_alpha_prior = None
A_factors_per_location = 40.0
B_groups_per_location = 5.0
use_proportion_factorisation_prior_on_w_sf = True
use_n_s_cells_per_location_limit = True
import torch
torch.set_float32_matmul_precision('high')
seed = 0
scvi.settings.seed = seed
np.random.seed(seed)
# prepare anndata for scVI model
cell2location.models.Cell2location.setup_anndata(
adata=adata_vis, batch_key="sample"
)
if training:
import pyro
mod = cell2location.models.Cell2location(
adata_vis, cell_state_df=inf_aver,
amortised=False,
N_cells_per_location=N_cells_per_location, # np.array shape (n_obs, 1)
detection_alpha=detection_alpha,
use_per_cell_type_normalisation=use_per_cell_type_normalisation,
N_cells_per_location_alpha_prior=N_cells_per_location_alpha_prior,
N_cells_mean_var_ratio=None,
detection_hyp_prior={"mean_alpha": float(1.0)},
detection_cell_type_prior_alpha=float(100.0),
A_B_per_location_alpha_prior=A_B_per_location_alpha_prior,
A_factors_per_location=A_factors_per_location,
B_groups_per_location=B_groups_per_location,
use_proportion_factorisation_prior_on_w_sf=use_proportion_factorisation_prior_on_w_sf,
use_n_s_cells_per_location_limit=use_n_s_cells_per_location_limit,
n_groups=50,
)
mod.view_anndata_setup()
mod.train(max_epochs=80000,
# train using full data (batch_size=None)
batch_size=None,
plan_kwargs={'optim': pyro.optim.Adam(optim_args={'lr': 0.002})},
# use all data points in training because
# we need to estimate cell abundance at all locations
train_size=1,
scale_elbo=1 / (adata_vis.n_obs * adata_vis.n_vars),
accelerator='gpu')
# Save model
mod.save(f"{scvi_run_name}", overwrite=True)
else:
# can be loaded later like this:
mod = cell2location.models.Cell2location.load(f"{scvi_run_name}", adata_vis) Note that this |
Depends on scverse/scvi-tools#2695