Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High resolution model & distance function prior [draft] #337

Open
wants to merge 49 commits into
base: master
Choose a base branch
from

Conversation

vitkl
Copy link
Contributor

@vitkl vitkl commented Nov 26, 2023

@vitkl
Copy link
Contributor Author

vitkl commented Oct 18, 2024

To enable using total cell abundance estimates from histology images the following changes are necessary (use_proportion_factorisation_prior_on_w_sf = True):

  1. Changing the parameterization of the factorisation prior to produce % of total cell abundance.
  2. Forcing the model to match the provided total cell abundance estimates by using that data as prior with very narrow distribution around the provided values (N_cells_per_location_alpha_prior=1000.0, use_n_s_cells_per_location_limit = True).
  3. Changing detection_alpha=200.0 back to narrow distribution.
  4. Changing other priors.
  5. Code modifications to support N_cells_per_location of shape=(n_obs, 1).

This branch can be installed as follows (I have not tested this particular recipe so please let me know if it doesn't work):

export PYTHONNOUSERSITE="True"
conda create -y -n c2l_v015 python=3.10
conda activate c2l_v015
pip install git+https://github.com/vitkl/scvi-tools.git@pyro_fixes
pip install "cell2location[tutorials,dev] @ git+https://github.com/BayraktarLab/cell2location.git@hires_sliding_window"
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install jupyter ipykernel
conda activate c2l_v015
python -m ipykernel install --user --name=c2l_v015 --display-name='Environment (c2l_v015)'

Temporary usage instructions. This will become mode='exact_total_cell_abundance' that switches all of these options on:

detection_alpha = 200.0
N_cells_per_location_alpha_prior = 1000.0
use_per_cell_type_normalisation = False
# ideally this is not count of cells 
# but % of spot occupied by cells * 0.9999 quantile of N cells across the data
N_cells_per_location = adata_vis.obs[['n_cell_occupancy']].values.astype('float32')

A_B_per_location_alpha_prior = None
A_factors_per_location = 40.0
B_groups_per_location = 5.0

use_proportion_factorisation_prior_on_w_sf = True
use_n_s_cells_per_location_limit = True

import torch
torch.set_float32_matmul_precision('high')

seed = 0
scvi.settings.seed = seed
np.random.seed(seed)

    # prepare anndata for scVI model
    cell2location.models.Cell2location.setup_anndata(
        adata=adata_vis, batch_key="sample"
    )

    if training:
        import pyro
        mod = cell2location.models.Cell2location(
            adata_vis, cell_state_df=inf_aver, 
            amortised=False,
            N_cells_per_location=N_cells_per_location, # np.array shape (n_obs, 1)
            detection_alpha=detection_alpha,
            use_per_cell_type_normalisation=use_per_cell_type_normalisation,
            N_cells_per_location_alpha_prior=N_cells_per_location_alpha_prior,
            N_cells_mean_var_ratio=None,
            detection_hyp_prior={"mean_alpha": float(1.0)},
            detection_cell_type_prior_alpha=float(100.0),
            A_B_per_location_alpha_prior=A_B_per_location_alpha_prior,
            A_factors_per_location=A_factors_per_location,
            B_groups_per_location=B_groups_per_location,
            use_proportion_factorisation_prior_on_w_sf=use_proportion_factorisation_prior_on_w_sf,
            use_n_s_cells_per_location_limit=use_n_s_cells_per_location_limit,
            n_groups=50,
        ) 
        
        mod.view_anndata_setup()
    
        mod.train(max_epochs=80000,
                  # train using full data (batch_size=None)
                  batch_size=None,
                  plan_kwargs={'optim': pyro.optim.Adam(optim_args={'lr': 0.002})},
                  # use all data points in training because
                  # we need to estimate cell abundance at all locations
                  train_size=1,
                  scale_elbo=1 / (adata_vis.n_obs * adata_vis.n_vars),
                  accelerator='gpu')
    
        # Save model
        mod.save(f"{scvi_run_name}", overwrite=True)
    else:
        # can be loaded later like this:
        mod = cell2location.models.Cell2location.load(f"{scvi_run_name}", adata_vis)

Note that this N_cells_per_location code doesn't support amortised=True.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant