From 985d8155391fdfbddec024de428308b5a57ee280 Mon Sep 17 00:00:00 2001 From: Michaela Mueller <51025211+mumichae@users.noreply.github.com> Date: Sun, 24 Oct 2021 18:44:40 +0200 Subject: [PATCH] Packaging (#272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * renamed scIB to scib * fixed imports * Fixed name in setup.py * fix name in CI * refactored package options to setup.cfg * separate command for pip install scib * fix finding packages * moved pytest.ini content to pyproject.toml * removed old module notes * removed redundant requirements files * rename functions to snake case * wrap integration functions with old functions names * Better deprecation warning * rename read_conos and read_scanorama * moved test dependencies to setup.cfg * Update README * Rename batch in README usage * add per batch trajectory score (#273) * add per batch trajectory score * Update trajectory.py add missing comma * Update trajectory.py import pandas * don't recompute trajectories per batch * correct batch var handling * Update trajectory.py * Check batch key * add tests for trajectory score * update test values * update test values Co-authored-by: Strobl Co-authored-by: Michaela Mueller * renamed scIB to scib * fixed imports * Fixed name in setup.py * fix name in CI * refactored package options to setup.cfg * separate command for pip install scib * fix finding packages * moved pytest.ini content to pyproject.toml * removed old module notes * removed redundant requirements files * rename functions to snake case * wrap integration functions with old functions names * Better deprecation warning * rename read_conos and read_scanorama * moved test dependencies to setup.cfg * Update README * Rename batch in README usage * rename trajectory batch function usage * integrated code review * add kwargs to integration methods * Use tempfile for model paths * rename packaging tools file * minor code review changes * use tempfile for conos saving * Throw error when batches mismatch in trajectory conservation metric * restructured utils functions in metrics module * Revert "Throw error when batches mismatch in trajectory conservation metric" This reverts commit fe0200cf8e7f5578a752c8170d23a65b7c821215. * Throw error when batches mismatch in trajectory conservation metric * Revert "restructured utils functions in metrics module" This reverts commit c30a501a84ffeb529341dba03b3e3c6b052f7288. * fix batch check in TI conservation * update import order * setup bumpversion * Bump version: 0.2.0 → 1.0.0 * add MANIFEST.in Co-authored-by: Daniel Strobl <50872326+danielStrobl@users.noreply.github.com> Co-authored-by: Strobl --- .github/workflows/python-package.yml | 6 +- .gitignore | 1 + MANIFEST.in | 1 + README.md | 87 +- VERSION.txt | 1 + pyproject.toml | 11 + pytest.ini | 4 - requirements.txt | 20 - requirements_extra.txt | 11 - scIB/__init__.py | 9 - scib/__init__.py | 35 + scib/_package_tools.py | 29 + {scIB => scib}/exceptions.py | 0 {scIB => scib}/integration.py | 124 ++- {scIB => scib}/knn_graph/README.md | 0 {scIB => scib}/knn_graph/knn_graph.cpp | 864 +++++++++--------- {scIB => scib}/knn_graph/knn_graph.o | Bin {scIB => scib}/knn_graph/makefile | 0 {scIB => scib}/metrics/__init__.py | 0 {scIB => scib}/metrics/ari.py | 8 +- {scIB => scib}/metrics/cell_cycle.py | 8 +- {scIB => scib/metrics}/clustering.py | 13 +- {scIB => scib}/metrics/graph_connectivity.py | 0 .../metrics/highly_variable_genes.py | 6 +- {scIB => scib}/metrics/isolated_labels.py | 2 +- {scIB => scib}/metrics/kbet.py | 19 +- {scIB => scib}/metrics/lisi.py | 35 +- {scIB => scib}/metrics/metrics.py | 27 +- {scIB => scib}/metrics/nmi.py | 10 +- {scIB => scib}/metrics/pcr.py | 8 +- {scIB => scib}/metrics/silhouette.py | 2 +- {scIB => scib}/metrics/trajectory.py | 22 +- {scIB => scib}/metrics/utils.py | 3 +- {scIB => scib}/preprocessing.py | 328 +++---- {scIB => scib}/resources/g2m_genes_tirosh.txt | 0 .../resources/g2m_genes_tirosh_hm.txt | 0 {scIB => scib}/resources/s_genes_tirosh.txt | 0 .../resources/s_genes_tirosh_hm.txt | 0 {scIB => scib}/trajectory_inference.py | 28 +- {scIB => scib}/utils.py | 34 +- setup.cfg | 83 +- setup.py | 32 +- tests/common.py | 2 +- tests/conftest.py | 8 +- tests/metrics/test_all.py | 6 +- tests/metrics/test_beyond_label_metrics.py | 6 +- tests/metrics/test_clisi.py | 6 +- tests/metrics/test_cluster_metrics.py | 12 +- tests/metrics/test_graph_connectivity.py | 2 +- tests/metrics/test_ilisi.py | 6 +- tests/metrics/test_kbet.py | 2 +- tests/metrics/test_pcr_metrics.py | 8 +- tests/metrics/test_silhouette_metrics.py | 6 +- tests/metrics/test_trajectory.py | 4 +- tests/preprocessing/test_clustering.py | 2 +- tests/preprocessing/test_preprocessing.py | 6 +- tests/requirements.txt | 3 - 57 files changed, 1064 insertions(+), 886 deletions(-) create mode 100644 MANIFEST.in create mode 100644 VERSION.txt create mode 100644 pyproject.toml delete mode 100644 pytest.ini delete mode 100644 requirements.txt delete mode 100644 requirements_extra.txt delete mode 100644 scIB/__init__.py create mode 100644 scib/__init__.py create mode 100644 scib/_package_tools.py rename {scIB => scib}/exceptions.py (100%) rename {scIB => scib}/integration.py (82%) rename {scIB => scib}/knn_graph/README.md (100%) rename {scIB => scib}/knn_graph/knn_graph.cpp (97%) rename {scIB => scib}/knn_graph/knn_graph.o (100%) rename {scIB => scib}/knn_graph/makefile (100%) rename {scIB => scib}/metrics/__init__.py (100%) rename {scIB => scib}/metrics/ari.py (91%) rename {scIB => scib}/metrics/cell_cycle.py (97%) rename {scIB => scib/metrics}/clustering.py (93%) rename {scIB => scib}/metrics/graph_connectivity.py (100%) rename {scIB => scib}/metrics/highly_variable_genes.py (94%) rename {scIB => scib}/metrics/isolated_labels.py (99%) rename {scIB => scib}/metrics/kbet.py (97%) rename {scIB => scib}/metrics/lisi.py (98%) rename {scIB => scib}/metrics/metrics.py (94%) rename {scIB => scib}/metrics/nmi.py (96%) rename {scIB => scib}/metrics/pcr.py (98%) rename {scIB => scib}/metrics/silhouette.py (97%) rename {scIB => scib}/metrics/trajectory.py (89%) rename {scIB => scib}/metrics/utils.py (99%) rename {scIB => scib}/preprocessing.py (77%) rename {scIB => scib}/resources/g2m_genes_tirosh.txt (100%) rename {scIB => scib}/resources/g2m_genes_tirosh_hm.txt (100%) rename {scIB => scib}/resources/s_genes_tirosh.txt (100%) rename {scIB => scib}/resources/s_genes_tirosh_hm.txt (100%) rename {scIB => scib}/trajectory_inference.py (78%) rename {scIB => scib}/utils.py (76%) delete mode 100644 tests/requirements.txt diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 43c3e667..43482b74 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -23,19 +23,19 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - pip install . - pip install -r tests/requirements.txt - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Install package + run: pip install .[test] - name: Import package run: | pip list cd tests - python -c 'import scIB' + python -c 'import scib' - name: Test with pytest run: | pytest --durations 0 -s diff --git a/.gitignore b/.gitignore index 699c7fb7..7863779c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ testing.h5ad data .ipynb_checkpoints *.egg-info +*dist/ *cache* .snakemake diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..74282fce --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include VERSION.txt diff --git a/README.md b/README.md index 6e6fd016..cc4c94e1 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ batches of gene expression and chromatin accessibility data. + On our [website](https://theislab.github.io/scib-reproducibility) we visualise the results of the study. + The reusable pipeline we used in the study can be found in the - separate [scIB pipeline](https://github.com/theislab/scib-pipeline.git) repository. It is reproducible and automates + separate [scib pipeline](https://github.com/theislab/scib-pipeline.git) repository. It is reproducible and automates the computation of preprocesssing combinations, integration methods and benchmarking metrics. + For reproducibility and visualisation we have a dedicated @@ -24,14 +24,24 @@ batches of gene expression and chromatin accessibility data. MD Luecken, M Büttner, K Chaichoompu, A Danese, M Interlandi, MF Mueller, DC Strobl, L Zappia, M Dugas, M Colomé-Tatché, FJ Theis bioRxiv 2020.05.22.111161; doi: https://doi.org/10.1101/2020.05.22.111161_ -## Package: `scIB` +## Package: `scib` -We created the python package called `scIB` that uses `scanpy` to streamline the integration of single-cell datasets and -evaluate the results. The evaluation of integration quality is based on a number of metrics. +We created the python package called `scib` that uses `scanpy` to streamline the integration of single-cell datasets and +evaluate the results. For evaluating the integration quality it provides a number of metrics. + +### Requirements + ++ Linux or UNIX system ++ Python >= 3.7 ++ 3.6 <= R <= 4.0 + +We recommend working with environments such as Conda or virtualenv, so that python and R dependencies are in one place. +Please also check out [scib pipeline](https://github.com/theislab/scib-pipeline.git) for ready-to-use environments. +Alternatively, manually install the package on your system using pip, described in the next section. ### Installation -The `scIB` python package is in the folder scIB. You can install it from the root of this repository using +The `scib` python package is in the folder scib. You can simply install it from the root of this repository using ``` pip install . @@ -49,45 +59,64 @@ Additionally, in order to run the R package `kBET`, you need to install it throu devtools::install_github('theislab/kBET') ``` -We recommend to use a conda environment or something similar, so that python and R dependencies are in one place. Please -also check out [scIB pipeline](https://github.com/theislab/scib-pipeline.git) for ready-to-use environments. +> **Note:** By default dependencies for integration methods are not installed due to dependency clashes. +> In order to use integration methods, see the next section ### Installing additional packages This package contains code for running integration methods as well as for evaluating their output. However, due to -dependency clashes, `scIB` is only installed with the packages needed for the metrics. In order to use the integration +dependency clashes, `scib` is only installed with the packages needed for the metrics. In order to use the integration wrapper functions, we recommend to work with different environments for different methods, each with their own -installation of `scIB`. Check out the `Tools` section for a list of supported integration methods. +installation of `scib`. You can install optional Python dependencies via pip as follows: + +``` +pip install .[bbknn] # using BBKNN +pip install .[scanorama] # using Scanorama +pip install .[bbknn,scanorama] # Multiple methods in one go +``` + +The `setup.cfg` for a full list of Python dependencies. For a comprehensive list of supported integration methods, +including R packages, check out the `Tools`. ## Usage The package contains several modules for the different steps of the integration and benchmarking pipeline. Functions for -the integration methods are in `scIB.integration`. The methods can be called using +the integration methods are in `scib.integration` or for short `scib.ig`. The methods can be called using +```py +scib.integration.(adata, batch=) ``` -scIB.integration.run(adata, batch=) + +where `` is the name of the integration method and `` is the name of the batch column in `adata.obs`. +For example, in order to run Scanorama, on a dataset with batch key 'batch' call + +```py +scib.integration.scanorama(adata, batch='batch') ``` -where `` is the name of the integration method and `` is the name of the batch column in `adata.obs`. +> **Warning:** the following notation is deprecated. +> ``` +> scib.integration.run(adata, batch=) +> ``` +> Please use the snake case naming without the `run` prefix. -Some integration methods (scGEN, SCANVI) also use cell type labels as input. For these, you need to additionally provide +Some integration methods (`scgen`, `scanvi`) also use cell type labels as input. For these, you need to additionally provide the corresponding label column. -``` -runScGen(adata, batch=, cell_type=) -runScanvi(adata, batch=, labels=) +```py +scgen(adata, batch=, cell_type=) +scanvi(adata, batch=, labels=) ``` -`scIB.preprocessing` contains methods for preprocessing of the data such as normalisation, scaling or highly variable -gene selection per batch. The metrics are located at `scIB.metrics`. To run multiple metrics in one run, use -the `scIB.metrics.metrics()` function. +`scib.preprocessing` (or `scib.pp`) contains functions for normalising, scaling or selecting highly variable genes per batch +The metrics are under `scib.metrics` (or `scib.me`). -### Metrics +## Metrics For a detailed description of the metrics implemented in this package, please see the [manuscript](https://www.biorxiv.org/content/10.1101/2020.05.22.111161v2). -#### Batch removal metrics include: +### Batch removal metrics include: - Principal component regression `pcr_comparison()` - Batch ASW `silhouette()` @@ -95,7 +124,7 @@ the [manuscript](https://www.biorxiv.org/content/10.1101/2020.05.22.111161v2). - Graph connectivity `graph_connectivity()` - Graph iLISI `lisi_graph()` -#### Biological conservation metrics include: +### Biological conservation metrics include: - Normalised mutual information `nmi()` - Adjusted Rand Index `ari()` @@ -107,6 +136,20 @@ the [manuscript](https://www.biorxiv.org/content/10.1101/2020.05.22.111161v2). - Trajectory conservation `trajectory_conservation()` - Graph cLISI `lisi_graph()` +### Metrics Wrapper Functions +We provide wrapper functions to run multiple metrics in one function call. +The `scib.metrics.metrics()` function returns a `pandas.Dataframe` of all metrics specified as parameters. + +```py +scib.metrics.metrics(adata, adata_int, ari=True, nmi=True) +``` + +Furthermore, `scib.metrics.metrics()` is wrapped by convenience functions that only select certain metrics: + ++ `scib.me.metrics_fast()` only computes metrics that require little preprocessing ++ `scib.me.metrics_slim()` includes all functions of `scib.me.metrics_fast()` and adds clustering-based metrics ++ `scib.me.metrics_all()` includes all metrics + ## Tools Tools that are compared include: diff --git a/VERSION.txt b/VERSION.txt new file mode 100644 index 00000000..3eefcb9d --- /dev/null +++ b/VERSION.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..95ce61c6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = [ + "setuptools", + "wheel", +] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +log_cli = 'True' +log_cli_level = 'INFO' +addopts = '-p no:warnings' diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 61af7eb1..00000000 --- a/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -log_cli = True -log_cli_level = INFO -addopts = -p no:warnings \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 501e3ed7..00000000 --- a/requirements.txt +++ /dev/null @@ -1,20 +0,0 @@ -numpy==1.18.1 -pandas -seaborn -matplotlib -numba -scanpy>=1.5 -anndata>=0.7.2 -h5py<3 -rpy2>=3 -anndata2ri -scipy -scikit-learn -scikit-misc -loompy==3.0.6 -louvain -umap-learn -pydot -python-igraph -llvmlite -memory_profiler diff --git a/requirements_extra.txt b/requirements_extra.txt deleted file mode 100644 index 2ae71790..00000000 --- a/requirements_extra.txt +++ /dev/null @@ -1,11 +0,0 @@ -# integration tools -bbknn==1.3.9 -scanorama==1.7.0 -mnnpy==0.1.9.5 -scgen==1.1.5 -scvi==0.6.7 -# trvae==1.1.2 -# trvaep==0.1.0 -# desc==2.0.3 -# keras==2.1 -# tensorflow==1.15 diff --git a/scIB/__init__.py b/scIB/__init__.py deleted file mode 100644 index 44488eee..00000000 --- a/scIB/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import scIB.utils as utils -from . import preprocessing, integration, metrics, clustering -pp = preprocessing -ig = integration -me = metrics -cl = clustering - -import seaborn -seaborn.set_context('talk') diff --git a/scib/__init__.py b/scib/__init__.py new file mode 100644 index 00000000..17c4aa2d --- /dev/null +++ b/scib/__init__.py @@ -0,0 +1,35 @@ +try: + from importlib import metadata +except ImportError: # for Python<3.8 + import importlib_metadata as metadata + +__version__ = metadata.version('scib') + +from . import integration, metrics, preprocessing +from . import utils as utils +from ._package_tools import rename_func +from .metrics import clustering + +alias_func_map = { + 'runScanorama': integration.scanorama, + 'runTrVae': integration.trvae, + 'runTrVaep': integration.trvaep, + 'runScGen': integration.scgen, + 'runScvi': integration.scvi, + 'runScanvi': integration.scanvi, + 'runMNN': integration.mnn, + 'runBBKNN': integration.bbknn, + 'runSaucie': integration.saucie, + 'runCombat': integration.combat, + 'runDESC': integration.desc, + 'readSeurat': preprocessing.read_seurat, + 'readConos': preprocessing.read_conos, +} + +for alias, func in alias_func_map.items(): + rename_func(func, alias) + +pp = preprocessing +ig = integration +me = metrics +cl = clustering diff --git a/scib/_package_tools.py b/scib/_package_tools.py new file mode 100644 index 00000000..73c797dd --- /dev/null +++ b/scib/_package_tools.py @@ -0,0 +1,29 @@ +import inspect +import warnings +from functools import wraps + +warnings.simplefilter('default') # or 'always' + + +def wrap_func_naming(func, name): + """ + Decorator that adds a `DeprecationWarning` and a name to `func`. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + warnings.warn( + f"Mixed case function naming is deprecated for '{name}'. " + f"Please use '{func.__name__}' instead.", + DeprecationWarning + ) + return func(*args, **kwargs) + + wrapper.__name__ = name + return wrapper + + +def rename_func(function, new_name): + if callable(function): + function = wrap_func_naming(function, new_name) + setattr(inspect.getmodule(function), new_name, function) diff --git a/scIB/exceptions.py b/scib/exceptions.py similarity index 100% rename from scIB/exceptions.py rename to scib/exceptions.py diff --git a/scIB/integration.py b/scib/integration.py similarity index 82% rename from scIB/integration.py rename to scib/integration.py index a294c72c..c2e95ae8 100644 --- a/scIB/integration.py +++ b/scib/integration.py @@ -1,35 +1,34 @@ -#!/bin/env python - -### D. C. Strobl, M. Müller; 2019-07-23 - -""" This module provides a toolkit for running a large range of single cell data integration methods - as well as tools and metrics to benchmark them. +""" +This module provides a toolkit for running a large range of single cell data integration +methods as well as tools and metrics to benchmark them. """ -import scipy as sp -from scIB.utils import * +import logging import os -import anndata -from scIB.exceptions import IntegrationMethodNotFound +import tempfile +import anndata +import numpy as np import rpy2.rinterface_lib.callbacks -import logging - -rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages +import scanpy as sc +import scipy as sp from scipy.sparse import issparse +from . import utils +from .exceptions import IntegrationMethodNotFound + +rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages -# functions for running the methods -def runScanorama(adata, batch, hvg=None): +def scanorama(adata, batch, hvg=None, **kwargs): try: import scanorama except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) - split, categories = splitBatches(adata.copy(), batch, return_categories=True) - corrected = scanorama.correct_scanpy(split, return_dimred=True) + utils.check_sanity(adata, batch, hvg) + split, categories = utils.split_batches(adata.copy(), batch, return_categories=True) + corrected = scanorama.correct_scanpy(split, return_dimred=True, **kwargs) corrected = anndata.AnnData.concatenate( *corrected, batch_key=batch, batch_categories=categories, index_unique=None ) @@ -39,13 +38,13 @@ def runScanorama(adata, batch, hvg=None): return corrected -def runTrVae(adata, batch, hvg=None): +def trvae(adata, batch, hvg=None): try: import trvae except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) + utils.check_sanity(adata, batch, hvg) n_batches = len(adata.obs[batch].cat.categories) train_adata, valid_adata = trvae.utils.train_test_split( @@ -85,13 +84,13 @@ def runTrVae(adata, batch, hvg=None): return adata -def runTrVaep(adata, batch, hvg=None): +def trvaep(adata, batch, hvg=None): try: import trvaep except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) + utils.check_sanity(adata, batch, hvg) n_batches = adata.obs[batch].nunique() # Densify the data matrix @@ -126,7 +125,7 @@ def runTrVaep(adata, batch, hvg=None): return adata -def runScGen(adata, batch, cell_type, epochs=100, hvg=None, model_path='/localscratch'): +def scgen(adata, batch, cell_type, epochs=100, hvg=None, model_path=None, **kwargs): """ Parametrization taken from the tutorial notebook at: https://nbviewer.jupyter.org/github/M0hammadL/scGen_notebooks/blob/master/notebooks/scgen_batch_removal.ipynb @@ -136,30 +135,47 @@ def runScGen(adata, batch, cell_type, epochs=100, hvg=None, model_path='/localsc except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) + utils.check_sanity(adata, batch, hvg) + + if model_path is None: + temp_dir = tempfile.TemporaryDirectory() + model_path = temp_dir.name # Fit the model - network = scgen.VAEArith(x_dimension=adata.shape[1], model_path=model_path) - network.train(train_data=adata, n_epochs=epochs, save=False) - corrected_adata = scgen.batch_removal(network, adata, batch_key=batch, cell_label_key=cell_type) + network = scgen.VAEArith( + x_dimension=adata.shape[1], + model_path=model_path + ) + network.train( + train_data=adata, + n_epochs=epochs, + save=False + ) + corrected_adata = scgen.batch_removal( + network, + adata, + batch_key=batch, + cell_label_key=cell_type, + **kwargs + ) network.sess.close() return corrected_adata -def runScvi(adata, batch, hvg=None): +def scvi(adata, batch, hvg=None): # Use non-normalized (count) data for scvi! # Expects data only on HVGs try: - from scvi.models import VAE + from scvi.dataset import AnnDatasetFromAnnData from scvi.inference import UnsupervisedTrainer + from scvi.models import VAE from sklearn.preprocessing import LabelEncoder - from scvi.dataset import AnnDatasetFromAnnData except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) + utils.check_sanity(adata, batch, hvg) # Check for counts data layer if 'counts' not in adata.layers: @@ -209,13 +225,13 @@ def runScvi(adata, batch, hvg=None): return adata -def runScanvi(adata, batch, labels): +def scanvi(adata, batch, labels): # Use non-normalized (count) data for scanvi! try: - from scvi.models import VAE, SCANVI - from scvi.inference import UnsupervisedTrainer, SemiSupervisedTrainer - from sklearn.preprocessing import LabelEncoder from scvi.dataset import AnnDatasetFromAnnData + from scvi.inference import SemiSupervisedTrainer, UnsupervisedTrainer + from scvi.models import SCANVI, VAE + from sklearn.preprocessing import LabelEncoder except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) @@ -297,37 +313,49 @@ def runScanvi(adata, batch, labels): return adata -def runMNN(adata, batch, hvg=None): +def mnn(adata, batch, hvg=None, **kwargs): try: import mnnpy except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) - split, categories = splitBatches(adata, batch, return_categories=True) + utils.check_sanity(adata, batch, hvg) + split, categories = utils.split_batches(adata, batch, return_categories=True) corrected, _, _ = mnnpy.mnn_correct( - *split, var_subset=hvg, batch_key=batch, batch_categories=categories, index_unique=None + *split, + var_subset=hvg, + batch_key=batch, + batch_categories=categories, + index_unique=None, + **kwargs ) return corrected -def runBBKNN(adata, batch, hvg=None): +def bbknn(adata, batch, hvg=None, **kwargs): try: import bbknn except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) - checkSanity(adata, batch, hvg) + utils.check_sanity(adata, batch, hvg) sc.pp.pca(adata, svd_solver='arpack') if adata.n_obs < 1e5: - return bbknn.bbknn(adata, batch_key=batch, copy=True) + return bbknn.bbknn( + adata, batch_key=batch, copy=True, **kwargs) if adata.n_obs >= 1e5: - return bbknn.bbknn(adata, batch_key=batch, neighbors_within_batch=25, copy=True) + return bbknn.bbknn( + adata, + batch_key=batch, + neighbors_within_batch=25, + copy=True, + **kwargs + ) -def runSaucie(adata, batch): +def saucie(adata, batch): """ parametrisation from https://github.com/KrishnaswamyLab/SAUCIE/blob/master/scripts/SAUCIE.py """ @@ -354,13 +382,13 @@ def runSaucie(adata, batch): return ret -def runCombat(adata, batch): +def combat(adata, batch): adata_int = adata.copy() sc.pp.combat(adata_int, key=batch) return adata_int -def runDESC(adata, batch, res=0.8, ncores=None, tmp_dir='/localscratch/tmp_desc/', use_gpu=False): +def desc(adata, batch, res=0.8, ncores=None, tmp_dir=None, use_gpu=False): """ Convenience function to run DESC. Parametrization was taken from: https://github.com/eleozzr/desc/issues/28 @@ -371,6 +399,10 @@ def runDESC(adata, batch, res=0.8, ncores=None, tmp_dir='/localscratch/tmp_desc/ except ModuleNotFoundError as e: raise IntegrationMethodNotFound(e) + if tmp_dir is None: + temp_dir = tempfile.TemporaryDirectory() + tmp_dir = temp_dir.name + # Set number of CPUs to all available if ncores is None: ncores = os.cpu_count() diff --git a/scIB/knn_graph/README.md b/scib/knn_graph/README.md similarity index 100% rename from scIB/knn_graph/README.md rename to scib/knn_graph/README.md diff --git a/scIB/knn_graph/knn_graph.cpp b/scib/knn_graph/knn_graph.cpp similarity index 97% rename from scIB/knn_graph/knn_graph.cpp rename to scib/knn_graph/knn_graph.cpp index 46ede024..0a082ee0 100644 --- a/scIB/knn_graph/knn_graph.cpp +++ b/scib/knn_graph/knn_graph.cpp @@ -1,432 +1,432 @@ -#include -#include -#include -#include -#include -#include -#include -//--------------------------------------------------------------------------- -using namespace std; -//--------------------------------------------------------------------------- -namespace { -//--------------------------------------------------------------------------- -// A sparse distance matrix, must be symmetric -class Matrix { - public: - /// An entry - struct Entry { - /// The column - unsigned column; - /// The weight - double weight; - }; - /// An entry range - struct EntryRange { - /// The entry range - const Entry *from, *to; - - /// Empty range? - bool empty() const { return from == to; } - /// First element - const Entry* begin() { return from; } - /// Behind the last element - const Entry* end() { return to; } - }; - - public: - /// The entries - vector entries; - /// The entries offsets - vector entryOffsets; - /// The width of the matrix - unsigned width = 0; - - public: - /// Get the number of rows - unsigned getRowCount() const { return width; } - /// Get all entries in a row - EntryRange getRow(unsigned i) const { return EntryRange{entries.data() + entryOffsets[i - 1], entries.data() + entryOffsets[i]}; } - - /// Read a file - static Matrix readFile(string fileName); -}; -//--------------------------------------------------------------------------- -Matrix Matrix::readFile(string fileName) -// Read a sparse matrix file -{ - ifstream in(fileName); - if (!in.is_open()) { - cerr << "unable to read " << fileName << endl; - exit(1); - } - - Matrix result; - - // Check the header line - string line; - while (getline(in, line)) { - if (line.empty() || (line.front() == '%')) continue; - - stringstream s(line); - unsigned width, height, entries; - s >> width >> height >> entries; - if ((width != height) || (!width) || (!entries)) { - cerr << "matrix must be symmetric and non-empty" << endl; - exit(1); - } - result.width = width; - result.entries.reserve(entries); - result.entryOffsets.reserve(width + 1); - break; - } - - // Read the elements - unsigned currentRow = 0, currentColumn = 0; - while (getline(in, line)) { - if (line.empty() || (line.front() == '%')) continue; - stringstream s(line); - unsigned row, column; - double weight; - if (!(s >> row >> column >> weight)) { - cerr << "malformed matrix line " << line << endl; - exit(1); - } - if ((row > result.width) || (column > result.width)) { - cerr << "malformed matrix format, cell offset out of bounds " << line << endl; - exit(1); - } - if (row < currentRow) { - cerr << "malformed matrix format, row number decreased " << line << endl; - exit(1); - } - if (row == currentRow) { - if (column <= currentColumn) { - cerr << "malformed matrix format, column number decreased " << line << endl; - exit(1); - } - currentColumn = column; - } else { - result.entryOffsets.insert(result.entryOffsets.end(), row - currentRow, result.entries.size()); - currentRow = row; - currentColumn = column; - } - result.entries.push_back({column, weight}); - } - result.entryOffsets.insert(result.entryOffsets.end(), result.width + 1 - currentRow, result.entries.size()); - return result; -} -//--------------------------------------------------------------------------- -/// A priority queue of entries. We need both a min heap (best candidate) and a max heap (worst candidate), thus we implement it only once and use a template paramter -template -class PriorityQueue { - public: - /// An entry - struct Entry { - /// The id - unsigned index; - /// The weight - double weight; - }; - - public: - /// The entries - vector entries; - /// The entry lookup - unordered_map entryLookup; - - /// Comparison logic - static bool isLess(double a, double b) { - if (isMin) { - return a < b; - } else { - return a > b; - } - } - - /// Move an element up in the heap until it is at the correct position - void heapifyUp(unsigned slot); - /// Move an element down the heap until it is at the correct position - void heapifyDown(unsigned slot); - - public: - /// Is the heap empty? - bool empty() const { return entries.empty(); } - /// The number of entries in the heap - unsigned size() const { return entries.size(); } - - /// Get the top element - const Entry& front() const { return entries.front(); } - /// Remove the top element from the queue - Entry pop_front(); - /// Add an element - void insert(Entry entry); - - using iterator = const Entry*; - /// Find an entry - iterator find(unsigned index) const; - /// Marker for not found - iterator end() const { return nullptr; } - /// Update an entry - void update(iterator pos, double newWeight); - /// Remove an entry - void erase(iterator pos); -}; -//--------------------------------------------------------------------------- -template -void PriorityQueue::heapifyUp(unsigned slot) -// Move an element up in the heap until it is at the correct position -{ - if ((!slot) || (slot >= entries.size())) return; - auto& currentPos = entryLookup[entries[slot].index]; - - // Bubble up until the heap condition is restored - while (slot > 0) { - unsigned parentSlot = slot / 2; - if (isLess(entries[slot].weight, entries[parentSlot].weight)) { - entryLookup[entries[parentSlot].index] = slot; - currentPos = parentSlot; - swap(entries[slot], entries[parentSlot]); - slot = parentSlot; - } else { - break; - } - } -} -//--------------------------------------------------------------------------- -template -void PriorityQueue::heapifyDown(unsigned slot) -// Move an element down the heap until it is at the correct position -{ - if (slot >= entries.size()) return; - auto& currentPos = entryLookup[entries[slot].index]; - - // Bubble down until the heap condition is restored - while (true) { - unsigned leftChild = 2 * slot, rightChild = leftChild + 1; - unsigned selectedChild; - if (rightChild < entries.size()) { - selectedChild = isLess(entries[leftChild].weight, entries[rightChild].weight) ? leftChild : rightChild; - } else if (leftChild < entries.size()) { - selectedChild = leftChild; - } else { - break; - } - if (isLess(entries[selectedChild].weight, entries[slot].weight)) { - entryLookup[entries[selectedChild].index] = slot; - currentPos = selectedChild; - swap(entries[slot], entries[selectedChild]); - slot = selectedChild; - } else { - break; - } - } -} -//--------------------------------------------------------------------------- -template -typename PriorityQueue::Entry PriorityQueue::pop_front() -// Remove the top element from the queue -{ - auto result = entries.front(); - swap(entries.front(), entries.back()); - entryLookup[entries.front().index] = 0; - entryLookup.erase(result.index); - entries.pop_back(); - heapifyDown(0); - - return result; -} -//--------------------------------------------------------------------------- -template -void PriorityQueue::insert(Entry entry) -// Add an element -{ - unsigned slot = entries.size(); - entries.push_back(entry); - entryLookup[entry.index] = slot; - heapifyUp(slot); -} -//--------------------------------------------------------------------------- -template -typename PriorityQueue::iterator PriorityQueue::find(unsigned index) const -// Find an entry -{ - auto iter = entryLookup.find(index); - if (iter == entryLookup.end()) return nullptr; - return entries.data() + iter->second; -} -//--------------------------------------------------------------------------- -template -void PriorityQueue::update(iterator pos, double newWeight) -// Update an entry -{ - unsigned slot = pos - entries.data(); - if (isLess(newWeight, pos->weight)) { - entries[slot].weight = newWeight; - heapifyUp(slot); - } else if (isLess(pos->weight, newWeight)) { - entries[slot].weight = newWeight; - heapifyDown(slot); - } -} -//--------------------------------------------------------------------------- -template -void PriorityQueue::erase(iterator pos) -// Remove an entry -{ - unsigned index = pos->index; - unsigned slot = pos - entries.data(); - double oldWeight = pos->weight, newWeight=entries.back().weight; - swap(entries[slot], entries.back()); - entryLookup[entries[slot].index] = slot; - entryLookup.erase(index); - entries.pop_back(); - if (isLess(oldWeight,newWeight)) { - heapifyDown(slot); - } else { - heapifyUp(slot); - } -} -//--------------------------------------------------------------------------- -using MinHeap = PriorityQueue; -using MaxHeap = PriorityQueue; -//--------------------------------------------------------------------------- -static vector getTopKNeighbors(const Matrix& m, unsigned start, unsigned k) -// Find the top k neighbors for a node -{ - vector result; - result.reserve(k); - - MinHeap minHeap; - MaxHeap maxHeap; - minHeap.insert({start, 0}); - maxHeap.insert({start, 0}); - while (!minHeap.empty()) { - // Remove the next element and add it to result - auto current = minHeap.pop_front(); - if (current.index != start) { - result.push_back({current.index, current.weight}); - if (result.size() >= k) break; - } - - // Examine all outgoing edges - for (auto& e : m.getRow(current.index)) { - // Already in heap? - double newDistance = current.weight + e.weight; - auto iter = maxHeap.find(e.column); - if (iter != maxHeap.end()) { - // If an entry is in the max heap but not in the min heap it is already in the result and we can ignore it - auto iter2 = minHeap.find(e.column); - if (iter2 != minHeap.end()) { - // Update if shorter - if (newDistance < iter2->weight) { - minHeap.update(iter2, newDistance); - maxHeap.update(iter, newDistance); - } - } - } else if (maxHeap.size() <= k) { - // As long as we have not seen k candidates add it unconditionally - minHeap.insert({e.column, newDistance}); - maxHeap.insert({e.column, newDistance}); - } else if (newDistance < maxHeap.front().weight) { - // We got a better candidate, remove the old one - auto worst = maxHeap.pop_front(); - minHeap.erase(minHeap.find(worst.index)); - minHeap.insert({e.column, newDistance}); - maxHeap.insert({e.column, newDistance}); - } - } - } - - return result; -} -//--------------------------------------------------------------------------- -} -//--------------------------------------------------------------------------- -int main(int argc, char* argv[]) { - if (argc != 6) { - cout << "usage: " << argv[0] << " matrixfile, output_prefix, k, n_chunks, percent_subsample" << endl; - return 0; - } - - // Read the matrix file - Matrix matrix = Matrix::readFile(argv[1]); - //get output_prefix - string output_prefix = argv[2]; - // The number of neighbors we are interested in - unsigned k = stoi(argv[3]); //convert input char to integer - - unsigned n_chunks = stoi(argv[4]); - unsigned limit = matrix.getRowCount(); - unsigned len_ch; - if (n_chunks <= 1){ - n_chunks = 0; - len_ch = limit; - } - else{ - len_ch = limit/n_chunks; - } - //get percentage to which should be subsampled - unsigned sub = stoi(argv[5]); - - //ininitialize random number generator - random_device rd; //used to obtain seed for random number engine - mt19937 gen(rd()); //standard merseen_twister_engine seeded with rd() - uniform_int_distribution<> dis(1, 100); //uniform int distribution between 0 and 100 - int rand_res; - //sanity check - //double sum = 0; - - //variable declaration - ofstream distances; - ofstream indices; - string dist; - string indi; - unsigned lower; - unsigned upper; - - for (unsigned n_ch = 0; n_ch <= n_chunks; ++n_ch){ - // Find the top k elements for all nodes. Computes the sum of all the weights, just to have some result to show - // write all neighbors and weights to two files - dist = output_prefix + "_distances_" + to_string(n_ch) + ".txt"; - indi = output_prefix + "_indices_" + to_string(n_ch) + ".txt"; - - distances.open(dist, ios::out | ios::binary); - indices.open(indi, ios::out | ios::binary); - lower = n_ch * len_ch + 1; - upper = (n_ch+1)*len_ch; - // don't run over upper limit - if (upper > limit) { - upper = limit; - } - for (unsigned row = lower; row <= upper; row++) { - //cout << row << endl; - // Ignore empty rows - if (matrix.getRow(row).empty()) { - //distances << row << endl; //add index to distances file to keep order - //indices << row << endl; //add index to indices file to keep order - continue; - } - // use subsampling - rand_res = dis(gen); //generate random number - if (rand_res> sub){ //skip for 100-sub percent of the data - continue; - } - - // Find the top k neighbors - auto neighbors = getTopKNeighbors(matrix, row, k); - distances << row; //add index of the root index to file - indices << row; //add index of the root index to file - for (auto& n : neighbors) { - distances << ',' << n.weight; - indices << ',' << n.column; - //sum += n.weight; - } - distances << endl; //add end of line after each knn search - indices << endl; //add end of line after each knn search - } - distances.close(); - indices.close(); - } - //cout << "sum of weights for all top " << k << " neighbors " << sum << endl; -} -//--------------------------------------------------------------------------- +#include +#include +#include +#include +#include +#include +#include +//--------------------------------------------------------------------------- +using namespace std; +//--------------------------------------------------------------------------- +namespace { +//--------------------------------------------------------------------------- +// A sparse distance matrix, must be symmetric +class Matrix { + public: + /// An entry + struct Entry { + /// The column + unsigned column; + /// The weight + double weight; + }; + /// An entry range + struct EntryRange { + /// The entry range + const Entry *from, *to; + + /// Empty range? + bool empty() const { return from == to; } + /// First element + const Entry* begin() { return from; } + /// Behind the last element + const Entry* end() { return to; } + }; + + public: + /// The entries + vector entries; + /// The entries offsets + vector entryOffsets; + /// The width of the matrix + unsigned width = 0; + + public: + /// Get the number of rows + unsigned getRowCount() const { return width; } + /// Get all entries in a row + EntryRange getRow(unsigned i) const { return EntryRange{entries.data() + entryOffsets[i - 1], entries.data() + entryOffsets[i]}; } + + /// Read a file + static Matrix readFile(string fileName); +}; +//--------------------------------------------------------------------------- +Matrix Matrix::readFile(string fileName) +// Read a sparse matrix file +{ + ifstream in(fileName); + if (!in.is_open()) { + cerr << "unable to read " << fileName << endl; + exit(1); + } + + Matrix result; + + // Check the header line + string line; + while (getline(in, line)) { + if (line.empty() || (line.front() == '%')) continue; + + stringstream s(line); + unsigned width, height, entries; + s >> width >> height >> entries; + if ((width != height) || (!width) || (!entries)) { + cerr << "matrix must be symmetric and non-empty" << endl; + exit(1); + } + result.width = width; + result.entries.reserve(entries); + result.entryOffsets.reserve(width + 1); + break; + } + + // Read the elements + unsigned currentRow = 0, currentColumn = 0; + while (getline(in, line)) { + if (line.empty() || (line.front() == '%')) continue; + stringstream s(line); + unsigned row, column; + double weight; + if (!(s >> row >> column >> weight)) { + cerr << "malformed matrix line " << line << endl; + exit(1); + } + if ((row > result.width) || (column > result.width)) { + cerr << "malformed matrix format, cell offset out of bounds " << line << endl; + exit(1); + } + if (row < currentRow) { + cerr << "malformed matrix format, row number decreased " << line << endl; + exit(1); + } + if (row == currentRow) { + if (column <= currentColumn) { + cerr << "malformed matrix format, column number decreased " << line << endl; + exit(1); + } + currentColumn = column; + } else { + result.entryOffsets.insert(result.entryOffsets.end(), row - currentRow, result.entries.size()); + currentRow = row; + currentColumn = column; + } + result.entries.push_back({column, weight}); + } + result.entryOffsets.insert(result.entryOffsets.end(), result.width + 1 - currentRow, result.entries.size()); + return result; +} +//--------------------------------------------------------------------------- +/// A priority queue of entries. We need both a min heap (best candidate) and a max heap (worst candidate), thus we implement it only once and use a template paramter +template +class PriorityQueue { + public: + /// An entry + struct Entry { + /// The id + unsigned index; + /// The weight + double weight; + }; + + public: + /// The entries + vector entries; + /// The entry lookup + unordered_map entryLookup; + + /// Comparison logic + static bool isLess(double a, double b) { + if (isMin) { + return a < b; + } else { + return a > b; + } + } + + /// Move an element up in the heap until it is at the correct position + void heapifyUp(unsigned slot); + /// Move an element down the heap until it is at the correct position + void heapifyDown(unsigned slot); + + public: + /// Is the heap empty? + bool empty() const { return entries.empty(); } + /// The number of entries in the heap + unsigned size() const { return entries.size(); } + + /// Get the top element + const Entry& front() const { return entries.front(); } + /// Remove the top element from the queue + Entry pop_front(); + /// Add an element + void insert(Entry entry); + + using iterator = const Entry*; + /// Find an entry + iterator find(unsigned index) const; + /// Marker for not found + iterator end() const { return nullptr; } + /// Update an entry + void update(iterator pos, double newWeight); + /// Remove an entry + void erase(iterator pos); +}; +//--------------------------------------------------------------------------- +template +void PriorityQueue::heapifyUp(unsigned slot) +// Move an element up in the heap until it is at the correct position +{ + if ((!slot) || (slot >= entries.size())) return; + auto& currentPos = entryLookup[entries[slot].index]; + + // Bubble up until the heap condition is restored + while (slot > 0) { + unsigned parentSlot = slot / 2; + if (isLess(entries[slot].weight, entries[parentSlot].weight)) { + entryLookup[entries[parentSlot].index] = slot; + currentPos = parentSlot; + swap(entries[slot], entries[parentSlot]); + slot = parentSlot; + } else { + break; + } + } +} +//--------------------------------------------------------------------------- +template +void PriorityQueue::heapifyDown(unsigned slot) +// Move an element down the heap until it is at the correct position +{ + if (slot >= entries.size()) return; + auto& currentPos = entryLookup[entries[slot].index]; + + // Bubble down until the heap condition is restored + while (true) { + unsigned leftChild = 2 * slot, rightChild = leftChild + 1; + unsigned selectedChild; + if (rightChild < entries.size()) { + selectedChild = isLess(entries[leftChild].weight, entries[rightChild].weight) ? leftChild : rightChild; + } else if (leftChild < entries.size()) { + selectedChild = leftChild; + } else { + break; + } + if (isLess(entries[selectedChild].weight, entries[slot].weight)) { + entryLookup[entries[selectedChild].index] = slot; + currentPos = selectedChild; + swap(entries[slot], entries[selectedChild]); + slot = selectedChild; + } else { + break; + } + } +} +//--------------------------------------------------------------------------- +template +typename PriorityQueue::Entry PriorityQueue::pop_front() +// Remove the top element from the queue +{ + auto result = entries.front(); + swap(entries.front(), entries.back()); + entryLookup[entries.front().index] = 0; + entryLookup.erase(result.index); + entries.pop_back(); + heapifyDown(0); + + return result; +} +//--------------------------------------------------------------------------- +template +void PriorityQueue::insert(Entry entry) +// Add an element +{ + unsigned slot = entries.size(); + entries.push_back(entry); + entryLookup[entry.index] = slot; + heapifyUp(slot); +} +//--------------------------------------------------------------------------- +template +typename PriorityQueue::iterator PriorityQueue::find(unsigned index) const +// Find an entry +{ + auto iter = entryLookup.find(index); + if (iter == entryLookup.end()) return nullptr; + return entries.data() + iter->second; +} +//--------------------------------------------------------------------------- +template +void PriorityQueue::update(iterator pos, double newWeight) +// Update an entry +{ + unsigned slot = pos - entries.data(); + if (isLess(newWeight, pos->weight)) { + entries[slot].weight = newWeight; + heapifyUp(slot); + } else if (isLess(pos->weight, newWeight)) { + entries[slot].weight = newWeight; + heapifyDown(slot); + } +} +//--------------------------------------------------------------------------- +template +void PriorityQueue::erase(iterator pos) +// Remove an entry +{ + unsigned index = pos->index; + unsigned slot = pos - entries.data(); + double oldWeight = pos->weight, newWeight=entries.back().weight; + swap(entries[slot], entries.back()); + entryLookup[entries[slot].index] = slot; + entryLookup.erase(index); + entries.pop_back(); + if (isLess(oldWeight,newWeight)) { + heapifyDown(slot); + } else { + heapifyUp(slot); + } +} +//--------------------------------------------------------------------------- +using MinHeap = PriorityQueue; +using MaxHeap = PriorityQueue; +//--------------------------------------------------------------------------- +static vector getTopKNeighbors(const Matrix& m, unsigned start, unsigned k) +// Find the top k neighbors for a node +{ + vector result; + result.reserve(k); + + MinHeap minHeap; + MaxHeap maxHeap; + minHeap.insert({start, 0}); + maxHeap.insert({start, 0}); + while (!minHeap.empty()) { + // Remove the next element and add it to result + auto current = minHeap.pop_front(); + if (current.index != start) { + result.push_back({current.index, current.weight}); + if (result.size() >= k) break; + } + + // Examine all outgoing edges + for (auto& e : m.getRow(current.index)) { + // Already in heap? + double newDistance = current.weight + e.weight; + auto iter = maxHeap.find(e.column); + if (iter != maxHeap.end()) { + // If an entry is in the max heap but not in the min heap it is already in the result and we can ignore it + auto iter2 = minHeap.find(e.column); + if (iter2 != minHeap.end()) { + // Update if shorter + if (newDistance < iter2->weight) { + minHeap.update(iter2, newDistance); + maxHeap.update(iter, newDistance); + } + } + } else if (maxHeap.size() <= k) { + // As long as we have not seen k candidates add it unconditionally + minHeap.insert({e.column, newDistance}); + maxHeap.insert({e.column, newDistance}); + } else if (newDistance < maxHeap.front().weight) { + // We got a better candidate, remove the old one + auto worst = maxHeap.pop_front(); + minHeap.erase(minHeap.find(worst.index)); + minHeap.insert({e.column, newDistance}); + maxHeap.insert({e.column, newDistance}); + } + } + } + + return result; +} +//--------------------------------------------------------------------------- +} +//--------------------------------------------------------------------------- +int main(int argc, char* argv[]) { + if (argc != 6) { + cout << "usage: " << argv[0] << " matrixfile, output_prefix, k, n_chunks, percent_subsample" << endl; + return 0; + } + + // Read the matrix file + Matrix matrix = Matrix::readFile(argv[1]); + //get output_prefix + string output_prefix = argv[2]; + // The number of neighbors we are interested in + unsigned k = stoi(argv[3]); //convert input char to integer + + unsigned n_chunks = stoi(argv[4]); + unsigned limit = matrix.getRowCount(); + unsigned len_ch; + if (n_chunks <= 1){ + n_chunks = 0; + len_ch = limit; + } + else{ + len_ch = limit/n_chunks; + } + //get percentage to which should be subsampled + unsigned sub = stoi(argv[5]); + + //ininitialize random number generator + random_device rd; //used to obtain seed for random number engine + mt19937 gen(rd()); //standard merseen_twister_engine seeded with rd() + uniform_int_distribution<> dis(1, 100); //uniform int distribution between 0 and 100 + int rand_res; + //sanity check + //double sum = 0; + + //variable declaration + ofstream distances; + ofstream indices; + string dist; + string indi; + unsigned lower; + unsigned upper; + + for (unsigned n_ch = 0; n_ch <= n_chunks; ++n_ch){ + // Find the top k elements for all nodes. Computes the sum of all the weights, just to have some result to show + // write all neighbors and weights to two files + dist = output_prefix + "_distances_" + to_string(n_ch) + ".txt"; + indi = output_prefix + "_indices_" + to_string(n_ch) + ".txt"; + + distances.open(dist, ios::out | ios::binary); + indices.open(indi, ios::out | ios::binary); + lower = n_ch * len_ch + 1; + upper = (n_ch+1)*len_ch; + // don't run over upper limit + if (upper > limit) { + upper = limit; + } + for (unsigned row = lower; row <= upper; row++) { + //cout << row << endl; + // Ignore empty rows + if (matrix.getRow(row).empty()) { + //distances << row << endl; //add index to distances file to keep order + //indices << row << endl; //add index to indices file to keep order + continue; + } + // use subsampling + rand_res = dis(gen); //generate random number + if (rand_res> sub){ //skip for 100-sub percent of the data + continue; + } + + // Find the top k neighbors + auto neighbors = getTopKNeighbors(matrix, row, k); + distances << row; //add index of the root index to file + indices << row; //add index of the root index to file + for (auto& n : neighbors) { + distances << ',' << n.weight; + indices << ',' << n.column; + //sum += n.weight; + } + distances << endl; //add end of line after each knn search + indices << endl; //add end of line after each knn search + } + distances.close(); + indices.close(); + } + //cout << "sum of weights for all top " << k << " neighbors " << sum << endl; +} +//--------------------------------------------------------------------------- diff --git a/scIB/knn_graph/knn_graph.o b/scib/knn_graph/knn_graph.o similarity index 100% rename from scIB/knn_graph/knn_graph.o rename to scib/knn_graph/knn_graph.o diff --git a/scIB/knn_graph/makefile b/scib/knn_graph/makefile similarity index 100% rename from scIB/knn_graph/makefile rename to scib/knn_graph/makefile diff --git a/scIB/metrics/__init__.py b/scib/metrics/__init__.py similarity index 100% rename from scIB/metrics/__init__.py rename to scib/metrics/__init__.py diff --git a/scIB/metrics/ari.py b/scib/metrics/ari.py similarity index 91% rename from scIB/metrics/ari.py rename to scib/metrics/ari.py index 55844248..ab517ed4 100644 --- a/scIB/metrics/ari.py +++ b/scib/metrics/ari.py @@ -3,7 +3,7 @@ import scipy.special from sklearn.metrics.cluster import adjusted_rand_score -from scIB.utils import checkAdata, checkBatch +from ..utils import check_adata, check_batch def ari(adata, group1, group2, implementation=None): @@ -19,9 +19,9 @@ def ari(adata, group1, group2, implementation=None): otherwise native implementation is taken """ - checkAdata(adata) - checkBatch(group1, adata.obs) - checkBatch(group2, adata.obs) + check_adata(adata) + check_batch(group1, adata.obs) + check_batch(group2, adata.obs) group1 = adata.obs[group1].to_numpy() group2 = adata.obs[group2].to_numpy() diff --git a/scIB/metrics/cell_cycle.py b/scib/metrics/cell_cycle.py similarity index 97% rename from scIB/metrics/cell_cycle.py rename to scib/metrics/cell_cycle.py index 76eaa2c7..63da42da 100644 --- a/scIB/metrics/cell_cycle.py +++ b/scib/metrics/cell_cycle.py @@ -1,9 +1,9 @@ import numpy as np import pandas as pd +from ..preprocessing import score_cell_cycle +from ..utils import check_adata from .pcr import pc_regression -from scIB.utils import checkAdata -from scIB.preprocessing import score_cell_cycle def cell_cycle( @@ -45,8 +45,8 @@ def cell_cycle( A score between 1 and 0. The larger the score, the stronger the cell cycle variance is conserved. """ - checkAdata(adata_pre) - checkAdata(adata_post) + check_adata(adata_pre) + check_adata(adata_post) if embed == 'X_pca': embed = None diff --git a/scIB/clustering.py b/scib/metrics/clustering.py similarity index 93% rename from scIB/clustering.py rename to scib/metrics/clustering.py index 2ebb8046..e736ce8b 100644 --- a/scIB/clustering.py +++ b/scib/metrics/clustering.py @@ -1,8 +1,9 @@ -import pandas as pd -import seaborn as sns import matplotlib.pyplot as plt +import pandas as pd import scanpy as sc -from . import metrics +import seaborn as sns + +from .nmi import nmi def opt_louvain(adata, label_key, cluster_key, function=None, resolutions=None, @@ -32,7 +33,7 @@ def opt_louvain(adata, label_key, cluster_key, function=None, resolutions=None, print('Clustering...') if function is None: - function = metrics.nmi + function = nmi if cluster_key in adata.obs.columns: if force: @@ -45,7 +46,7 @@ def opt_louvain(adata, label_key, cluster_key, function=None, resolutions=None, if resolutions is None: n = 20 - resolutions = [2*x/n for x in range(1,n+1)] + resolutions = [2 * x / n for x in range(1, n + 1)] score_max = 0 res_max = resolutions[0] @@ -79,7 +80,7 @@ def opt_louvain(adata, label_key, cluster_key, function=None, resolutions=None, score_all = pd.DataFrame(zip(resolutions, score_all), columns=('resolution', 'score')) if plot: # score vs. resolution profile - sns.lineplot(data= score_all, x='resolution', y='score').set_title('Optimal cluster resolution profile') + sns.lineplot(data=score_all, x='resolution', y='score').set_title('Optimal cluster resolution profile') plt.show() if inplace: diff --git a/scIB/metrics/graph_connectivity.py b/scib/metrics/graph_connectivity.py similarity index 100% rename from scIB/metrics/graph_connectivity.py rename to scib/metrics/graph_connectivity.py diff --git a/scIB/metrics/highly_variable_genes.py b/scib/metrics/highly_variable_genes.py similarity index 94% rename from scIB/metrics/highly_variable_genes.py rename to scib/metrics/highly_variable_genes.py index 9d4c4609..efbaf5a3 100644 --- a/scIB/metrics/highly_variable_genes.py +++ b/scib/metrics/highly_variable_genes.py @@ -1,11 +1,11 @@ import numpy as np import scanpy as sc -from scIB.utils import splitBatches +from ..utils import split_batches def precompute_hvg_batch(adata, batch, features, n_hvg=500, save_hvg=False): - adata_list = splitBatches(adata, batch, hvg=features) + adata_list = split_batches(adata, batch, hvg=features) hvg_dir = {} for i in adata_list: sc.pp.filter_genes(i, min_cells=1) @@ -25,7 +25,7 @@ def precompute_hvg_batch(adata, batch, features, n_hvg=500, save_hvg=False): def hvg_overlap(adata_pre, adata_post, batch, n_hvg=500, verbose=False): hvg_post = adata_post.var_names - adata_post_list = splitBatches(adata_post, batch) + adata_post_list = split_batches(adata_post, batch) overlap = [] if ('hvg_before' in adata_pre.uns_keys()) and (set(hvg_post) == set(adata_pre.var_names)): diff --git a/scIB/metrics/isolated_labels.py b/scib/metrics/isolated_labels.py similarity index 99% rename from scIB/metrics/isolated_labels.py rename to scib/metrics/isolated_labels.py index cbaefe1f..daf2fe15 100644 --- a/scIB/metrics/isolated_labels.py +++ b/scib/metrics/isolated_labels.py @@ -1,7 +1,7 @@ import pandas as pd from sklearn.metrics import f1_score -from scIB.clustering import opt_louvain +from .clustering import opt_louvain from .silhouette import silhouette diff --git a/scIB/metrics/kbet.py b/scib/metrics/kbet.py similarity index 97% rename from scIB/metrics/kbet.py rename to scib/metrics/kbet.py index 87013945..8f96e956 100644 --- a/scIB/metrics/kbet.py +++ b/scib/metrics/kbet.py @@ -1,14 +1,15 @@ +import logging + +import anndata2ri import numpy as np -import scipy.sparse import pandas as pd -import logging -import rpy2.robjects as ro import rpy2.rinterface_lib.callbacks -import anndata2ri +import rpy2.robjects as ro import scanpy as sc +import scipy.sparse -from scIB.utils import checkAdata, checkBatch -from .utils import diffusion_conn, diffusion_nn, NeighborsError +from ..utils import check_adata, check_batch +from .utils import NeighborsError, diffusion_conn, diffusion_nn rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages @@ -87,9 +88,9 @@ def kBET( return_df=True: pd.DataFrame with kBET observed rejection rates per cluster for batch """ - checkAdata(adata) - checkBatch(batch_key, adata.obs) - checkBatch(label_key, adata.obs) + check_adata(adata) + check_batch(batch_key, adata.obs) + check_batch(label_key, adata.obs) try: ro.r("library(kBET)") diff --git a/scIB/metrics/lisi.py b/scib/metrics/lisi.py similarity index 98% rename from scIB/metrics/lisi.py rename to scib/metrics/lisi.py index 56f9e1d6..405b2a34 100644 --- a/scIB/metrics/lisi.py +++ b/scib/metrics/lisi.py @@ -1,20 +1,21 @@ +import itertools +import logging +import multiprocessing as mp import os import pathlib -import itertools +import subprocess import tempfile + +import anndata2ri import numpy as np import pandas as pd -import scipy.sparse -from scipy.io import mmwrite -import multiprocessing as mp -import subprocess -import logging -import rpy2.robjects as ro import rpy2.rinterface_lib.callbacks -import anndata2ri +import rpy2.robjects as ro import scanpy as sc +import scipy.sparse +from scipy.io import mmwrite -from scIB.utils import checkAdata, checkBatch +from ..utils import check_adata, check_batch rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages @@ -40,9 +41,9 @@ def lisi( pd.DataFrame with median cLISI and median iLISI scores (following the harmony paper) """ - checkAdata(adata) - checkBatch(batch_key, adata.obs) - checkBatch(label_key, adata.obs) + check_adata(adata) + check_batch(batch_key, adata.obs) + check_batch(label_key, adata.obs) # if type_ != 'knn': # if verbose: @@ -218,8 +219,8 @@ def ilisi_graph( :return: Median of iLISI score """ - checkAdata(adata) - checkBatch(batch_key, adata.obs) + check_adata(adata) + check_batch(batch_key, adata.obs) adata_tmp = recompute_knn(adata, type_) ilisi_score = lisi_graph_py( @@ -275,9 +276,9 @@ def clisi_graph( :return: Median of cLISI score """ - checkAdata(adata) - checkBatch(batch_key, adata.obs) - checkBatch(label_key, adata.obs) + check_adata(adata) + check_batch(batch_key, adata.obs) + check_batch(label_key, adata.obs) adata_tmp = recompute_knn(adata, type_) diff --git a/scIB/metrics/metrics.py b/scib/metrics/metrics.py similarity index 94% rename from scIB/metrics/metrics.py rename to scib/metrics/metrics.py index 0bdb727e..4f56cc5a 100755 --- a/scIB/metrics/metrics.py +++ b/scib/metrics/metrics.py @@ -1,14 +1,15 @@ +import numpy as np import pandas as pd -from scIB.utils import * -from scIB.clustering import opt_louvain +from ..utils import check_adata, check_batch from .ari import ari from .cell_cycle import cell_cycle +from .clustering import opt_louvain from .graph_connectivity import graph_connectivity from .highly_variable_genes import hvg_overlap from .isolated_labels import isolated_labels from .kbet import kBET -from .lisi import ilisi_graph, clisi_graph +from .lisi import clisi_graph, ilisi_graph from .nmi import nmi from .pcr import pcr_comparison from .silhouette import silhouette, silhouette_batch @@ -178,13 +179,13 @@ def metrics( Compute of all metrics given unintegrate and integrated anndata object """ - checkAdata(adata) - checkBatch(batch_key, adata.obs) - checkBatch(label_key, adata.obs) + check_adata(adata) + check_batch(batch_key, adata.obs) + check_batch(label_key, adata.obs) - checkAdata(adata_int) - checkBatch(batch_key, adata_int.obs) - checkBatch(label_key, adata_int.obs) + check_adata(adata_int) + check_batch(batch_key, adata_int.obs) + check_batch(label_key, adata_int.obs) # clustering if nmi_ or ari_: @@ -365,7 +366,12 @@ def metrics( if trajectory_: print('Trajectory conservation score...') - trajectory_score = trajectory_conservation(adata, adata_int, label_key=label_key) + trajectory_score = trajectory_conservation( + adata, + adata_int, + label_key=label_key, + # batch_key=batch_key + ) else: trajectory_score = np.nan @@ -402,6 +408,7 @@ def measureTM(*args, **kwargs): """ import cProfile from pstats import Stats + import memory_profiler prof = cProfile.Profile() diff --git a/scIB/metrics/nmi.py b/scib/metrics/nmi.py similarity index 96% rename from scIB/metrics/nmi.py rename to scib/metrics/nmi.py index 4cad60c7..6c4f7d66 100644 --- a/scIB/metrics/nmi.py +++ b/scib/metrics/nmi.py @@ -1,9 +1,9 @@ import os import subprocess -import pandas as pd + from sklearn.metrics.cluster import normalized_mutual_info_score -from scIB.utils import checkAdata, checkBatch +from ..utils import check_adata, check_batch def nmi(adata, group1, group2, method="arithmetic", nmi_dir=None): @@ -26,9 +26,9 @@ def nmi(adata, group1, group2, method="arithmetic", nmi_dir=None): Normalized mutual information NMI value """ - checkAdata(adata) - checkBatch(group1, adata.obs) - checkBatch(group2, adata.obs) + check_adata(adata) + check_batch(group1, adata.obs) + check_batch(group2, adata.obs) group1 = adata.obs[group1].tolist() group2 = adata.obs[group2].tolist() diff --git a/scIB/metrics/pcr.py b/scib/metrics/pcr.py similarity index 98% rename from scIB/metrics/pcr.py rename to scib/metrics/pcr.py index 6b715c1c..cb45657d 100644 --- a/scIB/metrics/pcr.py +++ b/scib/metrics/pcr.py @@ -1,10 +1,10 @@ import numpy as np import pandas as pd -from scipy import sparse import scanpy as sc +from scipy import sparse from sklearn.linear_model import LinearRegression -from scIB.utils import checkAdata, checkBatch +from ..utils import check_adata, check_batch def pcr_comparison( @@ -96,8 +96,8 @@ def pcr( R2Var of regression """ - checkAdata(adata) - checkBatch(covariate, adata.obs) + check_adata(adata) + check_batch(covariate, adata.obs) if verbose: print(f"covariate: {covariate}") diff --git a/scIB/metrics/silhouette.py b/scib/metrics/silhouette.py similarity index 97% rename from scIB/metrics/silhouette.py rename to scib/metrics/silhouette.py index 1422f0b2..d61623e2 100644 --- a/scIB/metrics/silhouette.py +++ b/scib/metrics/silhouette.py @@ -1,5 +1,5 @@ import pandas as pd -from sklearn.metrics.cluster import silhouette_score, silhouette_samples +from sklearn.metrics.cluster import silhouette_samples, silhouette_score def silhouette( diff --git a/scIB/metrics/trajectory.py b/scib/metrics/trajectory.py similarity index 89% rename from scIB/metrics/trajectory.py rename to scib/metrics/trajectory.py index 53d7b702..ea0a531c 100644 --- a/scIB/metrics/trajectory.py +++ b/scib/metrics/trajectory.py @@ -1,10 +1,10 @@ import numpy as np -from scipy.sparse.csgraph import connected_components -import scanpy as sc import pandas as pd +import scanpy as sc +from scipy.sparse.csgraph import connected_components +from ..utils import check_batch from .utils import RootCellError -from ..utils import checkBatch def get_root( @@ -98,19 +98,27 @@ def trajectory_conservation( adata_post_ti.obs['dpt_pseudotime'] = adata_post_ti2.obs['dpt_pseudotime'] adata_post_ti.obs['dpt_pseudotime'].fillna(0, inplace=True) - adata_post_ti.obs['batch'] = adata_pre_ti.obs['batch'] - if batch_key == None: pseudotime_before = adata_pre_ti.obs[pseudotime_key] pseudotime_after = adata_post_ti.obs['dpt_pseudotime'] correlation = pseudotime_before.corr(pseudotime_after, 'spearman') return (correlation + 1) / 2 # scaled else: - checkBatch(batch_key, adata_pre.obs) - checkBatch(batch_key, adata_post.obs) + check_batch(batch_key, adata_pre.obs) + check_batch(batch_key, adata_post.obs) + + # check if batches match + if not np.array_equal(adata_post_ti.obs[batch_key], adata_pre_ti.obs[batch_key]): + raise ValueError( + 'Batch columns do not match\n' + f"adata_post_ti.obs['batch']:\n {adata_post_ti.obs[batch_key]}\n" + f"adata_pre_ti.obs['batch']:\n {adata_pre_ti.obs[batch_key]}\n" + ) + corr = pd.Series() for i in adata_pre_ti.obs[batch_key].unique(): pseudotime_before = adata_pre_ti.obs[adata_pre_ti.obs[batch_key] == i][pseudotime_key] pseudotime_after = adata_post_ti.obs[adata_post_ti.obs[batch_key] == i]['dpt_pseudotime'] corr[i] = pseudotime_before.corr(pseudotime_after, 'spearman') + return (corr.mean() + 1) / 2 # scaled diff --git a/scIB/metrics/utils.py b/scib/metrics/utils.py similarity index 99% rename from scIB/metrics/utils.py rename to scib/metrics/utils.py index d694399a..a60606fb 100644 --- a/scIB/metrics/utils.py +++ b/scib/metrics/utils.py @@ -1,7 +1,6 @@ import numpy as np -from scipy import sparse import pandas as pd - +from scipy import sparse # Errors diff --git a/scIB/preprocessing.py b/scib/preprocessing.py similarity index 77% rename from scIB/preprocessing.py rename to scib/preprocessing.py index 713666a7..170da5c0 100644 --- a/scIB/preprocessing.py +++ b/scib/preprocessing.py @@ -1,24 +1,27 @@ -import numpy as np -from matplotlib import pyplot as plt -import seaborn as sns -import scanpy as sc -from scipy import sparse +import logging +import tempfile +import anndata2ri +import numpy as np # rpy2 for running R code import rpy2.rinterface_lib.callbacks -import logging -rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages import rpy2.robjects as ro -import anndata2ri +import scanpy as sc +import seaborn +import seaborn as sns +from matplotlib import pyplot as plt +from scipy import sparse # access to other methods of this module -from scIB.utils import * +from . import utils + +rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages +seaborn.set_context('talk') def summarize_counts(adata, count_matrix=None, mt_gene_regex='^MT-'): - - checkAdata(adata) - + utils.check_adata(adata) + if count_matrix is None: count_matrix = adata.X adata.obs['n_counts'] = count_matrix.sum(1) @@ -34,32 +37,32 @@ def summarize_counts(adata, count_matrix=None, mt_gene_regex='^MT-'): if sparse.issparse(adata.X): mt_sum = mt_sum.A1 total_sum = total_sum.A1 - adata.obs['percent_mito'] = mt_sum / total_sum - - #mt_gene_mask = [gene.startswith('mt-') for gene in adata.var_names] - #mt_count = count_matrix[:, mt_gene_mask].sum(1) - #if mt_count.ndim > 1: + adata.obs['percent_mito'] = mt_sum / total_sum + + # mt_gene_mask = [gene.startswith('mt-') for gene in adata.var_names] + # mt_count = count_matrix[:, mt_gene_mask].sum(1) + # if mt_count.ndim > 1: # mt_count = np.squeeze(np.asarray(mt_count)) - #adata.obs['mt_frac'] = mt_count/adata.obs['n_counts'] - -### Quality Control -def plot_QC(adata, color=None, bins=60, legend_loc='right margin', histogram=True, - gene_threshold=(0,np.inf), - gene_filter_threshold=(0,np.inf), - count_threshold=(0,np.inf), - count_filter_threshold=(0,np.inf)): - + # adata.obs['mt_frac'] = mt_count/adata.obs['n_counts'] + + +### Quality Control +def plot_qc(adata, color=None, bins=60, legend_loc='right margin', histogram=True, + gene_threshold=(0, np.inf), + gene_filter_threshold=(0, np.inf), + count_threshold=(0, np.inf), + count_filter_threshold=(0, np.inf)): if count_filter_threshold == (0, np.inf): count_filter_threshold = count_threshold if gene_filter_threshold == (0, np.inf): gene_filter_threshold = gene_threshold - + # 2D scatter plot plot_scatter(adata, color=color, title=color, - gene_threshold=gene_filter_threshold[0], + gene_threshold=gene_filter_threshold[0], count_threshold=count_filter_threshold[0], legend_loc=legend_loc) - + if not histogram: return @@ -67,85 +70,85 @@ def plot_QC(adata, color=None, bins=60, legend_loc='right margin', histogram=Tru print(f"Counts Threshold: {count_filter_threshold}") # count filtering plot_count_filter(adata, obs_col='n_counts', bins=bins, - lower = count_threshold[0], - filter_lower = count_filter_threshold[0], - upper = count_threshold[1], - filter_upper = count_filter_threshold[1]) - + lower=count_threshold[0], + filter_lower=count_filter_threshold[0], + upper=count_threshold[1], + filter_upper=count_filter_threshold[1]) + if gene_filter_threshold != (0, np.inf): print(f"Gene Threshold: {gene_filter_threshold}") # gene filtering plot_count_filter(adata, obs_col='n_genes', bins=bins, - lower = gene_threshold[0], - filter_lower = gene_filter_threshold[0], - upper = gene_threshold[1], - filter_upper = gene_filter_threshold[1]) + lower=gene_threshold[0], + filter_lower=gene_filter_threshold[0], + upper=gene_threshold[1], + filter_upper=gene_filter_threshold[1]) + def plot_scatter(adata, count_threshold=0, gene_threshold=0, color=None, title='', lab_size=15, tick_size=11, legend_loc='right margin', palette=None): - - checkAdata(adata) + utils.check_adata(adata) if color: - checkBatch(color, adata.obs) - + utils.check_batch(color, adata.obs) + ax = sc.pl.scatter(adata, 'n_counts', 'n_genes', color=color, show=False, legend_fontweight=50, legend_loc=legend_loc, palette=palette) ax.set_title(title, fontsize=lab_size) - ax.set_xlabel("Count depth",fontsize=lab_size) - ax.set_ylabel("Number of genes",fontsize=lab_size) + ax.set_xlabel("Count depth", fontsize=lab_size) + ax.set_ylabel("Number of genes", fontsize=lab_size) ax.tick_params(labelsize=tick_size) - + if gene_threshold > 0: - ax.axhline(gene_threshold, 0,1, color='red') + ax.axhline(gene_threshold, 0, 1, color='red') if count_threshold > 0: - ax.axvline(count_threshold, 0,1, color='red') - + ax.axvline(count_threshold, 0, 1, color='red') + fig = plt.gcf() cbar_ax = fig.axes[-1] cbar_ax.tick_params(labelsize=tick_size) f1 = ax.get_figure() plt.show() + def plot_count_filter(adata, obs_col='n_counts', bins=60, lower=0, upper=np.inf, filter_lower=0, filter_upper=np.inf): - plot_data = adata.obs[obs_col] - + sns.distplot(plot_data, kde=False, bins=bins) - + if lower > 0: - plt.axvline(lower, linestyle = '--', color = 'g') + plt.axvline(lower, linestyle='--', color='g') if filter_lower > 0: - plt.axvline(filter_lower, linestyle = '-', color = 'r') + plt.axvline(filter_lower, linestyle='-', color='r') if not np.isinf(upper): - plt.axvline(upper, linestyle = '--', color = 'g') + plt.axvline(upper, linestyle='--', color='g') if not np.isinf(upper): - plt.axvline(filter_upper, linestyle = '-', color = 'r') - plt.ticklabel_format(style='sci', axis='x', scilimits=(0,2)) + plt.axvline(filter_upper, linestyle='-', color='r') + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 2)) plt.show() - + # determine lower bound of total, look at points below lower bound if filter_lower > 0: print(f"lower threshold: {filter_lower}") sns.distplot(plot_data[plot_data < lower], kde=False, bins=bins) - plt.axvline(filter_lower, linestyle = '-', color = 'r') - plt.axvline(lower, linestyle = '--', color = 'g') - plt.ticklabel_format(style='sci', axis='x', scilimits=(0,2)) + plt.axvline(filter_lower, linestyle='-', color='r') + plt.axvline(lower, linestyle='--', color='g') + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 2)) plt.show() # determine upper bound of total if not np.isinf(filter_upper) and not np.isinf(upper): print(f"upper threshold: {filter_upper}") sns.distplot(plot_data[plot_data > upper], kde=False, bins=bins) - plt.axvline(filter_upper, linestyle = '-', color = 'r') - plt.axvline(upper, linestyle = '--', color = 'g') - plt.ticklabel_format(style='sci', axis='x', scilimits=(0,2)) + plt.axvline(filter_upper, linestyle='-', color='r') + plt.axvline(upper, linestyle='--', color='g') + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 2)) plt.show() + ### Normalisation -def normalize(adata, min_mean = 0.1, log=True, precluster=True, sparsify=True): - - checkAdata(adata) +def normalize(adata, min_mean=0.1, log=True, precluster=True, sparsify=True): + utils.check_adata(adata) # Check for 0 count cells if np.any(adata.X.sum(axis=1) == 0): @@ -159,22 +162,22 @@ def normalize(adata, min_mean = 0.1, log=True, precluster=True, sparsify=True): if sparsify: # massive speedup when working with sparse matrix - if not sparse.issparse(adata.X): # quick fix: HVG doesn't work on dense matrix + if not sparse.issparse(adata.X): # quick fix: HVG doesn't work on dense matrix adata.X = sparse.csr_matrix(adata.X) anndata2ri.activate() ro.r('library("scran")') - + # keep raw counts adata.layers["counts"] = adata.X.copy() - is_sparse=False + is_sparse = False X = adata.X.T # convert to CSC if possible. See https://github.com/MarioniLab/scran/issues/70 if sparse.issparse(X): is_sparse = True - - if X.nnz > 2**31-1: + + if X.nnz > 2 ** 31 - 1: X = X.tocoo() else: X = X.tocsc() @@ -200,10 +203,10 @@ def normalize(adata, min_mean = 0.1, log=True, precluster=True, sparsify=True): else: size_factors = ro.r('sizeFactors(computeSumFactors(SingleCellExperiment(' f'list(counts=data_mat)), min.mean = {min_mean}))') - + # modify adata adata.obs['size_factors'] = size_factors - adata.X /= adata.obs['size_factors'].values[:,None] + adata.X /= adata.obs['size_factors'].values[:, None] if log: print("Note! Performing log1p-transformation after normalization.") sc.pp.log1p(adata) @@ -214,7 +217,7 @@ def normalize(adata, min_mean = 0.1, log=True, precluster=True, sparsify=True): # convert to sparse, bc operation always converts to dense adata.X = sparse.csr_matrix(adata.X) - adata.raw = adata # Store the full data set in 'raw' as log-normalised data for statistical testing + adata.raw = adata # Store the full data set in 'raw' as log-normalised data for statistical testing # Free memory in R ro.r('rm(list=ls())') @@ -231,8 +234,8 @@ def scale_batch(adata, batch): Function to scale the gene expression values of each batch separately. """ - checkAdata(adata) - checkBatch(batch, adata.obs) + utils.check_adata(adata) + utils.check_batch(batch, adata.obs) # Store layers for after merge (avoids vstack error in merge) adata_copy = adata.copy() @@ -241,12 +244,12 @@ def scale_batch(adata, batch): tmp[lay] = adata_copy.layers[lay] del adata_copy.layers[lay] - split = splitBatches(adata_copy, batch) + split = utils.split_batches(adata_copy, batch) for i in split: sc.pp.scale(i) - adata_scaled = merge_adata(split) + adata_scaled = utils.merge_adata(split) # Reorder to original obs_name ordering adata_scaled = adata_scaled[adata.obs_names] @@ -257,12 +260,13 @@ def scale_batch(adata, batch): del tmp del adata_copy - + return adata_scaled - -def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins=20, adataOut=False, n_stop=8000, min_genes=500, step_size=1000): -### Feature Selection + +def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins=20, adataOut=False, n_stop=8000, + min_genes=500, step_size=1000): + ### Feature Selection """ params: adata: @@ -273,48 +277,47 @@ def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins= return: list of highly variable genes less or equal to `target_genes` """ - - checkAdata(adata) - checkBatch(batch, adata.obs) - + + utils.check_adata(adata) + utils.check_batch(batch, adata.obs) + intersect = None enough = False n_hvg = target_genes - - split = splitBatches(adata, batch) + + split = utils.split_batches(adata, batch) hvg_res = [] for i in split: - sc.pp.filter_genes(i, min_cells=1) # remove genes unexpressed (otherwise hvg might break) + sc.pp.filter_genes(i, min_cells=1) # remove genes unexpressed (otherwise hvg might break) hvg_res.append(sc.pp.highly_variable_genes(i, flavor='cell_ranger', n_top_genes=n_hvg, inplace=False)) while not enough: genes = [] for i in range(len(split)): - dispersion_norm = hvg_res[i]['dispersions_norm'] dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] dispersion_norm[::-1].sort() - disp_cut_off = dispersion_norm[n_hvg-1] + disp_cut_off = dispersion_norm[n_hvg - 1] gene_subset = np.nan_to_num(hvg_res[i]['dispersions_norm']) >= disp_cut_off genes.append(set(split[i].var[gene_subset].index)) intersect = genes[0].intersection(*genes[1:]) - if len(intersect)>=target_genes: - enough=True + if len(intersect) >= target_genes: + enough = True else: - if n_hvg>n_stop: + if n_hvg > n_stop: if len(intersect) < min_genes: raise Exception(f'Only {len(intersect)} HVGs were found in the intersection.\n' f'This is fewer than {min_genes} HVGs set as the minimum.\n' 'Consider raising `n_stop` or reducing `min_genes`.') break - n_hvg=int(n_hvg+step_size) + n_hvg = int(n_hvg + step_size) if adataOut: - return adata[:,list(intersect)].copy() + return adata[:, list(intersect)].copy() return list(intersect) @@ -328,41 +331,41 @@ def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_ then HVGs in all but one batches are used to fill up. This is continued until HVGs in a single batch are considered. """ - - checkAdata(adata) + + utils.check_adata(adata) if batch_key is not None: - checkBatch(batch_key, adata.obs) - + utils.check_batch(batch_key, adata.obs) + adata_hvg = adata if adataOut else adata.copy() n_batches = len(adata_hvg.obs[batch_key].cat.categories) # Calculate double target genes per dataset sc.pp.highly_variable_genes(adata_hvg, - flavor=flavor, + flavor=flavor, n_top_genes=target_genes, - n_bins=n_bins, + n_bins=n_bins, batch_key=batch_key) nbatch1_dispersions = adata_hvg.var['dispersions_norm'][adata_hvg.var.highly_variable_nbatches > - len(adata_hvg.obs[batch_key].cat.categories)-1] - + len(adata_hvg.obs[batch_key].cat.categories) - 1] + nbatch1_dispersions.sort_values(ascending=False, inplace=True) if len(nbatch1_dispersions) > target_genes: hvg = nbatch1_dispersions.index[:target_genes] - + else: enough = False print(f'Using {len(nbatch1_dispersions)} HVGs from full intersect set') hvg = nbatch1_dispersions.index[:] not_n_batches = 1 - + while not enough: target_genes_diff = target_genes - len(hvg) tmp_dispersions = adata_hvg.var['dispersions_norm'][adata_hvg.var.highly_variable_nbatches == - (n_batches-not_n_batches)] + (n_batches - not_n_batches)] if len(tmp_dispersions) < target_genes_diff: print(f'Using {len(tmp_dispersions)} HVGs from n_batch-{not_n_batches} set') @@ -373,7 +376,7 @@ def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_ print(f'Using {target_genes_diff} HVGs from n_batch-{not_n_batches} set') tmp_dispersions.sort_values(ascending=False, inplace=True) hvg = hvg.append(tmp_dispersions.index[:target_genes_diff]) - enough=True + enough = True print(f'Using {len(hvg)} HVGs') @@ -381,14 +384,14 @@ def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_ del adata_hvg return hvg.tolist() else: - return adata_hvg[:,hvg].copy() + return adata_hvg[:, hvg].copy() ### Feature Reduction def reduce_data(adata, batch_key=None, subset=False, filter=True, flavor='cell_ranger', n_top_genes=2000, n_bins=20, pca=True, pca_comps=50, overwrite_hvg=True, - neighbors=True, use_rep='X_pca', + neighbors=True, use_rep='X_pca', umap=True): """ overwrite_hvg: @@ -397,20 +400,20 @@ def reduce_data(adata, batch_key=None, subset=False, if False, skips HVG computation even if `n_top_genes` is specified and uses pre-existing HVG column for PCA """ - - checkAdata(adata) + + utils.check_adata(adata) if batch_key: - checkBatch(batch_key, adata.obs) - + utils.check_batch(batch_key, adata.obs) + if n_top_genes is not None and overwrite_hvg: print("HVG") - + overwrite_hvg = False - + ## quick fix: HVG doesn't work on dense matrix if not sparse.issparse(adata.X): adata.X = sparse.csr_matrix(adata.X) - + if batch_key is not None: hvg_list = hvg_batch(adata, batch_key=batch_key, target_genes=n_top_genes, n_bins=n_bins) adata.var['highly_variable'] = np.in1d(adata.var_names, hvg_list) @@ -424,25 +427,25 @@ def reduce_data(adata, batch_key=None, subset=False, n_hvg = np.sum(adata.var["highly_variable"]) print(f'Computed {n_hvg} highly variable genes') - + if pca: print("PCA") use_hvgs = not overwrite_hvg and "highly_variable" in adata.var sc.tl.pca(adata, - n_comps=pca_comps, - use_highly_variable=use_hvgs, - svd_solver='arpack', + n_comps=pca_comps, + use_highly_variable=use_hvgs, + svd_solver='arpack', return_info=True) - + if neighbors: print("Nearest Neigbours") sc.pp.neighbors(adata, use_rep=use_rep) - + if umap: print("UMAP") sc.tl.umap(adata) - - + + ### Cell Cycle def score_cell_cycle(adata, organism='mouse'): """ @@ -454,7 +457,7 @@ def score_cell_cycle(adata, organism='mouse'): """ import pathlib root = pathlib.Path(__file__).parent - + cc_files = {'mouse': [root / 'resources/s_genes_tirosh.txt', root / 'resources/g2m_genes_tirosh.txt'], 'human': [root / 'resources/s_genes_tirosh_hm.txt', @@ -466,13 +469,13 @@ def score_cell_cycle(adata, organism='mouse'): g2m_genes = [x.strip() for x in f.readlines() if x.strip() in adata.var.index] if (len(s_genes) == 0) or (len(g2m_genes) == 0): - rand_choice = np.random.randint(1,adata.n_vars,10) + rand_choice = np.random.randint(1, adata.n_vars, 10) rand_genes = adata.var_names[rand_choice].tolist() raise ValueError(f"cell cycle genes not in adata\n organism: {organism}\n varnames: {rand_genes}") - + sc.tl.score_genes_cell_cycle(adata, s_genes, g2m_genes) - + def saveSeurat(adata, path, batch, hvgs=None): import re ro.r('library(Seurat)') @@ -489,25 +492,24 @@ def saveSeurat(adata, path, batch, hvgs=None): adata.layers[key].sort_indices() ro.globalenv['adata'] = adata - + ro.r('sobj = as.Seurat(adata, counts="counts", data = "X")') # Fix error if levels are 0 and 1 # ro.r(f'sobj$batch <- as.character(sobj${batch})') ro.r(f'Idents(sobj) = "{batch}"') - ro.r(f'saveRDS(sobj, file="{path}")') + ro.r(f'saveRDS(sobj, file="{path}")') if hvgs is not None: - hvg_out = re.sub('\.RDS$', '', path)+'_hvg.RDS' - #hvg_out = path+'_hvg.rds' - ro.globalenv['hvgs']=hvgs + hvg_out = re.sub('\.RDS$', '', path) + '_hvg.RDS' + # hvg_out = path+'_hvg.rds' + ro.globalenv['hvgs'] = hvgs ro.r('unlist(hvgs)') ro.r(f'saveRDS(hvgs, file="{hvg_out}")') - anndata2ri.deactivate() - - -def readSeurat(path): + + +def read_seurat(path): anndata2ri.activate() ro.r('library(Seurat)') ro.r('library(scater)') @@ -515,28 +517,28 @@ def readSeurat(path): adata = ro.r('as.SingleCellExperiment(sobj)') anndata2ri.deactivate() - #Test for 'X_EMB' + # Test for 'X_EMB' if 'X_EMB' in adata.obsm: if 'X_emb' in adata.obsm: print('overwriting existing `adata.obsm["X_emb"] in the adata object') adata.obsm['X_emb'] = adata.obsm['X_EMB'] del adata.obsm['X_EMB'] - - return(adata) - -def readConos(inPath): - from time import time - from shutil import rmtree - from scipy.io import mmread + + return (adata) + + +def read_conos(inPath, dir_path=None): from os import mkdir, path + from shutil import rmtree + from time import time + import pandas as pd - - dir_path = "/localscratch/conos"+str(int(time())) - while path.isdir(dir_path): - dir_path += '2' - dir_path += '/' - mkdir(dir_path) - + from scipy.io import mmread + + if dir_path is None: + tmpdir = tempfile.TemporaryDirectory() + dir_path = tmpdir.name + '/' + ro.r('library(conos)') ro.r(f'con <- readRDS("{inPath}")') ro.r('meta <- function(sobj) {return(sobj@meta.data)}') @@ -544,6 +546,7 @@ def readConos(inPath): ro.r('library(data.table)') ro.r('metaM <- do.call(rbind,unname(metalist))') ro.r(f'saveConosForScanPy(con, output.path="{dir_path}", pseudo.pca=TRUE, pca=TRUE, metadata.df=metaM)') + gene_df = pd.read_csv(dir_path + "genes.csv") metadata = pd.read_csv(dir_path + "metadata.csv") @@ -553,14 +556,13 @@ def readConos(inPath): embedding_df = pd.read_csv(dir_path + "embedding.csv") # Decide between using PCA or pseudo-PCA pseudopca_df = pd.read_csv(dir_path + "pseudopca.csv") - #pca_df = pd.read_csv(dir_path + "pca.csv") + # pca_df = pd.read_csv(dir_path + "pca.csv") graph_conn_mtx = mmread(dir_path + "graph_connectivities.mtx") graph_dist_mtx = mmread(dir_path + "graph_distances.mtx") - - adata = sc.read_mtx(dir_path+ "raw_count_matrix.mtx") - - + + adata = sc.read_mtx(dir_path + "raw_count_matrix.mtx") + adata.var_names = gene_df["gene"].values adata.obs_names = metadata.index.values @@ -577,12 +579,10 @@ def readConos(inPath): adata.uns['neighbors'] = dict(connectivities=graph_conn_mtx.tocsr(), distances=graph_dist_mtx.tocsr()) # Assign raw counts to .raw slot, load in normalised counts - #adata.raw = adata - #adata_temp = sc.read_mtx(DATA_PATH + "count_matrix.mtx") - #adata.X = adata_temp.X + # adata.raw = adata + # adata_temp = sc.read_mtx(DATA_PATH + "count_matrix.mtx") + # adata.X = adata_temp.X rmtree(dir_path) - - return adata - + return adata diff --git a/scIB/resources/g2m_genes_tirosh.txt b/scib/resources/g2m_genes_tirosh.txt similarity index 100% rename from scIB/resources/g2m_genes_tirosh.txt rename to scib/resources/g2m_genes_tirosh.txt diff --git a/scIB/resources/g2m_genes_tirosh_hm.txt b/scib/resources/g2m_genes_tirosh_hm.txt similarity index 100% rename from scIB/resources/g2m_genes_tirosh_hm.txt rename to scib/resources/g2m_genes_tirosh_hm.txt diff --git a/scIB/resources/s_genes_tirosh.txt b/scib/resources/s_genes_tirosh.txt similarity index 100% rename from scIB/resources/s_genes_tirosh.txt rename to scib/resources/s_genes_tirosh.txt diff --git a/scIB/resources/s_genes_tirosh_hm.txt b/scib/resources/s_genes_tirosh_hm.txt similarity index 100% rename from scIB/resources/s_genes_tirosh_hm.txt rename to scib/resources/s_genes_tirosh_hm.txt diff --git a/scIB/trajectory_inference.py b/scib/trajectory_inference.py similarity index 78% rename from scIB/trajectory_inference.py rename to scib/trajectory_inference.py index 509708e0..11ca7cb2 100644 --- a/scIB/trajectory_inference.py +++ b/scib/trajectory_inference.py @@ -1,34 +1,38 @@ -import scanpy as sc import matplotlib.pyplot as plt -from scIB.utils import * +import numpy as np +import scanpy as sc + +from . import utils + def paga(adata, groups='louvain'): """ """ - checkAdata(adata) - + utils.check_adata(adata) + sc.pp.neighbors(adata) sc.tl.paga(adata, groups=groups) _ = sc.pl.paga_compare(adata, show=False) - + fig1, ax1 = plt.subplots() sc.pl.umap(adata, size=40, ax=ax1, show=False) - sc.pl.paga(adata, pos=adata.uns['paga']['pos'], + sc.pl.paga(adata, pos=adata.uns['paga']['pos'], show=False, node_size_scale=10, - node_size_power=1, ax=ax1, text_kwds={'alpha':0}) + node_size_power=1, ax=ax1, text_kwds={'alpha': 0}) plt.show() + def dpt(adata, group, root, opt='min', comp=0): - checkAdata() - + utils.check_adata() + # TODO compute diffmap before - + # get root stem_mask = np.isin(adata.obs[group], root) if opt == 'min': - opt_stem_id = np.argmin(adata.obsm['X_diffmap'][stem_mask,comp]) + opt_stem_id = np.argmin(adata.obsm['X_diffmap'][stem_mask, comp]) elif opt == 'max': - opt_stem_id = np.argmax(adata.obsm['X_diffmap'][stem_mask,comp]) + opt_stem_id = np.argmax(adata.obsm['X_diffmap'][stem_mask, comp]) else: raise ("invalid optimum", opt) root_id = np.arange(len(stem_mask))[stem_mask][opt_stem_id] diff --git a/scIB/utils.py b/scib/utils.py similarity index 76% rename from scIB/utils.py rename to scib/utils.py index 34e9c19f..69ccba31 100644 --- a/scIB/utils.py +++ b/scib/utils.py @@ -1,52 +1,54 @@ -import numpy as np import anndata -import scanpy as sc # checker functions for data sanity -def checkAdata(adata): +def check_adata(adata): if type(adata) is not anndata.AnnData: raise TypeError('Input is not a valid AnnData object') -def checkBatch(batch, obs, verbose=False): + +def check_batch(batch, obs, verbose=False): if batch not in obs: raise ValueError(f'column {batch} is not in obs') elif verbose: print(f'Object contains {obs[batch].nunique()} batches.') -def checkHVG(hvg, adata_var): + +def check_hvg(hvg, adata_var): if type(hvg) is not list: raise TypeError('HVG list is not a list') else: if not all(i in adata_var.index for i in hvg): raise ValueError('Not all HVGs are in the adata object') -def checkSanity(adata, batch, hvg): - checkAdata(adata) - checkBatch(batch, adata.obs) + +def check_sanity(adata, batch, hvg): + check_adata(adata) + check_batch(batch, adata.obs) if hvg is not None: - checkHVG(hvg, adata.var) + check_hvg(hvg, adata.var) -def splitBatches(adata, batch, hvg= None, return_categories=False): +def split_batches(adata, batch, hvg=None, return_categories=False): split = [] batch_categories = adata.obs[batch].unique() if hvg is not None: adata = adata[:, hvg] for i in batch_categories: - split.append(adata[adata.obs[batch]==i].copy()) + split.append(adata[adata.obs[batch] == i].copy()) if return_categories: return split, batch_categories return split + def merge_adata(adata_list, sep='-'): """ merge adatas from list and remove duplicated obs and var columns """ - + if len(adata_list) == 1: return adata_list[0] - + adata = adata_list[0].concatenate(*adata_list[1:], index_unique=None, batch_key='tmp') del adata.obs['tmp'] @@ -57,11 +59,11 @@ def merge_adata(adata_list, sep='-'): clean_var = adata.var.loc[:, columns_to_keep] else: clean_var = adata.var - + if len(adata.var.columns) > 0: if sum(adata.var.columns.str.contains(sep)) > 0: - adata.var = clean_var.rename(columns={name : name.split('-')[0] for name in clean_var.columns.values}) - + adata.var = clean_var.rename(columns={name: name.split('-')[0] for name in clean_var.columns.values}) + return adata diff --git a/setup.cfg b/setup.cfg index 3287428d..3a9d774d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,82 @@ -# Inside of setup.cfg +[bumpversion] +current_version = 1.0.0 +commit = True +tag = True + +[bumpversion:file:VERSION.txt] +search = {current_version} +replace = {new_version} + [metadata] -description-file = README.md +name = scib +version = file: VERSION.txt +description = Evaluating single-cell data integration methods +long_description = file: README.md +long_description_content_type = text/markdown +author = Malte D. Luecken, Maren Buettner, Daniel C. Strobl, Michaela F. Mueller +author_email = malte.luecken@helmholtz-muenchen.de, michaela.mueller@helmholtz-muenchen.de +license = MIT +url = https://github.com/theislab/scib +project_urls = + Pipeline = https://github.com/theislab/scib-pipeline + Reproducibility = https://theislab.github.io/scib-reproducibility + Bug Tracker = https://github.com/theislab/scib/issues +keywords = + benchmarking + single cell + data integration +classifiers = + Development Status :: 3 - Alpha + Intended Audience :: Developers + Intended Audience :: Science/Research + Topic :: Software Development :: Build Tools + License :: OSI Approved :: MIT License + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + +[bdist_wheel] +build_number = 1 + +[options] +packages = + scib + scib.metrics +python_requires = >=3.7 +install_requires = + numpy==1.18.1 + pandas + seaborn + matplotlib + numba + scanpy>=1.5 + anndata>=0.7.2 + h5py<3 + rpy2>=3 + anndata2ri + scipy + scikit-learn + scikit-misc + louvain + umap-learn + pydot + python-igraph + llvmlite +zip_safe = False + +[options.package_data] +scib = + resources/*.txt + knn_graph/* + +[options.extras_require] +test = pytest; pytest-runner; pytest-icdiff +dev = build; twine; isort; bump2version +bbknn = bbknn ==1.3.9 +scanorama = scanorama ==1.7.0 +mnn = mnnpy ==0.1.9.5 +scgen = scgen ==1.1.5 +scvi = scvi ==0.6.7 +trvae = trvae ==1.1.2 +trvaep = trvaep ==0.1.0 +desc = desc ==2.0.3 diff --git a/setup.py b/setup.py index 467f7e91..60684932 100644 --- a/setup.py +++ b/setup.py @@ -1,33 +1,3 @@ from setuptools import setup -with open("requirements.txt", "r") as f: - requirements = f.read().splitlines() - requirements = [x for x in requirements if not x.startswith("#") and x != ""] - -with open("requirements_extra.txt", "r") as f: - requirements_extra = f.read().splitlines() - requirements_extra = [x for x in requirements_extra if not x.startswith("#") and x != ""] - - -setup(name='scIB', - version='0.1.1', - description='Benchmark tools for single cell data integration', - author='Malte Luecken, Maren Buettner, Daniel Strobl, Michaela Mueller', - author_email='malte.luecken@helmholtz-muenchen.de', - packages=['scIB', 'scIB.metrics'], - package_data={'scIB': ['resources/*.txt', 'knn_graph/*']}, - zip_safe=False, - license='MIT', - url='https://github.com/theislab/scib', - keywords = ['benchmark', 'single cell', 'data integration'], - install_requires=requirements, - extras_require={'integration': requirements_extra}, - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - ]) - +setup() diff --git a/tests/common.py b/tests/common.py index 827169d5..0839a512 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,5 @@ import pytest -import scIB +import scib import scanpy as sc import numpy as np import pandas as pd diff --git a/tests/conftest.py b/tests/conftest.py index 9e3de7e2..13e11b4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ def adata_paul15_template(): adata.obs['batch'] = adata.obs['batch'].astype(str) adata.obs['batch'] = adata.obs['batch'].astype("category") adata.layers['counts'] = adata.X - scIB.preprocessing.reduce_data( + scib.preprocessing.reduce_data( adata, pca=False, n_top_genes=None, @@ -73,7 +73,7 @@ def adata(adata_pbmc_template): @pytest.fixture() def adata_pca(adata): adata_obj = adata - scIB.pp.reduce_data( + scib.pp.reduce_data( adata_obj, pca=True, n_top_genes=200, @@ -86,7 +86,7 @@ def adata_pca(adata): @pytest.fixture() def adata_neighbors(adata): adata_obj = adata - scIB.pp.reduce_data( + scib.pp.reduce_data( adata_obj, pca=True, n_top_genes=200, @@ -99,7 +99,7 @@ def adata_neighbors(adata): @pytest.fixture() def adata_clustered(adata_neighbors): adata_obj = adata_neighbors - scIB.cl.opt_louvain( + scib.cl.opt_louvain( adata_obj, cluster_key='cluster', label_key='celltype', diff --git a/tests/metrics/test_all.py b/tests/metrics/test_all.py index bed44c74..6ee84d29 100644 --- a/tests/metrics/test_all.py +++ b/tests/metrics/test_all.py @@ -2,7 +2,7 @@ def test_fast(adata_neighbors): - metrics_df = scIB.me.metrics_fast( + metrics_df = scib.me.metrics_fast( adata_neighbors, adata_neighbors, batch_key='batch', @@ -20,7 +20,7 @@ def test_slim(adata_paul15): sc.pp.neighbors(adata_paul15) sc.tl.dpt(adata_paul15) - metrics_df = scIB.me.metrics_slim( + metrics_df = scib.me.metrics_slim( adata_paul15, adata_paul15, batch_key='batch', @@ -38,7 +38,7 @@ def test_slim(adata_paul15): # sc.pp.neighbors(adata_paul15) # sc.tl.dpt(adata_paul15) # -# metrics_df = scIB.me.metrics_all( +# metrics_df = scib.me.metrics_all( # adata_paul15, # adata_paul15, # batch_key='batch', diff --git a/tests/metrics/test_beyond_label_metrics.py b/tests/metrics/test_beyond_label_metrics.py index e670cba7..8493b670 100644 --- a/tests/metrics/test_beyond_label_metrics.py +++ b/tests/metrics/test_beyond_label_metrics.py @@ -6,7 +6,7 @@ def test_cell_cycle(adata_paul15): adata_int = adata.copy() # only final score - score = scIB.me.cell_cycle( + score = scib.me.cell_cycle( adata, adata_int, batch_key='batch', @@ -23,7 +23,7 @@ def test_cell_cycle_all(adata_paul15): adata_int = adata.copy() # get all intermediate scores - scores_df = scIB.me.cell_cycle( + scores_df = scib.me.cell_cycle( adata, adata_int, batch_key='batch', @@ -40,7 +40,7 @@ def test_cell_cycle_all(adata_paul15): def test_hvg_overlap(adata): adata_int = adata.copy() - score = scIB.me.hvg_overlap( + score = scib.me.hvg_overlap( adata_int, adata, batch='batch', diff --git a/tests/metrics/test_clisi.py b/tests/metrics/test_clisi.py index 1be4f2d3..c8734ef6 100644 --- a/tests/metrics/test_clisi.py +++ b/tests/metrics/test_clisi.py @@ -2,7 +2,7 @@ def test_clisi_full(adata): - score = scIB.me.clisi_graph( + score = scib.me.clisi_graph( adata, batch_key='batch', label_key='celltype', @@ -16,7 +16,7 @@ def test_clisi_full(adata): def test_clisi_embed(adata_neighbors): adata_neighbors.obsm['X_emb'] = adata_neighbors.obsm['X_pca'] - score = scIB.me.clisi_graph( + score = scib.me.clisi_graph( adata_neighbors, batch_key='batch', label_key='celltype', @@ -28,7 +28,7 @@ def test_clisi_embed(adata_neighbors): def test_clisi_knn(adata_neighbors): - score = scIB.me.clisi_graph( + score = scib.me.clisi_graph( adata_neighbors, batch_key='batch', label_key='celltype', diff --git a/tests/metrics/test_cluster_metrics.py b/tests/metrics/test_cluster_metrics.py index f2179510..4352a4db 100644 --- a/tests/metrics/test_cluster_metrics.py +++ b/tests/metrics/test_cluster_metrics.py @@ -2,22 +2,22 @@ def test_nmi_trivial(adata): - score = scIB.me.nmi(adata, 'celltype', 'celltype') + score = scib.me.nmi(adata, 'celltype', 'celltype') assert score == 1 def test_ari_trivial(adata): - score = scIB.me.ari(adata, 'celltype', 'celltype') + score = scib.me.ari(adata, 'celltype', 'celltype') assert score == 1 def test_nmi(adata_neighbors): - _, _, nmi_all = scIB.cl.opt_louvain( + _, _, nmi_all = scib.cl.opt_louvain( adata_neighbors, label_key='celltype', cluster_key='cluster', - function=scIB.me.nmi, + function=scib.me.nmi, plot=False, inplace=True, force=True, @@ -29,13 +29,13 @@ def test_nmi(adata_neighbors): def test_ari(adata_clustered): - score = scIB.me.ari(adata_clustered, group1='cluster', group2='celltype') + score = scib.me.ari(adata_clustered, group1='cluster', group2='celltype') LOGGER.info(f"score: {score}") assert 0 <= score <= 1 def test_isolated_labels_F1(adata_neighbors): - score = scIB.me.isolated_labels( + score = scib.me.isolated_labels( adata_neighbors, label_key='celltype', batch_key='batch', diff --git a/tests/metrics/test_graph_connectivity.py b/tests/metrics/test_graph_connectivity.py index ea1514a5..8f7b9730 100644 --- a/tests/metrics/test_graph_connectivity.py +++ b/tests/metrics/test_graph_connectivity.py @@ -2,6 +2,6 @@ def test_graph_connectivity(adata_neighbors): - score = scIB.me.graph_connectivity(adata_neighbors, label_key='celltype') + score = scib.me.graph_connectivity(adata_neighbors, label_key='celltype') LOGGER.info(f"score: {score}") assert score == 0.9670013350457753 diff --git a/tests/metrics/test_ilisi.py b/tests/metrics/test_ilisi.py index b596f0ae..4f2a6469 100644 --- a/tests/metrics/test_ilisi.py +++ b/tests/metrics/test_ilisi.py @@ -2,7 +2,7 @@ def test_ilisi_full(adata): - score = scIB.me.ilisi_graph( + score = scib.me.ilisi_graph( adata, batch_key='batch', scale=True, @@ -15,7 +15,7 @@ def test_ilisi_full(adata): def test_ilisi_embed(adata_neighbors): adata_neighbors.obsm['X_emb'] = adata_neighbors.obsm['X_pca'] - score = scIB.me.ilisi_graph( + score = scib.me.ilisi_graph( adata_neighbors, batch_key='batch', scale=True, @@ -26,7 +26,7 @@ def test_ilisi_embed(adata_neighbors): def test_ilisi_knn(adata_neighbors): - score = scIB.me.ilisi_graph( + score = scib.me.ilisi_graph( adata_neighbors, batch_key='batch', scale=True, diff --git a/tests/metrics/test_kbet.py b/tests/metrics/test_kbet.py index 9012366f..1371ae77 100644 --- a/tests/metrics/test_kbet.py +++ b/tests/metrics/test_kbet.py @@ -4,7 +4,7 @@ def test_kbet(adata_pca): - score = scIB.me.kBET( + score = scib.me.kBET( adata_pca, batch_key='batch', label_key='celltype', diff --git a/tests/metrics/test_pcr_metrics.py b/tests/metrics/test_pcr_metrics.py index 75a736b3..cd6d2312 100644 --- a/tests/metrics/test_pcr_metrics.py +++ b/tests/metrics/test_pcr_metrics.py @@ -2,12 +2,12 @@ def test_pc_regression(adata): - scIB.me.pcr.pc_regression(adata.X, adata.obs["batch"]) + scib.me.pcr.pc_regression(adata.X, adata.obs["batch"]) def test_pcr_batch(adata): # no PCA precomputed - score = scIB.me.pcr_comparison( + score = scib.me.pcr_comparison( adata, adata, covariate='batch', n_comps=50, @@ -18,14 +18,14 @@ def test_pcr_batch(adata): def test_pcr_batch_precomputed(adata_pca): - score = scIB.me.pcr_comparison(adata_pca, adata_pca, covariate='batch', scale=True) + score = scib.me.pcr_comparison(adata_pca, adata_pca, covariate='batch', scale=True) LOGGER.info(f"precomputed PCA: {score}") assert 0 <= score < 1e-6 def test_pcr_batch_embedding(adata): # use different embedding - score = scIB.me.pcr_comparison( + score = scib.me.pcr_comparison( adata_pre=adata, adata_post=add_embed(adata, type_='full'), covariate='batch', diff --git a/tests/metrics/test_silhouette_metrics.py b/tests/metrics/test_silhouette_metrics.py index 11f4bb95..ae969dc1 100644 --- a/tests/metrics/test_silhouette_metrics.py +++ b/tests/metrics/test_silhouette_metrics.py @@ -2,7 +2,7 @@ def test_silhouette(adata_pca): - score = scIB.me.silhouette( + score = scib.me.silhouette( adata_pca, group_key='celltype', embed='X_pca', @@ -13,7 +13,7 @@ def test_silhouette(adata_pca): def test_silhouette_batch(adata_pca): - score = scIB.me.silhouette_batch( + score = scib.me.silhouette_batch( adata_pca, batch_key='batch', group_key='celltype', @@ -26,7 +26,7 @@ def test_silhouette_batch(adata_pca): def test_isolated_labels_silhouette(adata_pca): - score = scIB.me.isolated_labels( + score = scib.me.isolated_labels( adata_pca, label_key='celltype', batch_key='batch', diff --git a/tests/metrics/test_trajectory.py b/tests/metrics/test_trajectory.py index 7cd12e4f..cb66ceeb 100644 --- a/tests/metrics/test_trajectory.py +++ b/tests/metrics/test_trajectory.py @@ -8,7 +8,7 @@ def test_trajectory(adata_neighbors): sc.tl.diffmap(adata_neighbors) sc.tl.dpt(adata_neighbors) - score = scIB.me.trajectory_conservation( + score = scib.me.trajectory_conservation( adata_pre=adata_neighbors, adata_post=adata_int, label_key='celltype', @@ -25,7 +25,7 @@ def test_trajectory_batch(adata_neighbors): sc.tl.diffmap(adata_neighbors) sc.tl.dpt(adata_neighbors) - score = scIB.me.trajectory_conservation( + score = scib.me.trajectory_conservation( adata_pre=adata_neighbors, adata_post=adata_int, label_key='celltype', diff --git a/tests/preprocessing/test_clustering.py b/tests/preprocessing/test_clustering.py index e1df5795..5e841751 100644 --- a/tests/preprocessing/test_clustering.py +++ b/tests/preprocessing/test_clustering.py @@ -2,7 +2,7 @@ def test_cluster(adata_neighbors): - _, _, score_all, clustering = scIB.cl.opt_louvain( + _, _, score_all, clustering = scib.cl.opt_louvain( adata_neighbors, label_key='celltype', cluster_key='cluster', diff --git a/tests/preprocessing/test_preprocessing.py b/tests/preprocessing/test_preprocessing.py index 333c9be1..0184a781 100644 --- a/tests/preprocessing/test_preprocessing.py +++ b/tests/preprocessing/test_preprocessing.py @@ -1,11 +1,11 @@ import scanpy as sc import numpy as np -import scIB +import scib def test_scale(): adata = sc.datasets.blobs() - scIB.pp.scale_batch(adata, 'blobs') - split = scIB.utils.splitBatches(adata, 'blobs') + scib.pp.scale_batch(adata, 'blobs') + split = scib.utils.split_batches(adata, 'blobs') for i in split: assert np.allclose(i.X.mean(0), np.zeros((0,adata.n_vars))) diff --git a/tests/requirements.txt b/tests/requirements.txt deleted file mode 100644 index 9155a309..00000000 --- a/tests/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -pytest -pytest-runner -pytest-icdiff