From 35b22e66931e3028738be9fecc463abd1a2bccac Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Tue, 2 Jul 2024 19:49:50 +0000 Subject: [PATCH 01/27] chore: Update matplotlib colormap registration --- dynamo/configuration.py | 20 ++++++++++---------- dynamo/vectorfield/utils.py | 6 +++--- requirements.txt | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dynamo/configuration.py b/dynamo/configuration.py index 907459e14..51043c7c0 100755 --- a/dynamo/configuration.py +++ b/dynamo/configuration.py @@ -432,25 +432,25 @@ def update_data_store_mode(mode: str) -> None: with warnings.catch_warnings(): warnings.simplefilter("ignore") if "zebrafish" not in matplotlib.colormaps(): - plt.register_cmap("zebrafish", zebrafish_cmap) + matplotlib.colormaps.register(name="zebrafish", cmap=zebrafish_cmap) if "fire" not in matplotlib.colormaps(): - plt.register_cmap("fire", fire_cmap) + matplotlib.colormaps.register(name="fire", cmap=fire_cmap) if "darkblue" not in matplotlib.colormaps(): - plt.register_cmap("darkblue", darkblue_cmap) + matplotlib.colormaps.register(name="darkblue", cmap=darkblue_cmap) if "darkgreen" not in matplotlib.colormaps(): - plt.register_cmap("darkgreen", darkgreen_cmap) + matplotlib.colormaps.register(name="darkgreen", cmap=darkgreen_cmap) if "darkred" not in matplotlib.colormaps(): - plt.register_cmap("darkred", darkred_cmap) + matplotlib.colormaps.register(name="darkred", cmap=darkred_cmap) if "darkpurple" not in matplotlib.colormaps(): - plt.register_cmap("darkpurple", darkpurple_cmap) + matplotlib.colormaps.register(name="darkpurple", cmap=darkpurple_cmap) if "div_blue_black_red" not in matplotlib.colormaps(): - plt.register_cmap("div_blue_black_red", div_blue_black_red_cmap) + matplotlib.colormaps.register(name="div_blue_black_red", cmap=div_blue_black_red_cmap) if "div_blue_red" not in matplotlib.colormaps(): - plt.register_cmap("div_blue_red", div_blue_red_cmap) + matplotlib.colormaps.register(name="div_blue_red", cmap=div_blue_red_cmap) if "glasbey_white" not in matplotlib.colormaps(): - plt.register_cmap("glasbey_white", glasbey_white_cmap) + matplotlib.colormaps.register(name="glasbey_white", cmap=glasbey_white_cmap) if "glasbey_dark" not in matplotlib.colormaps(): - plt.register_cmap("glasbey_dark", glasbey_dark_cmap) + matplotlib.colormaps.register(name="glasbey_dark", cmap=glasbey_dark_cmap) _themes = { diff --git a/dynamo/vectorfield/utils.py b/dynamo/vectorfield/utils.py index 15af7da91..a017c6cd8 100644 --- a/dynamo/vectorfield/utils.py +++ b/dynamo/vectorfield/utils.py @@ -916,7 +916,7 @@ def compute_curl(f_jac, X): # --------------------------------------------------------------------------------------------------- # ranking related utilies -def get_metric_gene_in_rank(mat: np.mat, genes: list, neg: bool = False) -> Tuple[np.ndarray, np.ndarray]: +def get_metric_gene_in_rank(mat: np.matrix, genes: list, neg: bool = False) -> Tuple[np.ndarray, np.ndarray]: """Calculate ranking of genes based on mean value of matrix. Args: @@ -936,7 +936,7 @@ def get_metric_gene_in_rank(mat: np.mat, genes: list, neg: bool = False) -> Tupl def get_metric_gene_in_rank_by_group( - mat: np.mat, genes: list, groups: np.array, selected_group, neg: bool = False + mat: np.matrix, genes: list, groups: np.array, selected_group, neg: bool = False ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Calculate ranking of genes based on mean value of matrix, grouped by selected group. @@ -993,7 +993,7 @@ def get_sorted_metric_genes_df(df: pd.DataFrame, genes: list, neg: bool = False) return sorted_metric, sorted_genes -def rank_vector_calculus_metrics(mat: np.mat, genes: list, group, groups: list, uniq_group: list) -> Tuple: +def rank_vector_calculus_metrics(mat: np.matrix, genes: list, group, groups: list, uniq_group: list) -> Tuple: """Calculate ranking of genes based on vector calculus metric. Args: diff --git a/requirements.txt b/requirements.txt index d93af2548..077ed667c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ scipy>=1.4.1 scikit-learn>=0.19.1,<1.5.0 anndata>=0.8.0 loompy>=3.0.5 -matplotlib>=3.5.3,<3.9.0 +matplotlib>=3.9.0 setuptools numdifftools>=0.9.39 umap-learn>=0.5.1 @@ -15,7 +15,7 @@ seaborn>=0.9.0 colorcet>=2.0.1 tqdm igraph>=0.7.1 -louvain==0.8.0 +louvain>=0.8.0 pynndescent>=0.5.2 pre-commit networkx>=2.6 From bd2e9bf27a0413fc9b15e338b8374dfc3a7d26bb Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 20:01:13 +0000 Subject: [PATCH 02/27] fix format issues --- docs/source/_ext/pdfembed.py | 23 +- docs/source/conf.py | 2 +- dynamo/__init__.py | 17 +- dynamo/configuration.py | 31 ++- dynamo/data_io.py | 19 +- dynamo/dynamo_logger.py | 21 +- dynamo/estimation/csc/velocity.py | 67 ++++- dynamo/estimation/deprecated.py | 4 +- dynamo/estimation/tsc/ODEs.py | 23 +- dynamo/estimation/tsc/estimation_kinetic.py | 7 +- dynamo/estimation/tsc/utils_moments.py | 1 - dynamo/external/hodge.py | 8 +- dynamo/external/scribe.py | 6 +- dynamo/get_version.py | 2 + dynamo/movie/fate.py | 22 +- dynamo/plot/__init__.py | 2 +- dynamo/plot/connectivity.py | 2 - dynamo/plot/dynamics.py | 253 ++++++++++++------ dynamo/plot/ezplots.py | 2 +- dynamo/plot/heatmaps.py | 4 +- dynamo/plot/networks.py | 1 - dynamo/plot/preprocess.py | 7 +- dynamo/plot/pseudotime.py | 37 ++- dynamo/plot/scVectorField.py | 20 +- dynamo/plot/scatters.py | 244 +++++++++-------- dynamo/plot/sctransform.py | 44 ++- dynamo/plot/state_graph.py | 4 +- dynamo/plot/time_series.py | 10 +- dynamo/plot/topography.py | 22 +- dynamo/plot/utils.py | 82 +++--- dynamo/plot/vector_calculus.py | 6 +- dynamo/prediction/fate.py | 8 +- dynamo/prediction/least_action_path.py | 5 +- dynamo/prediction/perturbation.py | 3 +- dynamo/prediction/trajectory.py | 18 +- dynamo/prediction/tscRNA_seq.py | 1 - dynamo/prediction/utils.py | 4 +- dynamo/preprocessing/Preprocessor.py | 9 +- dynamo/preprocessing/QC.py | 3 +- dynamo/preprocessing/__init__.py | 29 +- dynamo/preprocessing/deprecated.py | 1 + dynamo/preprocessing/dynast.py | 4 +- dynamo/preprocessing/external/__init__.py | 2 +- dynamo/preprocessing/external/integration.py | 4 +- .../external/pearson_residual_recipe.py | 1 + dynamo/preprocessing/external/sctransform.py | 1 + dynamo/preprocessing/gene_selection.py | 10 +- dynamo/preprocessing/normalization.py | 26 +- dynamo/preprocessing/utils.py | 4 +- dynamo/sample_data.py | 7 +- dynamo/shiny.py | 2 +- dynamo/shiny/lap.py | 230 ++++++++++------ dynamo/shiny/perturbation.py | 42 +-- dynamo/shiny/utils.py | 2 +- dynamo/simulation/ODE.py | 10 +- dynamo/simulation/bif_os_inclusive_sim.py | 2 + dynamo/simulation/simulate_anndata.py | 9 + dynamo/simulation/utils.py | 3 + dynamo/tools/DDRTree_graph.py | 38 +-- dynamo/tools/Markov.py | 35 +-- dynamo/tools/__init__.py | 21 +- dynamo/tools/cell_velocities.py | 15 +- dynamo/tools/clustering.py | 2 +- dynamo/tools/deprecated.py | 15 +- dynamo/tools/dimension_reduction.py | 4 +- dynamo/tools/dynamics.py | 20 +- dynamo/tools/graph_calculus.py | 33 +-- dynamo/tools/graph_operators.py | 2 +- dynamo/tools/growth.py | 9 +- dynamo/tools/markers.py | 35 ++- dynamo/tools/metric_velocity.py | 1 - dynamo/tools/moments.py | 14 +- dynamo/tools/pseudotime.py | 55 ++-- dynamo/tools/sampling.py | 2 +- dynamo/tools/utils.py | 75 ++++-- dynamo/tools/velocyto_scvelo.py | 27 +- dynamo/utils.py | 1 + dynamo/vectorfield/Ao.py | 3 +- dynamo/vectorfield/Bhattacharya.py | 24 +- dynamo/vectorfield/Tang.py | 1 + dynamo/vectorfield/VectorField.py | 13 +- dynamo/vectorfield/__init__.py | 11 +- dynamo/vectorfield/cell_vectors.py | 2 +- dynamo/vectorfield/clustering.py | 3 +- dynamo/vectorfield/scPotential.py | 62 +++-- dynamo/vectorfield/scVectorField.py | 8 +- dynamo/vectorfield/stochastic_process.py | 6 +- dynamo/vectorfield/topography.py | 2 +- dynamo/vectorfield/vector_calculus.py | 19 +- setup.py | 3 +- tests/test_data_io.py | 3 +- tests/test_pipeline.py | 6 +- tests/test_pl.py | 44 ++- tests/test_prediction.py | 26 +- tests/test_preprocess.py | 53 ++-- tests/test_tl.py | 31 ++- tests/test_vf.py | 40 ++- 97 files changed, 1349 insertions(+), 848 deletions(-) diff --git a/docs/source/_ext/pdfembed.py b/docs/source/_ext/pdfembed.py index e24a19361..4b0c12da6 100644 --- a/docs/source/_ext/pdfembed.py +++ b/docs/source/_ext/pdfembed.py @@ -3,6 +3,7 @@ from docutils import nodes + def pdfembed_html(pdfembed_specs): """ Build the iframe code for the pdf file, @@ -20,27 +21,31 @@ def pdfembed_html(pdfembed_specs): align="%s"> """ - return ( html_base_code % (pdfembed_specs['src' ], - pdfembed_specs['height'], - pdfembed_specs['width' ], - pdfembed_specs['align' ]) ) + return html_base_code % ( + pdfembed_specs["src"], + pdfembed_specs["height"], + pdfembed_specs["width"], + pdfembed_specs["align"], + ) + def pdfembed_role(typ, rawtext, text, lineno, inliner, options={}, content=[]): """ Get iframe specifications and generate the associate HTML code for the pdf iframe. """ # parse and init variables - text = text.replace(' ', '') + text = text.replace(" ", "") pdfembed_specs = {} # read specs - for component in text.split(','): - pdfembed_specs[component.split(':')[0]] = component.split(':')[1] + for component in text.split(","): + pdfembed_specs[component.split(":")[0]] = component.split(":")[1] # build node from pdf iframe html code - node = nodes.raw('', pdfembed_html(pdfembed_specs), format='html') + node = nodes.raw("", pdfembed_html(pdfembed_specs), format="html") return [node], [] + def setup(app): """ Set up the app with the extension function """ - app.add_role('pdfembed', pdfembed_role) \ No newline at end of file + app.add_role("pdfembed", pdfembed_role) diff --git a/docs/source/conf.py b/docs/source/conf.py index 230c741e0..e0df949f6 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -114,7 +114,7 @@ "sphinxcontrib.bibtex", "sphinx_gallery.load_style", # pdf embed - 'pdfembed', + "pdfembed", ] # Mappings for sphinx.ext.intersphinx. Projects have to have Sphinx-generated doc! (.inv file) diff --git a/dynamo/__init__.py b/dynamo/__init__.py index ecdfe4471..d5735723c 100755 --- a/dynamo/__init__.py +++ b/dynamo/__init__.py @@ -1,7 +1,7 @@ """Mapping Vector Field of Single Cells """ -from .get_version import get_version, get_dynamo_version +from .get_version import get_dynamo_version, get_version __version__ = get_version(__file__) del get_version @@ -10,20 +10,7 @@ # # __version__ = get_dynamo_version() -from . import pp -from . import est -from . import tl -from . import vf -from . import pd -from . import pl -from . import mv -from . import shiny -from . import sim -from .data_io import * -from . import sample_data -from . import configuration -from . import ext - +from . import configuration, est, ext, mv, pd, pl, pp, sample_data, shiny, sim, tl, vf from .data_io import * from .dynamo_logger import ( Logger, diff --git a/dynamo/configuration.py b/dynamo/configuration.py index 51043c7c0..033c5c484 100755 --- a/dynamo/configuration.py +++ b/dynamo/configuration.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, List, Generator, Optional, Tuple, Union +from typing import Any, Generator, List, Optional, Tuple, Union import colorcet import matplotlib @@ -15,6 +15,7 @@ class DynamoAdataKeyManager: """A class to manage the keys used in anndata object for dynamo.""" + VAR_GENE_MEAN_KEY = "pp_gene_mean" VAR_GENE_VAR_KEY = "pp_gene_variance" VAR_GENE_HIGHLY_VARIABLE_KEY = "gene_highly_variable" @@ -41,8 +42,8 @@ class DynamoAdataKeyManager: RAW = "raw" def _select_layer_cell_chunked_data( - mat: np.ndarray, - chunk_size: int, + mat: np.ndarray, + chunk_size: int, ) -> Generator: """Select layer data in cell chunks based on chunk_size.""" start = 0 @@ -55,8 +56,8 @@ def _select_layer_cell_chunked_data( yield (mat[start:n, :], start, n) def _select_layer_gene_chunked_data( - mat: np.ndarray, - chunk_size: int, + mat: np.ndarray, + chunk_size: int, ) -> Generator: """Select layer data in gene chunks based on chunk_size.""" start = 0 @@ -131,8 +132,11 @@ def select_layer_chunked_data( elif layer == DynamoAdataKeyManager.RAW: return DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.raw.X, chunk_size=chunk_size) elif layer == DynamoAdataKeyManager.PROTEIN_LAYER: - return DynamoAdataKeyManager._select_layer_cell_chunked_data( - adata.obsm["protein"], chunk_size=chunk_size) if "protein" in adata.obsm_keys() else None + return ( + DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.obsm["protein"], chunk_size=chunk_size) + if "protein" in adata.obsm_keys() + else None + ) else: return DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.layers[layer], chunk_size=chunk_size) elif chunk_mode == "gene": @@ -141,8 +145,11 @@ def select_layer_chunked_data( elif layer == DynamoAdataKeyManager.RAW: return DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.raw.X, chunk_size=chunk_size) elif layer == DynamoAdataKeyManager.PROTEIN_LAYER: - return DynamoAdataKeyManager._select_layer_gene_chunked_data( - adata.obsm["protein"], chunk_size=chunk_size) if "protein" in adata.obsm_keys() else None + return ( + DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.obsm["protein"], chunk_size=chunk_size) + if "protein" in adata.obsm_keys() + else None + ) else: return DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.layers[layer], chunk_size=chunk_size) else: @@ -172,7 +179,10 @@ def check_if_layer_exist(adata: AnnData, layer: str) -> bool: return layer in adata.layers def get_available_layer_keys( - adata: AnnData, layers: str = "all", remove_pp_layers: bool = True, include_protein: bool = True, + adata: AnnData, + layers: str = "all", + remove_pp_layers: bool = True, + include_protein: bool = True, ) -> List[str]: """Get the list of available layers' keys. If `layers` is set to all, return a list of all available layers; if `layers` is set to a list, then the intersetion of available layers and `layers` will be returned.""" @@ -278,6 +288,7 @@ def aggregate_layers_into_total( class DynamoVisConfig: """Dynamo visualization config class holding static variables to change behaviors of functions globally.""" + def set_default_mode(background="white"): """Set the default mode for dynamo visualization.""" set_figure_params("dynamo", background=background) diff --git a/dynamo/data_io.py b/dynamo/data_io.py index b868e3285..2822ce590 100755 --- a/dynamo/data_io.py +++ b/dynamo/data_io.py @@ -78,7 +78,11 @@ def convert2float(adata: AnnData, columns: List, var: bool = False) -> None: def load_NASC_seq( - dir: str, type: str = "TPM", delimiter: str = "_", colnames: Optional[List] = None, dropna: bool = False, + dir: str, + type: str = "TPM", + delimiter: str = "_", + colnames: Optional[List] = None, + dropna: bool = False, ) -> AnnData: """Function to create an anndata object from NASC-seq pipeline. @@ -323,7 +327,10 @@ def cleanup(adata: AnnData, del_prediction: bool = False, del_2nd_moments: bool def export_rank_xlsx( - adata: AnnData, path: str = "rank_info.xlsx", ext: str = "excel", rank_prefix: str = "rank", + adata: AnnData, + path: str = "rank_info.xlsx", + ext: str = "excel", + rank_prefix: str = "rank", ) -> None: import pandas as pd @@ -373,15 +380,16 @@ def export_h5ad(adata: AnnData, path: str = "data/processed_data.h5ad") -> None: for i in fate_keys: if i is not None: if "prediction" in adata.uns[i].keys(): - adata.uns[i]["prediction"] = {str(index): array for index, array in - enumerate(adata.uns[i]["prediction"])} + adata.uns[i]["prediction"] = { + str(index): array for index, array in enumerate(adata.uns[i]["prediction"]) + } if "t" in adata.uns[i].keys(): adata.uns[i]["t"] = {str(index): array for index, array in enumerate(adata.uns[i]["t"])} adata.write_h5ad(path) -def import_h5ad(path: str ="data/processed_data.h5ad") -> AnnData: +def import_h5ad(path: str = "data/processed_data.h5ad") -> AnnData: """Import a Dynamo h5ad object into anndata.""" adata = read_h5ad(path) @@ -397,4 +405,3 @@ def import_h5ad(path: str ="data/processed_data.h5ad") -> AnnData: adata.uns[i]["t"] = [adata.uns[i]["t"][index] for index in adata.uns[i]["t"]] return adata - diff --git a/dynamo/dynamo_logger.py b/dynamo/dynamo_logger.py index 9d766a469..0bcb776b6 100644 --- a/dynamo/dynamo_logger.py +++ b/dynamo/dynamo_logger.py @@ -1,10 +1,9 @@ -from typing import Iterable, Optional - import functools import logging import sys import time from contextlib import contextmanager +from typing import Iterable, Optional def silence_logger(name: str) -> None: @@ -176,7 +175,13 @@ def error(self, message: str, indent_level: int = 1, *args, **kwargs) -> None: return self.logger.error(message, *args, **kwargs) def info_insert_adata( - self, key: str, adata_attr: str = "obsm", log_level: int = logging.NOTSET, indent_level: int = 1, *args, **kwargs + self, + key: str, + adata_attr: str = "obsm", + log_level: int = logging.NOTSET, + indent_level: int = 1, + *args, + **kwargs, ) -> None: """Log a message for inserting data into an AnnData object.""" message = " %s to %s in AnnData Object." % (key, adata_attr) @@ -314,7 +319,10 @@ def get_temp_timer_logger() -> Logger: @staticmethod def progress_logger( - generator: Iterable, logger: Optional[Logger] = None, progress_name: str = "", indent_level: int = 1, + generator: Iterable, + logger: Optional[Logger] = None, + progress_name: str = "", + indent_level: int = 1, ) -> Iterable: """A generator that logs the progress of another generator.""" if logger is None: @@ -362,7 +370,10 @@ def main_critical(message: str, indent_level: int = 1) -> None: def main_tqdm( - generator: Iterable, desc: str = "", indent_level: int = 1, logger: LoggerManager = LoggerManager().main_logger, + generator: Iterable, + desc: str = "", + indent_level: int = 1, + logger: LoggerManager = LoggerManager().main_logger, ) -> Iterable: """a TQDM style wrapper for logging something like a loop. diff --git a/dynamo/estimation/csc/velocity.py b/dynamo/estimation/csc/velocity.py index 7cf27afac..3130faf3c 100755 --- a/dynamo/estimation/csc/velocity.py +++ b/dynamo/estimation/csc/velocity.py @@ -23,6 +23,7 @@ class Velocity: """The class that computes RNA/protein velocity given unknown parameters.""" + def __init__( self, alpha: Optional[np.ndarray] = None, @@ -163,7 +164,9 @@ def vel_u( return V - def vel_s(self, U: Union[csr_matrix, np.ndarray], S: Union[csr_matrix, np.ndarray]) -> Union[csr_matrix, np.ndarray]: + def vel_s( + self, U: Union[csr_matrix, np.ndarray], S: Union[csr_matrix, np.ndarray] + ) -> Union[csr_matrix, np.ndarray]: """Calculate the unspliced mRNA velocity. Args: @@ -227,7 +230,9 @@ def vel_s(self, U: Union[csr_matrix, np.ndarray], S: Union[csr_matrix, np.ndarra V = np.nan return V - def vel_p(self, S: Union[csr_matrix, np.ndarray], P: Union[csr_matrix, np.ndarray]) -> Union[csr_matrix, np.ndarray]: + def vel_p( + self, S: Union[csr_matrix, np.ndarray], P: Union[csr_matrix, np.ndarray] + ) -> Union[csr_matrix, np.ndarray]: """Calculate the protein velocity. Args: @@ -657,7 +662,14 @@ def fit( bs, bf, ) = zip(*res) - (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( + ( + gamma, + gamma_intercept, + gamma_r2, + gamma_logLL, + bs, + bf, + ) = ( np.array(gamma), np.array(gamma_intercept), np.array(gamma_r2), @@ -743,7 +755,14 @@ def fit( bs, bf, ) = zip(*res) - (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( + ( + gamma, + gamma_intercept, + gamma_r2, + gamma_logLL, + bs, + bf, + ) = ( np.array(gamma), np.array(gamma_intercept), np.array(gamma_r2), @@ -831,7 +850,11 @@ def fit( uu_m, uu_v, _ = calc_12_mom_labeling(self.data["uu"], self.t) if cores == 1: for i in tqdm(range(n_genes), desc="estimating alpha"): - (alpha[i], alpha_b[i], alpha_r2[i],) = fit_alpha_degradation( + ( + alpha[i], + alpha_b[i], + alpha_r2[i], + ) = fit_alpha_degradation( t_uniq, uu_m[i], self.parameters["gamma"][i], @@ -1023,7 +1046,10 @@ def fit( total, ), ) - (self.aux_param["total0"], self.parameters["gamma"],) = ( + ( + self.aux_param["total0"], + self.parameters["gamma"], + ) = ( total0, gamma, ) @@ -1060,7 +1086,14 @@ def fit( if issparse(self.data["ul"]) else np.zeros_like(self.data["ul"].shape) ) - (t_uniq, gamma, gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( + ( + t_uniq, + gamma, + gamma_k, + gamma_intercept, + gamma_r2, + gamma_logLL, + ) = ( np.unique(self.t), np.zeros(n_genes), np.zeros(n_genes), @@ -1108,7 +1141,12 @@ def fit( _, gamma_logLL, ) = zip(*res1) - (gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( + ( + gamma_k, + gamma_intercept, + gamma_r2, + gamma_logLL, + ) = ( np.array(gamma_k), np.array(gamma_intercept), np.array(gamma_r2), @@ -1462,7 +1500,11 @@ def fit( # gamma_3 = solve_gamma(np.max(self.t), self.data['uu'][i, self.t == np.max(self.t)], tmp) # sci-fate gamma[i] = gamma_2 # print('Steady state, stimulation, sci-fate like gamma values are ', gamma_1, '; ', gamma_2, '; ', gamma_3) - (self.parameters["gamma"], self.aux_param["U0"], self.parameters["beta"],) = ( + ( + self.parameters["gamma"], + self.aux_param["U0"], + self.parameters["beta"], + ) = ( gamma, U, np.ones(gamma.shape), @@ -1479,7 +1521,12 @@ def fit( if self.asspt_prot.lower() == "ss" and n_genes > 0: self.parameters["eta"] = np.ones(n_genes) - (delta, delta_intercept, delta_r2, delta_logLL,) = ( + ( + delta, + delta_intercept, + delta_r2, + delta_logLL, + ) = ( np.zeros(n_genes), np.zeros(n_genes), np.zeros(n_genes), diff --git a/dynamo/estimation/deprecated.py b/dynamo/estimation/deprecated.py index 23d629206..373b0968f 100644 --- a/dynamo/estimation/deprecated.py +++ b/dynamo/estimation/deprecated.py @@ -1,12 +1,12 @@ -import warnings import functools +import warnings from numba import float32 # import the types from numpy import * from scipy.optimize import least_squares -from .tsc.utils_moments import moments from ..tools.sampling import lhsclassic +from .tsc.utils_moments import moments class estimation: diff --git a/dynamo/estimation/tsc/ODEs.py b/dynamo/estimation/tsc/ODEs.py index 9dff4d051..0fe521811 100755 --- a/dynamo/estimation/tsc/ODEs.py +++ b/dynamo/estimation/tsc/ODEs.py @@ -9,6 +9,7 @@ class LinearODE: """A general class for linear odes.""" + def __init__(self, n_species: int, x0: Optional[np.ndarray] = None): """Initialize the LinearODE object. @@ -137,6 +138,7 @@ def integrate_matrix(self, t: np.ndarray, x0: Optional[np.ndarray] = None) -> np class MixtureModels: """The base class for mixture models.""" + def __init__(self, models: LinearODE, param_distributor: List): """Initialize the MixtureModels class. @@ -158,7 +160,9 @@ def __init__(self, models: LinearODE, param_distributor: List): self.methods = ["numerical", "matrix"] self.default_method = "matrix" - def integrate(self, t: np.ndarray, x0: Optional[np.ndarray] = None, method: Optional[Union[str, List]] = None) -> None: + def integrate( + self, t: np.ndarray, x0: Optional[np.ndarray] = None, method: Optional[Union[str, List]] = None + ) -> None: """Integrate with time values for all models. Args: @@ -224,8 +228,9 @@ def set_params(self, *params: Tuple) -> None: class LambdaModels_NoSwitching(MixtureModels): """Linear ODEs for the lambda mixture model. The order of params is: - parameter order: alpha, lambda, (beta), gamma - distributor order: alpha_1, alpha_2, (beta), gamma""" + parameter order: alpha, lambda, (beta), gamma + distributor order: alpha_1, alpha_2, (beta), gamma""" + def __init__(self, model1: LinearODE, model2: LinearODE): """Initialize the LambdaModels_NoSwitching class. @@ -261,6 +266,7 @@ def param_mixer(self, *params) -> np.ndarray: class Moments(LinearODE): """The class simulates the dynamics of first and second moments of a transcription-splicing system with promoter switching.""" + def __init__( self, a: Optional[np.ndarray] = None, @@ -527,6 +533,7 @@ def computeKnp(self) -> Tuple[np.ndarray, np.ndarray]: class Moments_Nosplicing(LinearODE): """The class simulates the dynamics of first and second moments of a transcription-splicing system with promoter switching.""" + def __init__( self, a: Optional[np.ndarray] = None, @@ -534,7 +541,7 @@ def __init__( alpha_a: Optional[np.ndarray] = None, alpha_i: Optional[np.ndarray] = None, gamma: Optional[np.ndarray] = None, - x0: Optional[np.ndarray] = None + x0: Optional[np.ndarray] = None, ): """Initialize the Moments_Nosplicing object. @@ -602,7 +609,9 @@ def fbar(self, x_a: np.ndarray, x_i: np.ndarray) -> np.ndarray: """ return self.b / (self.a + self.b) * x_a + self.a / (self.a + self.b) * x_i - def set_params(self, a: np.ndarray, b: np.ndarray, alpha_a: np.ndarray, alpha_i: np.ndarray, gamma: np.ndarray) -> None: + def set_params( + self, a: np.ndarray, b: np.ndarray, alpha_a: np.ndarray, alpha_i: np.ndarray, gamma: np.ndarray + ) -> None: """Set the parameters. Args: @@ -686,6 +695,7 @@ def computeKnp(self) -> Tuple[np.ndarray, np.ndarray]: class Moments_NoSwitching(LinearODE): """The class simulates the dynamics of first and second moments of a transcription-splicing system without promoter switching.""" + def __init__( self, alpha: Optional[np.ndarray] = None, @@ -875,6 +885,7 @@ def computeKnp(self) -> Tuple[np.ndarray, np.ndarray]: class Moments_NoSwitchingNoSplicing(LinearODE): """The class simulates the dynamics of first and second moments of a transcription system without promoter switching.""" + def __init__( self, alpha: Optional[np.ndarray] = None, @@ -997,6 +1008,7 @@ def computeKnp(self) -> Tuple[np.ndarray, np.ndarray]: class Deterministic(LinearODE): """This class simulates the deterministic dynamics of a transcription-splicing system.""" + def __init__( self, alpha: Optional[np.ndarray] = None, @@ -1131,6 +1143,7 @@ def integrate_analytical(self, t: np.ndarray, x0: Optional[np.ndarray] = None) - class Deterministic_NoSplicing(LinearODE): """The class simulates the deterministic dynamics of a transcription-splicing system.""" + def __init__( self, alpha: Optional[np.ndarray] = None, diff --git a/dynamo/estimation/tsc/estimation_kinetic.py b/dynamo/estimation/tsc/estimation_kinetic.py index 7ca198737..d973700dc 100755 --- a/dynamo/estimation/tsc/estimation_kinetic.py +++ b/dynamo/estimation/tsc/estimation_kinetic.py @@ -1,5 +1,5 @@ -from typing import Dict, List, Optional, Tuple, Union import warnings +from typing import Dict, List, Optional, Tuple, Union import numpy as np from scipy.optimize import least_squares @@ -350,6 +350,7 @@ def test_chi2( class Estimation_Degradation(kinetic_estimation): """The base parameters, estimation class for degradation experiments.""" + def __init__(self, ranges: np.ndarray, x0: np.ndarray, simulator: LinearODE): """Initialize the Estimation_Degradation object. @@ -1381,6 +1382,7 @@ def export_model(self, reinstantiate: bool = True) -> Union[LambdaModels_NoSwitc class Estimation_KineticChase(kinetic_estimation): """An estimation class for kinetic chase experiment.""" + def __init__( self, alpha: Optional[np.ndarray] = None, @@ -1502,6 +1504,7 @@ class GoodnessOfFit: This class provides methods for assessing the quality of predictions, using various metrics including Gaussian likelihood, Gaussian log-likelihood, and mean squared deviation. """ + def __init__( self, simulator: LinearODE, @@ -1600,7 +1603,7 @@ def normalize( return x_data_norm, x_model_norm def calc_gaussian_likelihood(self) -> float: - """ Calculate the Gaussian likelihood between model predictions and observations. + """Calculate the Gaussian likelihood between model predictions and observations. Returns: Gaussian likelihood value. diff --git a/dynamo/estimation/tsc/utils_moments.py b/dynamo/estimation/tsc/utils_moments.py index 00ced8c36..ae8003b68 100755 --- a/dynamo/estimation/tsc/utils_moments.py +++ b/dynamo/estimation/tsc/utils_moments.py @@ -10,7 +10,6 @@ from numpy import * from scipy.integrate import odeint - spec = [ ("a", float32), ("b", float32), diff --git a/dynamo/external/hodge.py b/dynamo/external/hodge.py index b30fa7ecf..4eb359fbd 100644 --- a/dynamo/external/hodge.py +++ b/dynamo/external/hodge.py @@ -21,7 +21,7 @@ div, potential, )""" -from ..tools.connectivity import generate_neighbor_keys, check_and_recompute_neighbors +from ..tools.connectivity import check_and_recompute_neighbors, generate_neighbor_keys def ddhodge( @@ -218,7 +218,11 @@ def func(x): W.dot(ddhodge_div), W.dot(potential_), ) - (adata.obs[prefix + "ddhodge_sampled"], adata.obs[prefix + "ddhodge_div"], adata.obs[prefix + "potential"],) = ( + ( + adata.obs[prefix + "ddhodge_sampled"], + adata.obs[prefix + "ddhodge_div"], + adata.obs[prefix + "potential"], + ) = ( False, 0, 0, diff --git a/dynamo/external/scribe.py b/dynamo/external/scribe.py index 587140d2f..41831f98a 100644 --- a/dynamo/external/scribe.py +++ b/dynamo/external/scribe.py @@ -86,11 +86,7 @@ def scribe( str_format = ( "upper" if adata.var_names[0].isupper() - else "lower" - if adata.var_names[0].islower() - else "title" - if adata.var_names[0].istitle() - else "other" + else "lower" if adata.var_names[0].islower() else "title" if adata.var_names[0].istitle() else "other" ) motifAnnotations_hgnc = pd.read_csv(motif_ref, sep="\t") diff --git a/dynamo/get_version.py b/dynamo/get_version.py index 9ce2daf1b..516f652e3 100755 --- a/dynamo/get_version.py +++ b/dynamo/get_version.py @@ -2,6 +2,7 @@ A minimalistic version helper in the spirit of versioneer, that is able to run without build step using pkg_resources. Developed by P Angerer, see https://github.com/flying-sheep/get_version. """ + import logging import os import re @@ -28,6 +29,7 @@ def match_groups(regex: str, target: str) -> List[str]: class Version(NamedTuple): """A parsed version string.""" + release: str dev: Optional[str] labels: List[str] diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 814b654f7..76c263549 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -25,6 +25,7 @@ class BaseAnim: vector field. Thus, it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell fate commitment in action. """ + def __init__( self, adata: AnnData, @@ -258,7 +259,9 @@ def __init__( self.fig = fig self.ax = ax - (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if dims is not None and len(dims) == 3 else self.ax.plot([], [], "ro") + (self.ln,) = ( + self.ax.plot([], [], "ro", zs=[]) if dims is not None and len(dims) == 3 else self.ax.plot([], [], "ro") + ) def init_background(self): """Initialize background of the animation.""" @@ -305,6 +308,7 @@ def update(self, frame): class StreamFuncAnim3D(StreamFuncAnim): """The class of 3D animation instance for matplotlib FuncAnimation function.""" + def update(self, frame): """The function to call at each frame. Update the position of the line object in the animation.""" init_states = self.init_states @@ -457,6 +461,7 @@ def animate_fates( class PyvistaAnim(BaseAnim): """The class for animating cell fate commitment prediction with pyvista.""" + def __init__( self, adata: AnnData, @@ -567,6 +572,7 @@ def animate(self): class PlotlyAnim(BaseAnim): """The class for animating cell fate commitment prediction with plotly.""" + def __init__( self, adata: AnnData, @@ -665,11 +671,10 @@ def animate(self): fig = go.Figure( data=self.pl, - layout=go.Layout(title="Moving Frenet Frame Along a Planar Curve", - updatemenus=[dict(type="buttons", - buttons=[dict(label="Play", - method="animate", - args=[None])])]), + layout=go.Layout( + title="Moving Frenet Frame Along a Planar Curve", + updatemenus=[dict(type="buttons", buttons=[dict(label="Play", method="animate", args=[None])])], + ), frames=[ go.Frame( data=[ @@ -684,8 +689,9 @@ def animate(self): ), ) ] - ) for k in range(1, self.n_steps) - ] + ) + for k in range(1, self.n_steps) + ], ) fig.show() diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index 585de3bc0..8001f95da 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -34,7 +34,7 @@ from .pseudotime import plot_dim_reduced_direct_graph from .scatters import scatters, scatters_interactive from .scPotential import show_landscape -from .sctransform import sctransform_plot_fit, plot_residual_var +from .sctransform import plot_residual_var, sctransform_plot_fit from .scVectorField import ( # , plot_LIC_gray cell_wise_vectors, cell_wise_vectors_3d, diff --git a/dynamo/plot/connectivity.py b/dynamo/plot/connectivity.py index ca7b59025..48d73509d 100755 --- a/dynamo/plot/connectivity.py +++ b/dynamo/plot/connectivity.py @@ -10,7 +10,6 @@ 6. others """ - from typing import Any, Dict, List, Optional, Union try: @@ -564,7 +563,6 @@ def nneighbors( return save_show_ret("nneighbors", save_show_or_return, save_kwargs, plt.gcf()) - def plot_connectivity( adata: AnnData, graph: Union[csr_matrix, csc_matrix, np.ndarray], diff --git a/dynamo/plot/dynamics.py b/dynamo/plot/dynamics.py index 676aa6191..6f10babf5 100755 --- a/dynamo/plot/dynamics.py +++ b/dynamo/plot/dynamics.py @@ -15,7 +15,15 @@ from ..dynamo_logger import main_warning from ..estimation.csc.velocity import sol_s, sol_u, solve_first_order_deg from ..estimation.tsc.utils_moments import moments -from ..tools.utils import get_mapper, get_valid_bools, get_vel_params, index_gene, log1p_, update_dict, update_vel_params +from ..tools.utils import ( + get_mapper, + get_valid_bools, + get_vel_params, + index_gene, + log1p_, + update_dict, + update_vel_params, +) from .scatters import scatters from .utils import ( _datashade_points, @@ -537,23 +545,29 @@ def phase_portraits( discrete_cmap, discrete_color_key_cmap, discrete_background = ( _themes[discrete_theme]["cmap"] if discrete_continous_div_cmap is None else discrete_continous_div_cmap[0], - _themes[discrete_theme]["color_key_cmap"] - if discrete_continous_div_color_key_cmap is None - else discrete_continous_div_color_key_cmap[0], + ( + _themes[discrete_theme]["color_key_cmap"] + if discrete_continous_div_color_key_cmap is None + else discrete_continous_div_color_key_cmap[0] + ), _themes[discrete_theme]["background"], ) continous_cmap, continous_color_key_cmap, continous_background = ( _themes[continous_theme]["cmap"] if discrete_continous_div_cmap is None else discrete_continous_div_cmap[1], - _themes[continous_theme]["color_key_cmap"] - if discrete_continous_div_color_key_cmap is None - else discrete_continous_div_color_key_cmap[1], + ( + _themes[continous_theme]["color_key_cmap"] + if discrete_continous_div_color_key_cmap is None + else discrete_continous_div_color_key_cmap[1] + ), _themes[continous_theme]["background"], ) divergent_cmap, divergent_color_key_cmap, divergent_background = ( _themes[divergent_theme]["cmap"] if discrete_continous_div_cmap is None else discrete_continous_div_cmap[2], - _themes[divergent_theme]["color_key_cmap"] - if discrete_continous_div_color_key_cmap is None - else discrete_continous_div_color_key_cmap[2], + ( + _themes[divergent_theme]["color_key_cmap"] + if discrete_continous_div_color_key_cmap is None + else discrete_continous_div_color_key_cmap[2] + ), _themes[divergent_theme]["background"], ) @@ -585,9 +599,11 @@ def phase_portraits( if cur_pd.color.isna().all(): if cur_pd.shape[0] <= figsize[0] * figsize[1] * 1000000: ax1, color = _matplotlib_points( - cur_pd.loc[:, ["S", "U"]].values - if vkey == "velocity_S" - else cur_pd.loc[:, ["total", "new"]].values, + ( + cur_pd.loc[:, ["S", "U"]].values + if vkey == "velocity_S" + else cur_pd.loc[:, ["total", "new"]].values + ), ax=ax1, labels=None, values=cur_pd.loc[:, "expression"].values, @@ -603,9 +619,11 @@ def phase_portraits( ) else: ax1, color = _datashade_points( - cur_pd.loc[:, ["S", "U"]].values - if vkey == "velocity_S" - else cur_pd.loc[:, ["total", "new"]].values, + ( + cur_pd.loc[:, ["S", "U"]].values + if vkey == "velocity_S" + else cur_pd.loc[:, ["total", "new"]].values + ), ax=ax1, labels=None, values=cur_pd.loc[:, "expression"].values, @@ -622,9 +640,11 @@ def phase_portraits( else: if cur_pd.shape[0] <= figsize[0] * figsize[1] * 1000000: ax1, color = _matplotlib_points( - cur_pd.loc[:, ["S", "U"]].values - if vkey == "velocity_S" - else cur_pd.loc[:, ["total", "new"]].values, + ( + cur_pd.loc[:, ["S", "U"]].values + if vkey == "velocity_S" + else cur_pd.loc[:, ["total", "new"]].values + ), ax=ax1, labels=cur_pd.loc[:, "color"], values=None, @@ -640,9 +660,11 @@ def phase_portraits( ) else: ax1, color = _datashade_points( - cur_pd.loc[:, ["S", "U"]].values - if vkey == "velocity_S" - else cur_pd.loc[:, ["total", "new"]].values, + ( + cur_pd.loc[:, ["S", "U"]].values + if vkey == "velocity_S" + else cur_pd.loc[:, ["total", "new"]].values + ), ax=ax1, labels=cur_pd.loc[:, "color"], values=None, @@ -1607,25 +1629,44 @@ def dynamics( mom.integrate(t) mom_data = mom.get_all_central_moments() if has_splicing else mom.get_nosplice_central_moments() if true_param_prefix is not None: - (true_a, true_b, true_alpha_a, true_alpha_i, true_beta, true_gamma,) = ( - vel_params_df.loc[gene_name, true_param_prefix + "a"] - if true_param_prefix + "a" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "b"] - if true_param_prefix + "b" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "alpha_a"] - if true_param_prefix + "alpha_a" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "alpha_i"] - if true_param_prefix + "alpha_i" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "beta"] - if true_param_prefix + "beta" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + true_a, + true_b, + true_alpha_a, + true_alpha_i, + true_beta, + true_gamma, + ) = ( + ( + vel_params_df.loc[gene_name, true_param_prefix + "a"] + if true_param_prefix + "a" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "b"] + if true_param_prefix + "b" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha_a"] + if true_param_prefix + "alpha_a" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha_i"] + if true_param_prefix + "alpha_i" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "beta"] + if true_param_prefix + "beta" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_params = { @@ -1842,7 +1883,15 @@ def dynamics( np.log1p(sl), ) - (alpha, beta, gamma, ul0, sl0, uu0, half_life,) = vel_params_df.loc[ + ( + alpha, + beta, + gamma, + ul0, + sl0, + uu0, + half_life, + ) = vel_params_df.loc[ gene_name, [ prefix + "alpha", @@ -1868,15 +1917,21 @@ def dynamics( l = sol_s(t, sl0, ul0, 0, beta, gamma) if true_param_prefix is not None: true_alpha, true_beta, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "beta"] - if true_param_prefix + "beta" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "beta"] + if true_param_prefix + "beta" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_u = sol_u(t, uu0, true_alpha, true_beta) @@ -1931,12 +1986,16 @@ def dynamics( title_ = ["(unlabeled)", "(labeled)"] if true_param_prefix is not None: true_alpha, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_u = sol_u(t, uu0, true_alpha, true_gamma) true_w = sol_u(t, ul0, 0, true_gamma) @@ -2147,15 +2206,21 @@ def dynamics( l = sol_s(t, 0, 0, alpha, beta, gamma) if true_param_prefix is not None: true_alpha, true_beta, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "beta"] - if true_param_prefix + "beta" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "beta"] + if true_param_prefix + "beta" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_u = sol_u(t, uu0, 0, true_beta) true_s = sol_s(t, su0, uu0, 0, true_beta, true_gamma) @@ -2207,12 +2272,16 @@ def dynamics( l = None # sol_s(t, 0, 0, alpha, 1, gamma) if true_param_prefix is not None: true_alpha, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_u = sol_u(t, uu0, 0, true_gamma) true_w = sol_u(t, 0, true_alpha, true_gamma) @@ -2354,15 +2423,21 @@ def dynamics( L = sl + ul if true_param_prefix is not None: true_alpha, true_beta, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "beta"] - if true_param_prefix + "beta" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "beta"] + if true_param_prefix + "beta" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_l = sol_u(t, 0, true_alpha, true_beta) + sol_s( t, 0, 0, true_alpha, true_beta, true_gamma @@ -2401,12 +2476,16 @@ def dynamics( L = ul if true_param_prefix is not None: true_alpha, true_gamma = ( - vel_params_df.loc[gene_name, true_param_prefix + "alpha"] - if true_param_prefix + "alpha" in vel_params_df.columns - else -np.inf, - vel_params_df.loc[gene_name, true_param_prefix + "gamma"] - if true_param_prefix + "gamma" in vel_params_df.columns - else -np.inf, + ( + vel_params_df.loc[gene_name, true_param_prefix + "alpha"] + if true_param_prefix + "alpha" in vel_params_df.columns + else -np.inf + ), + ( + vel_params_df.loc[gene_name, true_param_prefix + "gamma"] + if true_param_prefix + "gamma" in vel_params_df.columns + else -np.inf + ), ) true_l = sol_u(t, 0, true_alpha, true_gamma) # sol_s(t, 0, 0, alpha, 1, gamma) diff --git a/dynamo/plot/ezplots.py b/dynamo/plot/ezplots.py index afe2a3ac6..d0a71aebf 100644 --- a/dynamo/plot/ezplots.py +++ b/dynamo/plot/ezplots.py @@ -13,9 +13,9 @@ from matplotlib.colors import ListedColormap from matplotlib.figure import Figure -from .utils import save_show_ret from ..tools.utils import flatten, index_gene, velocity_on_grid from ..utils import areinstance, isarray +from .utils import save_show_ret # from ..tools.Markov import smoothen_drift_on_grid diff --git a/dynamo/plot/heatmaps.py b/dynamo/plot/heatmaps.py index ccad56b78..b6aee78e0 100644 --- a/dynamo/plot/heatmaps.py +++ b/dynamo/plot/heatmaps.py @@ -362,8 +362,8 @@ def response( x, y_ori = x[valid_ids], y_ori[valid_ids] if log: - x, y_ori = x if sum(x < 0) else np.log(np.array(x) + 1), y_ori if sum(y_ori) < 0 else np.log( - np.array(y_ori) + 1 + x, y_ori = x if sum(x < 0) else np.log(np.array(x) + 1), ( + y_ori if sum(y_ori) < 0 else np.log(np.array(y_ori) + 1) ) if delay != 0: diff --git a/dynamo/plot/networks.py b/dynamo/plot/networks.py index b84b4cebd..2c6087766 100644 --- a/dynamo/plot/networks.py +++ b/dynamo/plot/networks.py @@ -404,7 +404,6 @@ def circosPlotDeprecated( save_kwargs: Dict[str, Any] = {}, **kwargs, ) -> Optional[Any]: - """Deprecated. A wrapper of `dynamo.pl.networks.nxvizPlot` to plot Circos graph. See the `nxvizPlot` for more information. diff --git a/dynamo/plot/preprocess.py b/dynamo/plot/preprocess.py index 9119d9102..fa3ec7c3e 100755 --- a/dynamo/plot/preprocess.py +++ b/dynamo/plot/preprocess.py @@ -583,8 +583,7 @@ def feature_genes( variance_key = layer + "_gini" if variance_key not in adata.var.columns: - raise ValueError( - "Looks like you have not run gene selection yet, try run necessary preprocessing first.") + raise ValueError("Looks like you have not run gene selection yet, try run necessary preprocessing first.") mean = DynamoAdataKeyManager.select_layer_data(adata, layer).mean(0)[0] table = adata.var.loc[:, [variance_key]] @@ -604,9 +603,7 @@ def feature_genes( else: table = adata.var.loc[:, [mean_key, variance_key]] - table = table.loc[ - np.isfinite(table[mean_key]) & np.isfinite(table[variance_key]) - ] + table = table.loc[np.isfinite(table[mean_key]) & np.isfinite(table[variance_key])] x_min, x_max = ( np.nanmin(table[mean_key]), np.nanmax(table[mean_key]), diff --git a/dynamo/plot/pseudotime.py b/dynamo/plot/pseudotime.py index 0f6d0573e..21c37af40 100755 --- a/dynamo/plot/pseudotime.py +++ b/dynamo/plot/pseudotime.py @@ -6,10 +6,10 @@ except ImportError: from typing_extensions import Literal -import pandas as pd import matplotlib.pyplot as plt import networkx as nx import numpy as np +import pandas as pd from anndata import AnnData from scipy.sparse import csr_matrix @@ -34,19 +34,21 @@ def _calculate_cells_mapping( cells_mapping_size = np.bincount(cell_proj_closest_vertex) centroids_index = range(len(cells_mapping_size)) - cell_type_info = pd.DataFrame({ - "class": adata.obs[group_key].values, - "centroid": cell_proj_closest_vertex, - }) + cell_type_info = pd.DataFrame( + { + "class": adata.obs[group_key].values, + "centroid": cell_proj_closest_vertex, + } + ) cell_color_map = get_color_map_from_labels(adata.obs[group_key].values) - cell_type_info = cell_type_info.groupby(['centroid', 'class']).size().unstack() + cell_type_info = cell_type_info.groupby(["centroid", "class"]).size().unstack() cell_type_info = cell_type_info.reindex(centroids_index, fill_value=0) cells_mapping_percentage = cell_type_info.div(cells_mapping_size, axis=0) cells_mapping_percentage = np.nan_to_num(cells_mapping_percentage.values) - cells_mapping_size = (cells_mapping_size / len(cell_proj_closest_vertex)) + cells_mapping_size = cells_mapping_size / len(cell_proj_closest_vertex) cells_mapping_size = [0.05 if s < 0.05 else s for s in cells_mapping_size] return cells_mapping_size, cells_mapping_percentage, cell_color_map @@ -168,7 +170,13 @@ def plot_dim_reduced_direct_graph( max_idx = np.argmax(attributes) dominate_colors.append(cells_colors[max_idx]) - nx.draw_networkx_nodes(G, pos=pos_dict, node_color=dominate_colors, node_size=[s * len(cells_size) * 300 for s in cells_size], ax=ax) + nx.draw_networkx_nodes( + G, + pos=pos_dict, + node_color=dominate_colors, + node_size=[s * len(cells_size) * 300 for s in cells_size], + ax=ax, + ) g = nx.draw_networkx_edges( G, pos=pos_dict, @@ -180,11 +188,14 @@ def plot_dim_reduced_direct_graph( ) cells_color_map["None"] = np.array([0, 0, 0, 1]) - plt.legend(handles=[plt.Line2D([0], [0], marker="o", color='w', label=label, - markerfacecolor=color) for label, color in cells_color_map.items()], - loc="best", - fontsize="medium", - ) + plt.legend( + handles=[ + plt.Line2D([0], [0], marker="o", color="w", label=label, markerfacecolor=color) + for label, color in cells_color_map.items() + ], + loc="best", + fontsize="medium", + ) return save_show_ret("plot_dim_reduced_direct_graph", save_show_or_return, save_kwargs, g) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index d16e5c9ff..cb38ff61f 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -25,17 +25,17 @@ velocity_on_grid, ) from ..tools.utils import update_dict -from ..vectorfield.VectorField import VectorField from ..vectorfield.utils import vecfld_from_adata +from ..vectorfield.VectorField import VectorField from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _get_adata_color_vec, default_quiver_args, quiver_autoscaler, retrieve_plot_save_path, - save_show_ret, save_plotly_figure, save_pyvista_plotter, + save_show_ret, set_arrow_alpha, set_stream_line_alpha, ) @@ -75,7 +75,7 @@ def cell_wise_vectors_3d( save_show_or_return: str = "show", save_kwargs: Dict[str, Any] = {}, quiver_3d_kwargs: Dict[str, Any] = { - "linewidth": 1, + "linewidth": 1, "edgecolors": "white", "alpha": 1, "length": 8, @@ -349,7 +349,7 @@ def add_axis_label(ax, labels): ) point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) - point_cloud['vectors'] = np.column_stack((v0.values, v1.values, v2.values)) + point_cloud["vectors"] = np.column_stack((v0.values, v1.values, v2.values)) r, c = pl.shape[0], pl.shape[1] subplot_indices = [[i, j] for i in range(r) for j in range(c)] @@ -359,7 +359,7 @@ def add_axis_label(ax, labels): point_cloud.point_data["colors"] = np.stack(colors_list[i]) arrows = point_cloud.glyph( - orient='vectors', + orient="vectors", factor=3.5, ) @@ -367,7 +367,7 @@ def add_axis_label(ax, labels): pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) cur_subplot += 1 - pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) + pl.add_mesh(arrows, scalars="colors", preference="point", rgb=True) return save_pyvista_plotter( pl=pl, @@ -422,7 +422,8 @@ def add_axis_label(ax, labels): sizemode="absolute", sizeref=1, ), - row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + row=subplot_indices[cur_subplot][0] + 1, + col=subplot_indices[cur_subplot][1] + 1, ) # TODO: implement customized color for individual cone cur_subplot += 1 @@ -607,7 +608,7 @@ def line_integral_conv( Exception: _description_ Returns: - None would be returned by default. If `save_show_or_return` is set to "return" or "all", the generated + None would be returned by default. If `save_show_or_return` is set to "return" or "all", the generated `yt.SlicePlot` will be returned. """ @@ -966,7 +967,6 @@ def cell_wise_vectors( "alpha": 1, "length": 8, "arrow_length_ratio": scale, - } axes_list, color_list, _ = scatters( @@ -1039,7 +1039,7 @@ def cell_wise_vectors( ) ax.set_facecolor(background) - return save_show_ret("cell_wise_vector", save_show_or_return, save_kwargs, axes_list, tight = projection != "3d") + return save_show_ret("cell_wise_vector", save_show_or_return, save_kwargs, axes_list, tight=projection != "3d") @docstrings.with_indent(4) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index fd2902908..83d16b9ea 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -15,8 +15,8 @@ from anndata import AnnData from matplotlib import patches, rcParams from matplotlib.axes import Axes -from matplotlib.lines import Line2D from matplotlib.colors import rgb2hex, to_hex +from matplotlib.lines import Line2D from pandas.api.types import is_categorical_dtype from scipy.sparse import issparse @@ -25,7 +25,13 @@ from ..dynamo_logger import main_debug, main_info, main_warning from ..preprocessing.utils import affine_transform, gen_rotation_2d from ..tools.moments import calc_1nd_moment -from ..tools.utils import flatten, get_mapper, get_vel_params, update_dict, update_vel_params +from ..tools.utils import ( + flatten, + get_mapper, + get_vel_params, + update_dict, + update_vel_params, +) from .utils import ( _datashade_points, _get_adata_color_vec, @@ -40,9 +46,9 @@ is_layer_keys, is_list_of_lists, retrieve_plot_save_path, - save_show_ret, save_plotly_figure, save_pyvista_plotter, + save_show_ret, ) docstrings = DocstringProcessor() @@ -555,12 +561,16 @@ def _plot_basis_layer(cur_b, cur_l): elif is_gene_name(_adata, cur_x) and is_gene_name(_adata, cur_y): points = pd.DataFrame( { - cur_x: _adata.obs_vector(k=cur_x, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), - cur_y: _adata.obs_vector(k=cur_y, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), + cur_x: ( + _adata.obs_vector(k=cur_x, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed) + ), + cur_y: ( + _adata.obs_vector(k=cur_y, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed) + ), } ) # points = points.loc[(points > 0).sum(1) > 1, :] @@ -582,9 +592,11 @@ def _plot_basis_layer(cur_b, cur_l): points = pd.DataFrame( { cur_x: _adata.obs_vector(cur_x), - cur_y: _adata.obs_vector(k=cur_y, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), + cur_y: ( + _adata.obs_vector(k=cur_y, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed) + ), } ) # points = points.loc[points.iloc[:, 1] > 0, :] @@ -596,9 +608,11 @@ def _plot_basis_layer(cur_b, cur_l): elif is_gene_name(_adata, cur_x) and is_cell_anno_column(_adata, cur_y): points = pd.DataFrame( { - cur_x: _adata.obs_vector(k=cur_x, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), + cur_x: ( + _adata.obs_vector(k=cur_x, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed) + ), cur_y: _adata.obs_vector(cur_y), } ) @@ -633,9 +647,11 @@ def _plot_basis_layer(cur_b, cur_l): list(_adata.obs[aggregate].unique()), ) group_color, group_median = ( - np.zeros((1, len(uniq_grp))).flatten() - if isinstance(_color[0], Number) - else np.zeros((1, len(uniq_grp))).astype("str").flatten(), + ( + np.zeros((1, len(uniq_grp))).flatten() + if isinstance(_color[0], Number) + else np.zeros((1, len(uniq_grp))).astype("str").flatten() + ), np.zeros((len(uniq_grp), 2)), ) @@ -882,7 +898,9 @@ def _plot_basis_layer(cur_b, cur_l): return_value = (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) else: return_value = axes_list if total_panels > 1 else ax - return save_show_ret("scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, background=background) + return save_show_ret( + "scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, background=background + ) def map_to_points( @@ -925,9 +943,11 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: nonlocal gene_title, anno_title if is_gene_name(_adata, cur): - points_df_data = (_adata.obs_vector(k=cur, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_df_data = ( + _adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed) + ) points_column = cur + " (" + cur_l_smoothed + ")" gene_title.append(cur) elif is_cell_anno_column(_adata, cur): @@ -960,20 +980,22 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: return points, cur_title elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ - anndata._core.views.ArrayView, - np.ndarray, - ]: + anndata._core.views.ArrayView, + np.ndarray, + ]: points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y), "x": flatten(axis_z)}) points.columns = ["x", "y", "z"] else: x_points_df_data, x_points_column = _map_cur_axis(axis_x) y_points_df_data, y_points_column = _map_cur_axis(axis_y) z_points_df_data, z_points_column = _map_cur_axis(axis_z) - points = pd.DataFrame({ - axis_x: x_points_df_data, - axis_y: y_points_df_data, - axis_z: z_points_df_data, - }) + points = pd.DataFrame( + { + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + } + ) points.columns = [x_points_column, y_points_column, z_points_column] if len(gene_title) != 0: @@ -1098,12 +1120,12 @@ def scatters_interactive_legacy( # make x, y, z lists of list, where each list corresponds to one coordinate set if ( - type(x) in [anndata._core.views.ArrayView, np.ndarray] - and type(y) in [anndata._core.views.ArrayView, np.ndarray] - and type(z) in [anndata._core.views.ArrayView, np.ndarray] - and len(x) == adata.n_obs - and len(y) == adata.n_obs - and len(z) == adata.n_obs + type(x) in [anndata._core.views.ArrayView, np.ndarray] + and type(y) in [anndata._core.views.ArrayView, np.ndarray] + and type(z) in [anndata._core.views.ArrayView, np.ndarray] + and len(x) == adata.n_obs + and len(y) == adata.n_obs + and len(z) == adata.n_obs ): x, y, z = [x], [y], [z] @@ -1148,8 +1170,9 @@ def scatters_interactive_legacy( pl = ( pv.Plotter(shape=(nrow, ncol)) if plot_method == "pv" - else - make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + else make_subplots( + rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)] + ) ) def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: @@ -1276,7 +1299,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: pvdataset = pv.PolyData(points.values) pvdataset.point_data["colors"] = np.stack(colors) - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + pl.add_points(pvdataset, scalars="colors", preference="point", rgb=True, cmap=_cmap, **kwargs) if color_type == "labels": type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} @@ -1301,14 +1324,13 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: text=_labels if color_type == "labels" else _values, **kwargs, ), - row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + row=subplot_indices[cur_subplot][0] + 1, + col=subplot_indices[cur_subplot][1] + 1, ) pl.update_layout( scene=dict( - xaxis_title=points.columns[0], - yaxis_title=points.columns[1], - zaxis_title=points.columns[2] + xaxis_title=points.columns[0], yaxis_title=points.columns[1], zaxis_title=points.columns[2] ), ) @@ -1320,16 +1342,20 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: main_debug("colors: %s" % (str(color))) _plot_basis_layer_pv(cur_b, cur_l) - return save_pyvista_plotter( - pl=pl, - colors_list=colors_list, - save_show_or_return=save_show_or_return, - save_kwargs=save_kwargs, - ) if plot_method == "pv" else save_plotly_figure( - pl=pl, - colors_list=colors_list, - save_show_or_return=save_show_or_return, - save_kwargs=save_kwargs, + return ( + save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + if plot_method == "pv" + else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) ) @@ -1461,8 +1487,9 @@ def scatters_interactive( pl = ( pv.Plotter(shape=(nrow, ncol)) if plot_method == "pv" - else - make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + else make_subplots( + rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)] + ) ) def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: str, cur_z: str) -> None: @@ -1515,14 +1542,11 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: title=cur_title, ) - if smooth and not is_not_continuous: main_debug("smooth and not continuous") knn = adata.obsp["moments_con"] _values = ( - calc_1nd_moment(_values, knn)[0] - if smooth in [1, True] - else calc_1nd_moment(_values, knn**smooth)[0] + calc_1nd_moment(_values, knn)[0] if smooth in [1, True] else calc_1nd_moment(_values, knn**smooth)[0] ) colors, color_type, _ = calculate_colors( @@ -1544,7 +1568,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: pvdataset = pv.PolyData(points.values) pvdataset.point_data["colors"] = np.stack(colors) - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + pl.add_points(pvdataset, scalars="colors", preference="point", rgb=True, cmap=_cmap, **kwargs) if color_type == "labels": type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} @@ -1569,15 +1593,12 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: text=_labels if color_type == "labels" else _values, **kwargs, ), - row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + row=subplot_indices[cur_subplot][0] + 1, + col=subplot_indices[cur_subplot][1] + 1, ) pl.update_layout( - scene=dict( - xaxis_title=points.columns[0], - yaxis_title=points.columns[1], - zaxis_title=points.columns[2] - ), + scene=dict(xaxis_title=points.columns[0], yaxis_title=points.columns[1], zaxis_title=points.columns[2]), ) cur_subplot += 1 @@ -1614,17 +1635,20 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: for cur_x, cur_y, cur_z in zip(x, y, z): _plot_basis_layer_pv(cur_b, cur_l, cur_c, cur_x, cur_y, cur_z) - - return save_pyvista_plotter( - pl=pl, - colors_list=colors_list, - save_show_or_return=save_show_or_return, - save_kwargs=save_kwargs, - ) if plot_method == "pv" else save_plotly_figure( - pl=pl, - colors_list=colors_list, - save_show_or_return=save_show_or_return, - save_kwargs=save_kwargs, + return ( + save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + if plot_method == "pv" + else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) ) @@ -2019,8 +2043,9 @@ def scatters( return_value = (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) else: return_value = axes_list if total_panels > 1 else ax - return save_show_ret("scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, - background=background) + return save_show_ret( + "scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, background=background + ) def scatters_single_input( @@ -2271,16 +2296,16 @@ def scatters_single_input( list(adata.obs[aggregate].unique()), ) group_color, group_median = ( - np.zeros((1, len(uniq_grp))).flatten() - if isinstance(_color[0], Number) - else np.zeros((1, len(uniq_grp))).astype("str").flatten(), + ( + np.zeros((1, len(uniq_grp))).flatten() + if isinstance(_color[0], Number) + else np.zeros((1, len(uniq_grp))).astype("str").flatten() + ), np.zeros((len(uniq_grp), 2)), ) grp_size = adata.obs[aggregate].value_counts()[uniq_grp].values - scatter_kwargs = ( - {"s": grp_size} if scatter_kwargs is None else update_dict(scatter_kwargs, {"s": grp_size}) - ) + scatter_kwargs = {"s": grp_size} if scatter_kwargs is None else update_dict(scatter_kwargs, {"s": grp_size}) for ind, cur_grp in enumerate(uniq_grp): group_median[ind, :] = np.nanmedian( @@ -2312,11 +2337,7 @@ def scatters_single_input( if smooth and not is_not_continuous: main_debug("smooth and not continuous") knn = adata.obsp["moments_con"] - values = ( - calc_1nd_moment(values, knn)[0] - if smooth in [1, True] - else calc_1nd_moment(values, knn ** smooth)[0] - ) + values = calc_1nd_moment(values, knn)[0] if smooth in [1, True] else calc_1nd_moment(values, knn**smooth)[0] if affine_transform_A is None or affine_transform_b is None: point_coords = points.values @@ -2404,8 +2425,7 @@ def scatters_single_input( update_vel_params(adata, params_df=vel_params_df) ax.plot( xnew, - xnew * adata[:, basis].var.loc[:, k_name].unique() - + adata[:, basis].var.loc[:, "gamma_b"].unique(), + xnew * adata[:, basis].var.loc[:, k_name].unique() + adata[:, basis].var.loc[:, "gamma_b"].unique(), dashes=[6, 2], c=font_color, ) @@ -2429,8 +2449,8 @@ def scatters_single_input( group_points.iloc[:, 0].max() * 0.90, ) group_ynew = ( - group_xnew * group_adata[:, basis].var.loc[:, group_k_name].unique() - + group_adata[:, basis].var.loc[:, group_b_key].unique() + group_xnew * group_adata[:, basis].var.loc[:, group_k_name].unique() + + group_adata[:, basis].var.loc[:, group_b_key].unique() ) ax.annotate(group + "_" + cur_group, xy=(group_xnew[-1], group_ynew[-1])) vel_params_df = get_vel_params(group_adata) @@ -2532,6 +2552,7 @@ def _validate_parameters( return basis, background, color, layer, x, y, z, nrow, ncol, total_panels + def _get_basis_key( adata: AnnData, basis: str, @@ -2571,6 +2592,7 @@ def _get_basis_key( return basis_key, cur_l_smoothed, cmap, sym_c + def _get_color_parameters( adata: AnnData, color: str, @@ -2628,8 +2650,7 @@ def _get_theme_for_values(): is_numeric_color = np.issubdtype(_color.dtype, np.number) if not is_numeric_color: main_info( - "skip filtering %s by stack threshold when stacking color because it is not a numeric type" - % color, + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" % color, indent_level=2, ) @@ -2739,9 +2760,11 @@ def _map_cur_axis_to_title( nonlocal gene_title, anno_title if is_gene_name(_adata, cur): - points_df_data = (_adata.obs_vector(k=cur, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_df_data = ( + _adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed) + ) points_column = cur + " (" + cur_l_smoothed + ")" gene_title.append(cur) elif is_cell_anno_column(_adata, cur): @@ -2784,11 +2807,13 @@ def _map_cur_axis_to_title( x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) z_points_df_data, z_points_column = _map_cur_axis_to_title(axis_z, _adata, cur_b, cur_l_smoothed) - points = pd.DataFrame({ - axis_x: x_points_df_data, - axis_y: y_points_df_data, - axis_z: z_points_df_data, - }) + points = pd.DataFrame( + { + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + } + ) points.columns = [x_points_column, y_points_column, z_points_column] if len(gene_title) != 0: @@ -2824,10 +2849,13 @@ def _map_cur_axis_to_title( y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) x_points_df_data = x_points_df_data.A.flatten() if issparse(x_points_df_data) else x_points_df_data y_points_df_data = y_points_df_data.A.flatten() if issparse(y_points_df_data) else y_points_df_data - points = pd.DataFrame({ - axis_x: x_points_df_data, - axis_y: y_points_df_data, - }, index=_adata.obs_names) + points = pd.DataFrame( + { + axis_x: x_points_df_data, + axis_y: y_points_df_data, + }, + index=_adata.obs_names, + ) points.columns = [axis_x, axis_y] if len(gene_title) != 0: diff --git a/dynamo/plot/sctransform.py b/dynamo/plot/sctransform.py index 3a4b9a646..3413eac23 100644 --- a/dynamo/plot/sctransform.py +++ b/dynamo/plot/sctransform.py @@ -1,11 +1,12 @@ from typing import Optional -from anndata import AnnData -from matplotlib.axes import Axes -from matplotlib.figure import Figure import matplotlib.pyplot as plt import numpy as np import pandas as pd +from anndata import AnnData +from matplotlib.axes import Axes +from matplotlib.figure import Figure + def sctransform_plot_fit( adata: AnnData, @@ -24,23 +25,21 @@ def sctransform_plot_fit( """ if fig is None: fig = plt.figure(figsize=(12, 3)) - gene_names = adata.var['genes_step1_sct'][ - ~adata.var['genes_step1_sct'].isna()].index + gene_names = adata.var["genes_step1_sct"][~adata.var["genes_step1_sct"].isna()].index genes_log10_mean = adata.var["log10_gmean_sct"] genes_log_gmean = genes_log10_mean[~genes_log10_mean.isna()] - model_params_fit = pd.concat( - [adata.var["log_umi_sct"], adata.var["Intercept_sct"], adata.var["theta_sct"]], axis=1) + model_params_fit = pd.concat([adata.var["log_umi_sct"], adata.var["Intercept_sct"], adata.var["theta_sct"]], axis=1) model_params = pd.concat( - [adata.var["log_umi_step1_sct"], adata.var["Intercept_step1_sct"], adata.var["model_pars_theta_step1"]], - axis=1) + [adata.var["log_umi_step1_sct"], adata.var["Intercept_step1_sct"], adata.var["model_pars_theta_step1"]], axis=1 + ) model_params_fit = model_params_fit.rename( - columns={"log_umi_sct": "log_umi", "Intercept_sct": "Intercept", "theta_sct": "theta"}) + columns={"log_umi_sct": "log_umi", "Intercept_sct": "Intercept", "theta_sct": "theta"} + ) model_params = model_params.rename( - columns={"log_umi_step1_sct": "log_umi", - "Intercept_step1_sct": "Intercept", - "model_pars_theta_step1": "theta"}) + columns={"log_umi_step1_sct": "log_umi", "Intercept_step1_sct": "Intercept", "model_pars_theta_step1": "theta"} + ) model_params = model_params.loc[gene_names] @@ -92,6 +91,7 @@ def sctransform_plot_fit( _ = fig.tight_layout() return fig + def plot_residual_var( adata: AnnData, topngenes: int = 10, @@ -111,6 +111,7 @@ def plot_residual_var( Returns: The Figure object if `ax` is not given, else None. """ + def vars(a, axis=None): """Helper function to calculate variance of sparse matrix by equation: var = mean(a**2) - mean(a)**2""" a_squared = a.copy() @@ -122,28 +123,21 @@ def vars(a, axis=None): else: fig = None - gene_attr = pd.DataFrame(adata.var['log10_gmean_sct']) + gene_attr = pd.DataFrame(adata.var["log10_gmean_sct"]) # gene_attr = gene_attr.loc[gene_names] gene_attr["var"] = vars(adata.X, axis=0).tolist()[0] gene_attr["mean"] = adata.X.mean(axis=0).tolist()[0] - gene_attr_sorted = gene_attr.sort_values( - "var", ascending=False - ).reset_index() + gene_attr_sorted = gene_attr.sort_values("var", ascending=False).reset_index() topn = gene_attr_sorted.iloc[:topngenes] gene_attr = gene_attr_sorted.iloc[topngenes:] ax.set_xscale("log") - ax.scatter( - gene_attr["mean"], gene_attr["var"], s=1.5, color="black" - ) + ax.scatter(gene_attr["mean"], gene_attr["var"], s=1.5, color="black") ax.scatter(topn["mean"], topn["var"], s=1.5, color="deeppink") ax.axhline(1, linestyle="dashed", color="red") ax.set_xlabel("mean") ax.set_ylabel("var") if label_genes: - texts = [ - plt.text(row["mean"], row["var"], row["index"]) - for index, row in topn.iterrows() - ] + texts = [plt.text(row["mean"], row["var"], row["index"]) for index, row in topn.iterrows()] fig.tight_layout() - return fig \ No newline at end of file + return fig diff --git a/dynamo/plot/state_graph.py b/dynamo/plot/state_graph.py index ff8faa795..509af0d30 100755 --- a/dynamo/plot/state_graph.py +++ b/dynamo/plot/state_graph.py @@ -322,4 +322,6 @@ def state_graph( plt.axis("off") - return save_show_ret("state_graph", save_show_or_return, save_kwargs, (axes_list, color_list, font_color), adjust = show_legend) + return save_show_ret( + "state_graph", save_show_or_return, save_kwargs, (axes_list, color_list, font_color), adjust=show_legend + ) diff --git a/dynamo/plot/time_series.py b/dynamo/plot/time_series.py index 1892f93cf..1ca8892e0 100755 --- a/dynamo/plot/time_series.py +++ b/dynamo/plot/time_series.py @@ -449,7 +449,7 @@ def kinetic_heatmap( vline_kwargs = update_dict({"linestyles": "dashdot"}, vlines_kwargs) sns_heatmap.ax_heatmap.vlines(vline_cols, *sns_heatmap.ax_heatmap.get_ylim(), **vline_kwargs) - return save_show_ret("kinetic_heatmap", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) + return save_show_ret("kinetic_heatmap", save_show_or_return, save_kwargs, sns_heatmap, adjust=show_colorbar) def _half_max_ordering(exprs, time, mode, interpolate=False, spaced_num=100): @@ -675,7 +675,7 @@ def jacobian_kinetics( Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the generated seaborn ClusterGrid would be returned. - + Examples: >>> import dynamo as dyn >>> adata = dyn.sample_data.hgForebrainGlutamatergic() @@ -803,7 +803,7 @@ def jacobian_kinetics( if not show_colorbar: sns_heatmap.cax.set_visible(False) - return save_show_ret("jacobian_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) + return save_show_ret("jacobian_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust=show_colorbar) @docstrings.with_indent(4) @@ -872,7 +872,7 @@ def sensitivity_kinetics( Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the generated seaborn ClusterGrid would be returned. - + Examples: >>> import dynamo as dyn >>> adata = dyn.sample_data.hgForebrainGlutamatergic() @@ -1000,4 +1000,4 @@ def sensitivity_kinetics( if not show_colorbar: sns_heatmap.cax.set_visible(False) - return save_show_ret("sensitivity_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) + return save_show_ret("sensitivity_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust=show_colorbar) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index f491ed022..bfbe86a75 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -19,13 +19,11 @@ from ..tools.cell_velocities import cell_velocities from ..tools.utils import nearest_neighbors, update_dict from ..vectorfield.scVectorField import BaseVectorField -from ..vectorfield.topography import ( # , compute_separatrices - Topography2D, -) +from ..vectorfield.topography import Topography2D # , compute_separatrices from ..vectorfield.topography import topography as _topology # , compute_separatrices from ..vectorfield.utils import vecfld_from_adata -from ..vectorfield.VectorField import VectorField from ..vectorfield.vector_calculus import curl, divergence +from ..vectorfield.VectorField import VectorField from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _plot_traj, @@ -33,9 +31,9 @@ default_quiver_args, quiver_autoscaler, retrieve_plot_save_path, - save_show_ret, save_plotly_figure, save_pyvista_plotter, + save_show_ret, set_arrow_alpha, set_stream_line_alpha, ) @@ -534,7 +532,8 @@ def plot_fixed_points( ), **kwargs, ), - row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + row=subplot_indices[cur_subplot][0] + 1, + col=subplot_indices[cur_subplot][1] + 1, ) return save_plotly_figure( @@ -592,7 +591,7 @@ def plot_traj( ax: Optional[Axes] = None, ) -> Optional[Axes]: """Plots a trajectory on a phase portrait. - + Code adapted from: http://be150.caltech.edu/2017/handouts/dynamical_systems_approaches.html Args: @@ -1306,7 +1305,9 @@ def topography( **quiver_kwargs, ) # color='red', facecolors='gray' - return save_show_ret("topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0]) + return save_show_ret( + "topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0] + ) # TODO: Implement more `terms` like streamline and trajectory for 3D topography @@ -1567,7 +1568,6 @@ def topography_3D( max_[2] + (max_[2] - min_[2]) * 0.1, ] - if init_cells is not None: if init_states is None: intersect_cell_names = list(set(init_cells).intersection(adata.obs_names)) @@ -1751,4 +1751,6 @@ def topography_3D( cmap=marker_cmap, ) - return save_show_ret("topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0]) + return save_show_ret( + "topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0] + ) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 941ad41e8..d4caedc09 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -243,20 +243,20 @@ def calculate_colors( _vmin = ( np.nanmin(values) if vmin is None - else np.nanpercentile(values, vmin * 100) - if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmin) - if (vmin + vmax == 100 and 0 <= vmin < vmax) - else vmin + else ( + np.nanpercentile(values, vmin * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmin) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmin + ) ) _vmax = ( np.nanmax(values) if vmax is None - else np.nanpercentile(values, vmax * 100) - if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmax) - if (vmin + vmax == 100 and 0 <= vmin < vmax) - else vmax + else ( + np.nanpercentile(values, vmax * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmax) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmax + ) ) if sym_c and _vmin < 0 and _vmax > 0: @@ -264,7 +264,6 @@ def calculate_colors( bounds = bounds * np.array([-1, 1]) _vmin, _vmax = bounds - if "norm" in kwargs: norm = kwargs["norm"] else: @@ -424,9 +423,15 @@ def _matplotlib_points( if ax is None: dpi = plt.rcParams["figure.dpi"] fig = plt.figure(figsize=(width / dpi, height / dpi)) - ax = fig.add_subplot( - 111, projection=projection, computed_zorder=False, - ) if projection == "3d" else fig.add_subplot(111, projection=projection) + ax = ( + fig.add_subplot( + 111, + projection=projection, + computed_zorder=False, + ) + if projection == "3d" + else fig.add_subplot(111, projection=projection) + ) ax.set_facecolor(background) @@ -655,20 +660,20 @@ def _matplotlib_points( _vmin = ( np.nanmin(values) if vmin is None - else np.nanpercentile(values, vmin * 100) - if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmin) - if (vmin + vmax == 100 and 0 <= vmin < vmax) - else vmin + else ( + np.nanpercentile(values, vmin * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmin) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmin + ) ) _vmax = ( np.nanmax(values) if vmax is None - else np.nanpercentile(values, vmax * 100) - if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmax) - if (vmin + vmax == 100 and 0 <= vmin < vmax) - else vmax + else ( + np.nanpercentile(values, vmax * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmax) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmax + ) ) if sym_c and _vmin < 0 and _vmax > 0: @@ -929,7 +934,10 @@ def _datashade_points( data["label"] == "other", ) reorder_data = data.copy(deep=True) - (reorder_data.iloc[: sum(background_ids), :], reorder_data.iloc[sum(background_ids) :, :],) = ( + ( + reorder_data.iloc[: sum(background_ids), :], + reorder_data.iloc[sum(background_ids) :, :], + ) = ( data.loc[background_ids, :], data.loc[highlight_ids, :], ) @@ -1663,31 +1671,31 @@ def save_show_ret( prefix: str, save_show_or_return: Literal["save", "show", "return", "both", "all"], save_kwargs: Dict[str, Any], - ret_value = None, + ret_value=None, tight: bool = True, adjust: bool = False, background: Optional[str] = None, ): """ - Helper function that performs actions based on the variable save_show_or_return. + Helper function that performs actions based on the variable save_show_or_return. Should always have at least 3 inputs (prefix, save_show__or_return, save_kwargs). Args: prefix: Prefix added to name of figure that will be saved. See the `s_kwargs` variable. - save_show_or_return: Whether the figure should be saved, shown, or returned. + save_show_or_return: Whether the figure should be saved, shown, or returned. "both" means that the figure would be shown and saved but not returned. Defaults to "show". - save_kwargs: A dictionary that will be passed to the save_fig() function. + save_kwargs: A dictionary that will be passed to the save_fig() function. The save_fig() function will use { - "path": None, - "prefix": [prefix input], - "dpi": None, + "path": None, + "prefix": [prefix input], + "dpi": None, "ext": 'pdf', - "transparent": True, - "close": True, + "transparent": True, + "close": True, "verbose": True - } + } as its parameters. `save_kwargs` modifies those keys according to your needs. Defaults to {}. ret_value: Value to be returned if `save_show_or_return` equals "return" or "all". tight: Toggles whether plt.tight_layout() is called. @@ -1722,7 +1730,7 @@ def save_show_ret( plt.subplots_adjust(right=0.85) if tight: - #Do note that warnings should not be ignored in the future. + # Do note that warnings should not be ignored in the future. with warnings.catch_warnings(): warnings.simplefilter("ignore") plt.tight_layout() @@ -1812,7 +1820,7 @@ def save_pyvista_plotter( "path": None, "prefix": "scatters_pv", "ext": "pdf", - "title": 'PyVista Export', + "title": "PyVista Export", "raster": True, "painter": True, } diff --git a/dynamo/plot/vector_calculus.py b/dynamo/plot/vector_calculus.py index 516922f46..185545326 100644 --- a/dynamo/plot/vector_calculus.py +++ b/dynamo/plot/vector_calculus.py @@ -405,7 +405,7 @@ def jacobian( Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the matplotlib `GridSpec` of the figure would be returned. - + Examples: >>> import dynamo as dyn >>> adata = dyn.sample_data.hgForebrainGlutamatergic() @@ -797,7 +797,7 @@ def sensitivity( Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the matplotlib `GridSpec` of the figure would be returned. - + Examples: >>> import dynamo as dyn >>> adata = dyn.sample_data.hgForebrainGlutamatergic() @@ -1017,7 +1017,7 @@ def sensitivity_heatmap( Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the matplotlib `GridSpec` of the figure would be returned. - + Examples: >>> import dynamo as dyn >>> adata = dyn.sample_data.hgForebrainGlutamatergic() diff --git a/dynamo/prediction/fate.py b/dynamo/prediction/fate.py index 9c2a4deb2..bf82b8384 100755 --- a/dynamo/prediction/fate.py +++ b/dynamo/prediction/fate.py @@ -16,9 +16,13 @@ main_info_insert_adata, main_warning, ) -from ..utils import pca_to_expr -from ..tools.connectivity import construct_mapper_umap, correct_hnsw_neighbors, k_nearest_neighbors +from ..tools.connectivity import ( + construct_mapper_umap, + correct_hnsw_neighbors, + k_nearest_neighbors, +) from ..tools.utils import fetch_states, getTseq +from ..utils import pca_to_expr from ..vectorfield import vector_field_function from ..vectorfield.utils import vecfld_from_adata, vector_transformation from .utils import integrate_vf_ivp diff --git a/dynamo/prediction/least_action_path.py b/dynamo/prediction/least_action_path.py index 430f3c649..b2e2c9abf 100644 --- a/dynamo/prediction/least_action_path.py +++ b/dynamo/prediction/least_action_path.py @@ -15,7 +15,7 @@ vector_field_function_transformation, vector_transformation, ) -from .trajectory import arclength_sampling_n, GeneTrajectory, Trajectory +from .trajectory import GeneTrajectory, Trajectory, arclength_sampling_n from .utils import find_elbow @@ -138,6 +138,7 @@ class GeneLeastActionPath(GeneTrajectory): t: Array of time values. action: Array of action values. """ + def __init__( self, adata: AnnData, @@ -652,7 +653,7 @@ def least_action( A = [] path_ind = 0 - for (init_state, target_state) in LoggerManager.progress_logger( + for init_state, target_state in LoggerManager.progress_logger( pairs, progress_name=f"iterating through {len(pairs)} pairs" ): logger.info( diff --git a/dynamo/prediction/perturbation.py b/dynamo/prediction/perturbation.py index ab26f6571..7f551a048 100644 --- a/dynamo/prediction/perturbation.py +++ b/dynamo/prediction/perturbation.py @@ -9,14 +9,13 @@ from ..tools.cell_velocities import cell_velocities from ..utils import expr_to_pca, pca_to_expr from ..vectorfield import SvcVectorField +from ..vectorfield.rank_vf import rank_cell_groups, rank_genes from ..vectorfield.scVectorField import KOVectorField, vector_field_function_knockout from ..vectorfield.vector_calculus import ( jacobian, vecfld_from_adata, vector_transformation, ) - -from ..vectorfield.rank_vf import rank_cell_groups, rank_genes from .utils import z_score, z_score_inv diff --git a/dynamo/prediction/trajectory.py b/dynamo/prediction/trajectory.py index 1b3cbadf3..5819d82ed 100644 --- a/dynamo/prediction/trajectory.py +++ b/dynamo/prediction/trajectory.py @@ -15,6 +15,7 @@ class Trajectory: """Base class for handling trajectory interpolation, resampling, etc.""" + def __init__(self, X: np.ndarray, t: Union[None, np.ndarray] = None, sort: bool = True) -> None: """Initializes a Trajectory object. @@ -156,7 +157,7 @@ def archlength_sampling( # idx = dup_osc_idx_iter(x) x = x[:idx] _, arclen, _ = remove_redundant_points_trajectory(x, tol=1e-4, output_discard=True) - cur_Y, alen, self.t = arclength_sampling_n(x, num=interpolation_num+1, t=tau[:idx]) + cur_Y, alen, self.t = arclength_sampling_n(x, num=interpolation_num + 1, t=tau[:idx]) self.t = self.t[1:] cur_Y = cur_Y[:, 1:] @@ -301,6 +302,7 @@ def calc_msd(self, decomp_dim: bool = True, ref: int = 0) -> Union[float, np.nda class VectorFieldTrajectory(Trajectory): """Class for handling trajectory data with a differentiable vector field.""" + def __init__(self, X: np.ndarray, t: np.ndarray, vecfld: DifferentiableVectorField) -> None: """Initializes a VectorFieldTrajectory object. @@ -414,6 +416,7 @@ def calc_vector_msd(self, key: str, decomp_dim: bool = True, ref: int = 0) -> Un class GeneTrajectory(Trajectory): """Class for handling gene expression trajectory data.""" + def __init__( self, adata: AnnData, @@ -524,7 +527,10 @@ def save(self, save_key: str = "gene_trajectory") -> None: self.adata.varm[save_key][self.genes_to_mask(), :] = self.X.T def select_gene( - self, genes: Union[np.ndarray, list], arr: Optional[np.ndarray] = None, axis: Optional[int] = None, + self, + genes: Union[np.ndarray, list], + arr: Optional[np.ndarray] = None, + axis: Optional[int] = None, ) -> np.ndarray: """Selects the gene expression data for the specified genes. @@ -560,7 +566,9 @@ def select_gene( def arclength_sampling_n( - X: np.ndarray, num: int, t: Optional[np.ndarray] = None, + X: np.ndarray, + num: int, + t: Optional[np.ndarray] = None, ) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]: """Uniformly sample data points on an arc curve that generated from vector field predictions. @@ -585,7 +593,9 @@ def arclength_sampling_n( def remove_redundant_points_trajectory( - X: np.ndarray, tol: float = 1e-4, output_discard: bool = False, + X: np.ndarray, + tol: float = 1e-4, + output_discard: bool = False, ) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]: """Remove consecutive data points that are too close to each other. diff --git a/dynamo/prediction/tscRNA_seq.py b/dynamo/prediction/tscRNA_seq.py index aa611070d..7d41437ee 100644 --- a/dynamo/prediction/tscRNA_seq.py +++ b/dynamo/prediction/tscRNA_seq.py @@ -18,7 +18,6 @@ def get_pulse_r0( add_init_r0_key: str = "init_r0_pulse", copy: bool = False, ) -> Union[anndata.AnnData, None]: - """Get the total RNA at the initial time point for a kinetic experiment with the formula: :math:`r_0 = \frac{(r - l)}{(1 - k)}`, where :math: `k = 1 - e^{- \gamma t} diff --git a/dynamo/prediction/utils.py b/dynamo/prediction/utils.py index ba46ce10e..b1a8f0650 100644 --- a/dynamo/prediction/utils.py +++ b/dynamo/prediction/utils.py @@ -7,13 +7,13 @@ from scipy.integrate import solve_ivp from tqdm import tqdm -from .trajectory import Trajectory from ..dynamo_logger import main_warning from ..tools.utils import log1p_, nearest_neighbors from ..utils import isarray, normalize # import scipy.sparse as sp from ..vectorfield.topography import dup_osc_idx_iter +from .trajectory import Trajectory # --------------------------------------------------------------------------------------------------- # initial state related @@ -253,7 +253,7 @@ def integrate_vf_ivp( t = [t] * n_cell subarray_width = Y.shape[1] // n_cell - Y = [Y[:, i * subarray_width: (i + 1) * subarray_width] for i in range(n_cell)] + Y = [Y[:, i * subarray_width : (i + 1) * subarray_width] for i in range(n_cell)] return t, Y diff --git a/dynamo/preprocessing/Preprocessor.py b/dynamo/preprocessing/Preprocessor.py index 614d56131..87abfffe5 100644 --- a/dynamo/preprocessing/Preprocessor.py +++ b/dynamo/preprocessing/Preprocessor.py @@ -27,10 +27,10 @@ from .gene_selection import select_genes_by_seurat_recipe, select_genes_monocle from .normalization import calc_sz_factor, normalize from .pca import pca -from .QC import basic_stats +from .QC import basic_stats, filter_cells_by_highly_variable_genes from .QC import filter_cells_by_outliers as monocle_filter_cells_by_outliers from .QC import filter_genes_by_outliers as monocle_filter_genes_by_outliers -from .QC import regress_out_parallel, filter_cells_by_highly_variable_genes +from .QC import regress_out_parallel from .transform import Freeman_Tukey, log, log1p, log2 from .utils import ( _infer_labeling_experiment_type, @@ -255,7 +255,6 @@ def standardize_adata(self, adata: AnnData, tkey: str, experiment_type: str) -> main_debug("applying convert_gene_name function...") self.convert_gene_name(adata) - self.basic_stats(adata) self.add_experiment_info(adata, tkey, experiment_type) main_info_insert_adata("tkey=%s" % tkey, "uns['pp']", indent_level=2) @@ -300,7 +299,9 @@ def _filter_cells_by_highly_variable_genes(self, adata: AnnData) -> None: if self.filter_cells_by_highly_variable_genes: main_debug("filtering cells by highly variable genes...") - main_debug("filter_cells_by_highly_variable_genes kwargs:" + str(self.filter_cells_by_highly_variable_genes_kwargs)) + main_debug( + "filter_cells_by_highly_variable_genes kwargs:" + str(self.filter_cells_by_highly_variable_genes_kwargs) + ) self.filter_cells_by_highly_variable_genes(adata, **self.filter_cells_by_highly_variable_genes_kwargs) def _calc_size_factor(self, adata: AnnData) -> None: diff --git a/dynamo/preprocessing/QC.py b/dynamo/preprocessing/QC.py index 118a7588d..7f2a113c5 100644 --- a/dynamo/preprocessing/QC.py +++ b/dynamo/preprocessing/QC.py @@ -291,8 +291,7 @@ def filter_cells_by_highly_variable_genes( """ if high_var_genes_key not in adata.var.keys(): raise ValueError( - "The key %s is not found in adata.var. Please run genes selection methods first." - % high_var_genes_key + "The key %s is not found in adata.var. Please run genes selection methods first." % high_var_genes_key ) if obs_store_key not in adata.obs.keys(): diff --git a/dynamo/preprocessing/__init__.py b/dynamo/preprocessing/__init__.py index 3e3365cc6..35935bedb 100755 --- a/dynamo/preprocessing/__init__.py +++ b/dynamo/preprocessing/__init__.py @@ -2,6 +2,14 @@ """ from .cell_cycle import cell_cycle_scores +from .deprecated import ( + calc_sz_factor_legacy, + cook_dist, + filter_cells_legacy, + normalize_cell_expr_by_size_factors, + recipe_monocle, + recipe_velocyto, +) from .dynast import lambda_correction from .external import ( harmony_debatch, @@ -11,15 +19,15 @@ select_genes_by_pearson_residuals, ) from .normalization import calc_sz_factor, normalize +from .pca import pca, top_pca_genes from .QC import ( basic_stats, - filter_genes_by_clusters, - filter_cells_by_outliers, filter_cells_by_highly_variable_genes, + filter_cells_by_outliers, + filter_genes_by_clusters, filter_genes_by_outliers, filter_genes_by_pattern, ) -from .pca import pca, top_pca_genes from .transform import log1p, log1p_adata_layer from .utils import ( compute_gene_exp_fraction, @@ -29,14 +37,6 @@ relative2abs, scale, ) -from .deprecated import ( - cook_dist, - calc_sz_factor_legacy, - normalize_cell_expr_by_size_factors, - filter_cells_legacy, - recipe_monocle, - recipe_velocyto, -) filter_cells = filter_cells_by_outliers filter_genes = filter_genes_by_outliers @@ -44,7 +44,12 @@ normalize_cells = normalize from .CnmfPreprocessor import CnmfPreprocessor -from .gene_selection import calc_Gini, calc_dispersion_by_svr, highest_frac_genes, select_genes_monocle +from .gene_selection import ( + calc_dispersion_by_svr, + calc_Gini, + highest_frac_genes, + select_genes_monocle, +) from .Preprocessor import Preprocessor __all__ = [ diff --git a/dynamo/preprocessing/deprecated.py b/dynamo/preprocessing/deprecated.py index f9fca44b9..122b39a1c 100644 --- a/dynamo/preprocessing/deprecated.py +++ b/dynamo/preprocessing/deprecated.py @@ -72,6 +72,7 @@ def wrapper(*args, **kwargs): # from __future__ import division, print_function + # https://stats.stackexchange.com/questions/356053/the-identity-link-function-does-not-respect-the-domain-of-the-gamma- # family def _weight_matrix_legacy(fitted_model: sm.Poisson) -> np.ndarray: diff --git a/dynamo/preprocessing/dynast.py b/dynamo/preprocessing/dynast.py index d71f0a1f0..4aca30b10 100644 --- a/dynamo/preprocessing/dynast.py +++ b/dynamo/preprocessing/dynast.py @@ -69,9 +69,7 @@ def lambda_correction( elif np.count_nonzero([has_l, has_n]): datatype = "labeling" else: - raise ValueError( - "the adata object has to include labeling layers." - ) + raise ValueError("the adata object has to include labeling layers.") logger.info(f"the data type identified is {datatype}", indent_level=2) diff --git a/dynamo/preprocessing/external/__init__.py b/dynamo/preprocessing/external/__init__.py index d6551b301..bb73ed945 100644 --- a/dynamo/preprocessing/external/__init__.py +++ b/dynamo/preprocessing/external/__init__.py @@ -11,4 +11,4 @@ "select_genes_by_pearson_residuals", "harmony_debatch", "integrate", -] \ No newline at end of file +] diff --git a/dynamo/preprocessing/external/integration.py b/dynamo/preprocessing/external/integration.py index b7529b38a..e2b4da139 100644 --- a/dynamo/preprocessing/external/integration.py +++ b/dynamo/preprocessing/external/integration.py @@ -7,6 +7,7 @@ # Convert sparse matrix to dense matrix. to_dense_matrix = lambda X: np.array(X.todense()) if isspmatrix(X) else np.asarray(X) + def integrate( adatas: List[AnnData], batch_key: str = "slices", @@ -78,6 +79,7 @@ def integrate( return integrated_adata + def harmony_debatch( adata: AnnData, key: str, @@ -125,4 +127,4 @@ def harmony_debatch( adata.obsm[adjusted_basis] = adjusted_matrix - return adata if copy else None \ No newline at end of file + return adata if copy else None diff --git a/dynamo/preprocessing/external/pearson_residual_recipe.py b/dynamo/preprocessing/external/pearson_residual_recipe.py index 8ae353054..b4361306c 100644 --- a/dynamo/preprocessing/external/pearson_residual_recipe.py +++ b/dynamo/preprocessing/external/pearson_residual_recipe.py @@ -28,6 +28,7 @@ main_logger = LoggerManager.main_logger + # TODO: Use compute_pearson_residuals function to calculate residuals def _highly_variable_pearson_residuals( adata: AnnData, diff --git a/dynamo/preprocessing/external/sctransform.py b/dynamo/preprocessing/external/sctransform.py index b54c90dfd..3bbc12fae 100644 --- a/dynamo/preprocessing/external/sctransform.py +++ b/dynamo/preprocessing/external/sctransform.py @@ -218,6 +218,7 @@ def sctransform_core( """ import multiprocessing import sys + try: from KDEpy import FFTKDE except ImportError: diff --git a/dynamo/preprocessing/gene_selection.py b/dynamo/preprocessing/gene_selection.py index cacf33945..12fbe1db1 100644 --- a/dynamo/preprocessing/gene_selection.py +++ b/dynamo/preprocessing/gene_selection.py @@ -247,7 +247,11 @@ def calc_dispersion_by_svr( adata_ori.uns[key] = {"mean": mean, "cv": cv, "svr_gamma": svr_gamma} prefix = "" if layer == "X" else layer + "_" - (adata.var[prefix + "log_m"], adata.var[prefix + "log_cv"], adata.var[prefix + "score"],) = ( + ( + adata.var[prefix + "log_m"], + adata.var[prefix + "log_cv"], + adata.var[prefix + "score"], + ) = ( np.nan, np.nan, -np.inf, @@ -602,8 +606,8 @@ def select_genes_by_seurat_recipe( chunked_mean, chunked_var = seurat_get_mean_var(layer_mat) - mean[mat_data[1]:mat_data[2]] = chunked_mean - variance[mat_data[1]:mat_data[2]] = chunked_var + mean[mat_data[1] : mat_data[2]] = chunked_mean + variance[mat_data[1] : mat_data[2]] = chunked_var mean, variance, highly_variable_mask = select_genes_by_seurat_dispersion( mean=mean, diff --git a/dynamo/preprocessing/normalization.py b/dynamo/preprocessing/normalization.py index 3e097569e..f6d9d8e7f 100644 --- a/dynamo/preprocessing/normalization.py +++ b/dynamo/preprocessing/normalization.py @@ -20,9 +20,7 @@ main_info_insert_adata_obsm, main_warning, ) -from .utils import ( - merge_adata_attrs, -) +from .utils import merge_adata_attrs def calc_sz_factor( @@ -76,7 +74,9 @@ def calc_sz_factor( """ if initial_dtype is None: - initial_dtype = adata_ori.X.dtype if adata_ori.X.dtype == np.float32 or adata_ori.X.dtype == np.float64 else np.float32 + initial_dtype = ( + adata_ori.X.dtype if adata_ori.X.dtype == np.float32 or adata_ori.X.dtype == np.float64 else np.float32 + ) if use_all_genes_cells: # let us ignore the `inplace` parameter in pandas.Categorical.remove_unused_categories warning. @@ -288,8 +288,10 @@ def normalize( layers = DKM.get_available_layer_keys(adata, layers) if "X" in layers and transform_int_to_float and adata.X.dtype == "int": - main_warning("Transforming adata.X from int to float32 for normalization. If you want to disable this, set " - "`transform_int_to_float` to False.") + main_warning( + "Transforming adata.X from int to float32 for normalization. If you want to disable this, set " + "`transform_int_to_float` to False." + ) adata.X = adata.X.astype("float32") if recalc_sz: @@ -357,8 +359,8 @@ def normalize( for CM_data in CMs_data: CM = CM_data[0] - CM = size_factor_normalize(CM, szfactors[CM_data[1]:CM_data[2]]) - adata.X[CM_data[1]:CM_data[2]] = CM + CM = size_factor_normalize(CM, szfactors[CM_data[1] : CM_data[2]]) + adata.X[CM_data[1] : CM_data[2]] = CM else: main_info_insert_adata_layer("X_" + layer) @@ -370,8 +372,8 @@ def normalize( for CM_data in CMs_data: CM = CM_data[0] - CM = size_factor_normalize(CM, szfactors[CM_data[1]:CM_data[2]]) - adata.layers["X_" + layer][CM_data[1]:CM_data[2]] = CM + CM = size_factor_normalize(CM, szfactors[CM_data[1] : CM_data[2]]) + adata.layers["X_" + layer][CM_data[1] : CM_data[2]] = CM return adata @@ -439,7 +441,7 @@ def sz_util( total_layers: List[str] = None, CM: pd.DataFrame = None, scale_to: Union[float, None] = None, - initial_dtype: type=np.float32, + initial_dtype: type = np.float32, ) -> Tuple[pd.Series, pd.Series]: """Calculate the size factor for a given layer. @@ -496,7 +498,7 @@ def sz_util( chunk_cell_total = CM.sum(axis=1).A1 if issparse(CM) else CM.sum(axis=1) chunk_cell_total += chunk_cell_total == 0 # avoid infinity value after log (0) - cell_total[CM_data[1]:CM_data[2]] = chunk_cell_total + cell_total[CM_data[1] : CM_data[2]] = chunk_cell_total cell_total = cell_total.astype(int) if np.all(cell_total % 1 == 0) else cell_total diff --git a/dynamo/preprocessing/utils.py b/dynamo/preprocessing/utils.py index 7c428b65e..789811306 100755 --- a/dynamo/preprocessing/utils.py +++ b/dynamo/preprocessing/utils.py @@ -703,9 +703,7 @@ def default_layer(adata: anndata.AnnData) -> str: default_layer = ( "M_s" if "M_s" in adata.layers.keys() - else "X_spliced" - if "X_spliced" in adata.layers.keys() - else "spliced" + else "X_spliced" if "X_spliced" in adata.layers.keys() else "spliced" ) else: default_layer = ( diff --git a/dynamo/sample_data.py b/dynamo/sample_data.py index 6260c413b..b0c8fef1e 100755 --- a/dynamo/sample_data.py +++ b/dynamo/sample_data.py @@ -1,8 +1,7 @@ -from typing import Optional - import ntpath import os from pathlib import Path +from typing import Optional from urllib.request import urlretrieve import pandas as pd @@ -50,7 +49,7 @@ def get_adata(url: str, filename: Optional[str] = None) -> Optional[AnnData]: main_info("REPORT THIS: Unknown filetype (" + file_path + ")") adata.var_names_make_unique() - except OSError: + except OSError: # Usually occurs when download is stopped before completion then attempted again. main_info("Corrupted file. Deleting " + file_path + " then redownloading...") # Half-downloaded file cannot be read due to corruption so it's better to delete it. @@ -210,7 +209,7 @@ def BM( def pancreatic_endocrinogenesis( - url: str ="https://github.com/theislab/scvelo_notebooks/raw/master/data/Pancreas/endocrinogenesis_day15.h5ad", + url: str = "https://github.com/theislab/scvelo_notebooks/raw/master/data/Pancreas/endocrinogenesis_day15.h5ad", filename: Optional[str] = None, ) -> AnnData: """Pancreatic endocrinogenesis. Data from scvelo. diff --git a/dynamo/shiny.py b/dynamo/shiny.py index 7e97592ef..922133c1a 100644 --- a/dynamo/shiny.py +++ b/dynamo/shiny.py @@ -1,3 +1,3 @@ """Shiny interactive web application.""" -from .shiny import * \ No newline at end of file +from .shiny import * diff --git a/dynamo/shiny/lap.py b/dynamo/shiny/lap.py index 7261a806c..6ca85c5b2 100644 --- a/dynamo/shiny/lap.py +++ b/dynamo/shiny/lap.py @@ -1,28 +1,28 @@ +import json +import random +from functools import reduce +from pathlib import Path from typing import List, Optional -import json import matplotlib.pyplot as plt import numpy as np import pandas as pd -import random import seaborn as sns from anndata import AnnData -from functools import reduce -from pathlib import Path -from sklearn.metrics import roc_curve, auc +from sklearn.metrics import auc, roc_curve -from .utils import filter_fig -from ..prediction import GeneTrajectory, least_action, least_action_path from ..plot import kinetic_heatmap, streamline_plot from ..plot.utils import get_color_map_from_labels, map2color +from ..prediction import GeneTrajectory, least_action, least_action_path from ..tools import neighbors from ..tools.utils import nearest_neighbors, select_cell from ..vectorfield import rank_genes - +from .utils import filter_fig css_path = Path(__file__).parent / "styles.css" -def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): + +def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData] = None): """The shiny web application of most probable path predictions analyses. The process is equivalent to this tutorial: https://dynamo-release.readthedocs.io/en/latest/notebooks/lap_tutorial/lap_tutorial.html @@ -34,7 +34,7 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): try: import shiny.experimental as x from htmltools import HTML, TagList, div - from shiny import App, Inputs, Outputs, reactive, Session, render, ui + from shiny import App, Inputs, Outputs, Session, reactive, render, ui from shiny.plotutils import brushed_points, near_points except ImportError: raise ImportError("Please install shiny and htmltools before running the web application!") @@ -47,30 +47,36 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ui.panel_main( div("Most probable path predictions", class_="bold-title"), div(HTML("

")), - div("The least action path (LAP) is a principled method that has previously been used in " + div( + "The least action path (LAP) is a principled method that has previously been used in " "theoretical efforts to predict the most probable path a cell will follow during fate " "transition. Specifically, the optimal path between any two cell states (e.g. the fixed " "point of HSCs and that of megakaryocytes) is searched by variating the continuous path " "connecting the source state to the target while minimizing its action and updating the " "associated transition time. The resultant least action path has the highest transition " - "probability and is associated with a particular transition time.", class_="explanation"), + "probability and is associated with a particular transition time.", + class_="explanation", + ), div(HTML("

")), ui.div( x.ui.card( div("Initialization", class_="bold-sectiontitle"), - div("Given the group information and basis, we can visualize the projected velocity " + div( + "Given the group information and basis, we can visualize the projected velocity " "information. The velocity provides us with fundamental insights into cell fate " - "transitions.", class_="explanation"), + "transitions.", + class_="explanation", + ), ui.row( ui.column(6, ui.output_ui("selectize_cells_type_key")), ui.column(6, ui.output_ui("selectize_streamline_basis")), ), x.ui.output_plot("base_streamline_plot"), - div("In the scatter plot, we can choose the fixed points to initialize the LAP analyses. ", - class_="explanation"), - x.ui.output_plot("initialize_searching", click=True, dblclick=True, hover=True, - brush=True), - + div( + "In the scatter plot, we can choose the fixed points to initialize the LAP analyses. ", + class_="explanation", + ), + x.ui.output_plot("initialize_searching", click=True, dblclick=True, hover=True, brush=True), ui.row( ui.column( 6, @@ -91,7 +97,8 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ui.output_table("fixed_points"), ui.input_action_button("reset_fixed_points", "Reset", class_="btn-primary"), ui.input_action_button( - "activate_lap", "Run LAP analyses with identified points", + "activate_lap", + "Run LAP analyses with identified points", class_="btn-primary", ), ), @@ -99,14 +106,21 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ), x.ui.card( div("LAP results", class_="bold-sectiontitle"), - div("After calculating LAPs for all possible cell type transition pairs, the results will " - "be visualized in this section.", class_="explanation"), - div("Barplot of genes' ranking based on the mean squared displacement of the path.", - class_="bold-subtitle"), + div( + "After calculating LAPs for all possible cell type transition pairs, the results will " + "be visualized in this section.", + class_="explanation", + ), + div( + "Barplot of genes' ranking based on the mean squared displacement of the path.", + class_="bold-subtitle", + ), ui.row( ui.column( 6, - ui.input_slider("top_n_genes", "Top N genes to visualize: ", min=0, max=20, value=10), + ui.input_slider( + "top_n_genes", "Top N genes to visualize: ", min=0, max=20, value=10 + ), ), ui.column(6, ui.output_ui("selectize_gene_barplot_transition")), ), @@ -118,7 +132,10 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ui.input_slider( "n_lap_visualize_transition", "Number of transitions to visualize: ", - min=1, max=20, value=1), + min=1, + max=20, + value=1, + ), ui.output_ui("selectize_lap_visualize_transition"), ), ui.column( @@ -128,24 +145,32 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ), div("Barplot of the LAP time starting from given cell type", class_="bold-subtitle"), ui.input_switch("if_global_lap_time_rank", "Display global LAP time", value=False), - div("Note: If enabled, the rank of all transitions will be displayed. Else, will rank " - "the transitions with given starting cell type.", class_="explanation"), + div( + "Note: If enabled, the rank of all transitions will be displayed. Else, will rank " + "the transitions with given starting cell type.", + class_="explanation", + ), ui.output_ui("selectize_barplot_start_genes"), x.ui.output_plot("tfs_barplot"), div( "Heatmap of LAP actions (left) and LAP time (right) matrices of pairwise cell fate conversions", - class_="bold-subtitle" + class_="bold-subtitle", ), x.ui.output_plot("pairwise_cell_fate_heatmap"), - div("Kinetics heatmap of gene expression dynamics along the LAP", - class_="bold-subtitle"), + div("Kinetics heatmap of gene expression dynamics along the LAP", class_="bold-subtitle"), ui.row( ui.column( 3, ui.output_ui("selectize_kinetic_heatmap_transition"), ui.output_ui("selectize_lap_heatmap_basis"), ui.output_ui("selectize_lap_heatmap_adj_key"), - ui.input_slider("heatmap_n_genes", "number of genes to visualize in the plot: ", min=1, max=200, value=50), + ui.input_slider( + "heatmap_n_genes", + "number of genes to visualize in the plot: ", + min=1, + max=200, + value=50, + ), ), ui.column( 9, @@ -156,15 +181,17 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ), ), ), - ui.nav( "Evaluate TF rankings based on LAP analyses", ui.panel_main( div("Evaluate TF rankings based on LAP analyses.", class_="bold-title"), div(HTML("

")), - div("After we obtained the TFs ranking based on the mean square displacement, we are able to " + div( + "After we obtained the TFs ranking based on the mean square displacement, we are able to " "evaluate rankings by comparing with known transcription factors that enable the successful " - "cell fate conversion.", class_="explanation"), + "cell fate conversion.", + class_="explanation", + ), div(HTML("

")), ui.div( x.ui.card( @@ -173,21 +200,28 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): "Visualization of transition information and known TFs", class_="bold-subtitle", ), - div("Here we need to manually add known TFs and transition type to all possible transition " - "pairs.", class_="explanation"), + div( + "Here we need to manually add known TFs and transition type to all possible transition " + "pairs.", + class_="explanation", + ), ui.row( ui.column( 3, div("Add known TFs", class_="bold-subtitle"), div("First, choose target transition and input known TFs.", class_="explanation"), ui.output_ui("selectize_known_tf_transition"), - ui.input_text("known_tf", "Known TFs: ", - placeholder="e.g. GATA1,GATA2,ZFPM1,GFI1B,FLI1,NFE2"), - div("Next, specify the keys to extract and save the TFs and rank. The TFs and rank " + ui.input_text( + "known_tf", "Known TFs: ", placeholder="e.g. GATA1,GATA2,ZFPM1,GFI1B,FLI1,NFE2" + ), + div( + "Next, specify the keys to extract and save the TFs and rank. The TFs and rank " "will be saved in dictionary[main key][TF key] and " "dictionary[main key][TF rank key]. We don't need to change the default value " "unless there are more than one set of known genes to analyze in one " - "transition.", class_="explanation"), + "transition.", + class_="explanation", + ), ui.input_text("known_tf_key", "Key to save TFs: ", value="TFs"), ui.input_text("known_tf_rank_key", "Key to save TFs rank: ", value="TFs_rank"), ui.output_ui("input_reprog_mat_main_key"), @@ -197,14 +231,20 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): ), ui.output_ui("selectize_reprog_mat_type"), ui.input_action_button( - "activate_add_reprog_info", "Add transition info", class_="btn-primary", + "activate_add_reprog_info", + "Add transition info", + class_="btn-primary", ), - div("The known TF dictionary will be visualized on the right. After we add known " + div( + "The known TF dictionary will be visualized on the right. After we add known " "TFs for all transitions, we can click the following button to start the " - "analyses", class_="explanation"), + "analyses", + class_="explanation", + ), ui.input_action_button( - "activate_plot_priority_scores_and_ROC", "Analyze with current TFs", - class_="btn-primary" + "activate_plot_priority_scores_and_ROC", + "Analyze with current TFs", + class_="btn-primary", ), ), ui.column( @@ -219,15 +259,21 @@ def lap_web_app(input_adata: AnnData, tfs_data: Optional[AnnData]=None): "Plotting priority scores of known TFs for specific transition type", class_="bold-subtitle", ), - div("The ranking of known TFs will be converted to a priority score, simply defined as " - "1 - ( rank / number of TFs ).", class_="explanation"), + div( + "The ranking of known TFs will be converted to a priority score, simply defined as " + "1 - ( rank / number of TFs ).", + class_="explanation", + ), ui.output_ui("selectize_reprog_query_type"), x.ui.output_plot("plot_priority_scores"), div("ROC curve analyses of TF priorization of the LAP predictions", class_="bold-subtitle"), - div("We can evaluate the TF ranking through ROC of LAP TF prioritization predictions using " - "all known genes of all known transitions as the gold standard.", class_="explanation"), + div( + "We can evaluate the TF ranking through ROC of LAP TF prioritization predictions using " + "all known genes of all known transitions as the gold standard.", + class_="explanation", + ), ui.input_text("roc_tf_key", "Key of TFs for ROC plot: ", value="TFs"), - x.ui.output_plot("tf_roc_curve") + x.ui.output_plot("tf_roc_curve"), ), ), ), @@ -262,11 +308,11 @@ def server(input: Inputs, output: Outputs, session: Session): @render.ui def selectize_cells_type_key(): return ui.input_selectize( - "cells_type_key", - "Key representing the group information, most of the time it is related to cell type: ", - choices=list(adata.obs.keys()), - selected="cell_type", - ) + "cells_type_key", + "Key representing the group information, most of the time it is related to cell type: ", + choices=list(adata.obs.keys()), + selected="cell_type", + ) @output @render.ui @@ -371,12 +417,15 @@ def base_streamline_plot(): save_show_or_return="return", ) - df = pd.DataFrame({ - "x": adata.obsm["X_" + input.streamline_basis()][:, 0], - "y": adata.obsm["X_" + input.streamline_basis()][:, 1], - "Cell_Type": adata.obs[input.cells_type_key().split(",")[0]], - "Cell Names": adata.obs_names, - }, index=adata.obs_names) + df = pd.DataFrame( + { + "x": adata.obsm["X_" + input.streamline_basis()][:, 0], + "y": adata.obsm["X_" + input.streamline_basis()][:, 1], + "Cell_Type": adata.obs[input.cells_type_key().split(",")[0]], + "Cell Names": adata.obs_names, + }, + index=adata.obs_names, + ) coordinates_df.set(df) @@ -478,11 +527,11 @@ def _(): @render.ui def selectize_gene_barplot_transition(): return ui.input_selectize( - "gene_barplot_transition", - "Specific transition to visualize:", - choices=list(transition_graph().keys()), - selected=list(transition_graph().keys())[0], - ) + "gene_barplot_transition", + "Specific transition to visualize:", + choices=list(transition_graph().keys()), + selected=list(transition_graph().keys())[0], + ) @output @render.plot() @@ -491,7 +540,7 @@ def genes_barplot(): sns.barplot( y="all", x="all_values", - data=transition_graph()[input.gene_barplot_transition()]["ranking"][:input.top_n_genes()], + data=transition_graph()[input.gene_barplot_transition()]["ranking"][: input.top_n_genes()], dodge=False, ).set( title="Genes rank for transition: " + input.gene_barplot_transition(), @@ -521,7 +570,10 @@ def selectize_lap_visualize_transition(): @render.plot() def plot_lap(): if input.activate_lap() > 0: - paths = [getattr(input, "lap_visualize_transition_" + str(i))() for i in range(input.n_lap_visualize_transition())] + paths = [ + getattr(input, "lap_visualize_transition_" + str(i))() + for i in range(input.n_lap_visualize_transition()) + ] fig, ax = plt.subplots(figsize=(5, 4)) ax_list = streamline_plot( adata, @@ -644,15 +696,14 @@ def lap_kinetic_heatmap(): _adata = adata.copy() _adata.uns["LAP_umap"] = transition_graph()[path]["LAP_umap"] _adata.uns["LAP_pca"] = transition_graph()[path]["LAP_pca"] - is_human_tfs = [gene in tfs_names for gene in - _adata.var_names[_adata.var.use_for_transition]] + is_human_tfs = [gene in tfs_names for gene in _adata.var_names[_adata.var.use_for_transition]] human_genes = _adata.var_names[_adata.var.use_for_transition][is_human_tfs] sns.set(font_scale=0.8) sns_heatmap = kinetic_heatmap( _adata, basis=input.lap_heatmap_basis(), mode="lap", - genes=human_genes[:input.heatmap_n_genes()], + genes=human_genes[: input.heatmap_n_genes()], project_back_to_high_dim=True, save_show_or_return="return", color_map="bwr", @@ -705,7 +756,7 @@ def selectize_reprog_mat_type(): @output @render.ui def input_reprog_mat_main_key(): - return ui.input_text("reprog_mat_main_key", "Main Key: ", value=input.known_tf_transition()), + return (ui.input_text("reprog_mat_main_key", "Main Key: ", value=input.known_tf_transition()),) @output @render.text @@ -721,8 +772,8 @@ def add_reprog_info(): cur_transition_graph[transition][input.known_tf_key()] = input.known_tf().split(",") cur_transition_graph[transition][input.known_tf_rank_key()] = [ - all_tfs.index(key) if key in true_tf_list else -1 for key in - cur_transition_graph[transition][input.known_tf_key()] + all_tfs.index(key) if key in true_tf_list else -1 + for key in cur_transition_graph[transition][input.known_tf_key()] ] transition_graph.set(cur_transition_graph) @@ -760,14 +811,17 @@ def plot_priority_scores(): all_genes = reduce(lambda a, b: a + b, reprogramming_mat_df.loc["genes", :]) all_rank = reduce(lambda a, b: a + b, reprogramming_mat_df.loc["rank", :]) all_keys = np.repeat( - np.array(list(reprogramming_mat_dict().keys())), [len(i) for i in reprogramming_mat_df.loc["genes", :]] + np.array(list(reprogramming_mat_dict().keys())), + [len(i) for i in reprogramming_mat_df.loc["genes", :]], ) all_types = np.repeat( np.array([v["type"] for v in reprogramming_mat_dict().values()]), [len(i) for i in reprogramming_mat_df.loc["genes", :]], ) - reprogramming_mat_df_p = pd.DataFrame({"genes": all_genes, "rank": all_rank, "transition": all_keys, "type": all_types}) + reprogramming_mat_df_p = pd.DataFrame( + {"genes": all_genes, "rank": all_rank, "transition": all_keys, "type": all_types} + ) reprogramming_mat_df_p = reprogramming_mat_df_p.query("rank > -1") reprogramming_mat_df_p["rank"] /= 133 reprogramming_mat_df_p["rank"] = 1 - reprogramming_mat_df_p["rank"] @@ -796,8 +850,12 @@ def plot_priority_scores(): for i in range(reprogramming_mat_df_p_subset.shape[0]): annote_text = genes[i] # STK_ID ax.annotate( - annote_text, xy=(rank[i], transition[i]), xytext=(0, 3), textcoords="offset points", ha="center", - va="bottom" + annote_text, + xy=(rank[i], transition[i]), + xytext=(0, 3), + textcoords="offset points", + ha="center", + va="bottom", ) plt.axvline(0.8, linestyle="--", lw=0.5) @@ -833,14 +891,14 @@ def tf_roc_curve(): all_ranks_df = pd.concat([rank_dict for rank_dict in all_ranks_dict.values()]) target_ranking = all_ranks_dict[ - list(transition_graph().keys())[0].split("->")[0] + - "_" + - list(transition_graph().keys())[0].split("->")[1] + - "_ranking" - ] + list(transition_graph().keys())[0].split("->")[0] + + "_" + + list(transition_graph().keys())[0].split("->")[1] + + "_ranking" + ] all_ranks_df["priority_score"] = ( - 1 - np.tile(np.arange(target_ranking.shape[0]), len(all_ranks_dict)) / target_ranking.shape[0] + 1 - np.tile(np.arange(target_ranking.shape[0]), len(all_ranks_dict)) / target_ranking.shape[0] ) cls = all_ranks_df["known_TF"].astype(int) @@ -864,4 +922,4 @@ def tf_roc_curve(): return filter_fig(plt.gcf()) app = App(app_ui, server, debug=True) - app.run() \ No newline at end of file + app.run() diff --git a/dynamo/shiny/perturbation.py b/dynamo/shiny/perturbation.py index c9031a7dd..e86132ca8 100644 --- a/dynamo/shiny/perturbation.py +++ b/dynamo/shiny/perturbation.py @@ -1,12 +1,12 @@ +from pathlib import Path + import matplotlib.pyplot as plt import pandas as pd from anndata import AnnData -from pathlib import Path -from .utils import filter_fig from ..plot import streamline_plot from ..prediction import perturbation - +from .utils import filter_fig css_path = Path(__file__).parent / "styles.css" @@ -21,7 +21,7 @@ def perturbation_web_app(input_adata: AnnData): try: import shiny.experimental as x from htmltools import HTML, div - from shiny import App, Inputs, Outputs, reactive, Session, render, ui + from shiny import App, Inputs, Outputs, Session, reactive, render, ui except ImportError: raise ImportError("Please install shiny and htmltools before running the web application!") @@ -33,9 +33,7 @@ def perturbation_web_app(input_adata: AnnData): div("Perturbation Setting", class_="bold-subtitle"), ui.input_slider("n_genes", "Number of genes to perturb:", min=1, max=5, value=1), ui.output_ui("selectize_genes"), - ui.input_action_button( - "activate_perturbation", "Run perturbation", class_="btn-primary" - ), + ui.input_action_button("activate_perturbation", "Run perturbation", class_="btn-primary"), value="Perturbation", ), x.ui.accordion_panel( @@ -52,14 +50,16 @@ def perturbation_web_app(input_adata: AnnData): ui.div( div("in silico perturbation", class_="bold-title"), div(HTML("

")), - div("Perturbation function in Dynamo can be used to either upregulating or suppressing a single or " + div( + "Perturbation function in Dynamo can be used to either upregulating or suppressing a single or " "multiple genes in a particular cell or across all cells to perform in silico genetic perturbation. " "Dynamo first calculates the perturbation velocity vector from the input expression value and " "the analytical Jacobian from our vector field function Because Jacobian encodes the instantaneous " "changes of velocity of any genes after increasing any other gene, the output vector will produce the " "perturbation effect vector after propagating the genetic perturbation through the gene regulatory " "network. Then Dynamo projects the perturbation vector to low dimensional space.", - class_="explanation"), + class_="explanation", + ), div(HTML("

")), x.ui.card( div("Streamline Plot", class_="bold-subtitle"), @@ -95,11 +95,11 @@ def selectize_color(): @render.ui def selectize_basis(): return ui.input_selectize( - "streamline_basis", - "The perturbation output as the basis of plot: ", - choices=[b[2:] if b.startswith("X_") else b for b in list(adata.obsm.keys())], - selected="umap", - ) + "streamline_basis", + "The perturbation output as the basis of plot: ", + choices=[b[2:] if b.startswith("X_") else b for b in list(adata.obsm.keys())], + selected="umap", + ) @output @render.ui @@ -116,10 +116,11 @@ def selectize_genes(): ui.input_slider( "expression_" + str(i), "Expression value to encode the genetic perturbation: ", - min=-200, max=200, value=-100, + min=-200, + max=200, + value=-100, ), ), - ) return ui_list @@ -128,8 +129,9 @@ def selectize_genes(): def base_plot(): color = [getattr(input, "base_color_" + str(i))() for i in range(input.n_colors())] - axes_list = streamline_plot(adata, color=color, basis=input.streamline_basis(), - save_show_or_return="return") + axes_list = streamline_plot( + adata, color=color, basis=input.streamline_basis(), save_show_or_return="return" + ) return filter_fig(plt.gcf()) @@ -146,7 +148,9 @@ def activate_perturbation(): def perturbation_plot(): if input.activate_perturbation() > 0: color = [getattr(input, "base_color_" + str(i))() for i in range(input.n_colors())] - axes_list = streamline_plot(adata, color=color, basis=input.streamline_basis() + "_perturbation", save_show_or_return="return") + axes_list = streamline_plot( + adata, color=color, basis=input.streamline_basis() + "_perturbation", save_show_or_return="return" + ) return filter_fig(plt.gcf()) diff --git a/dynamo/shiny/utils.py b/dynamo/shiny/utils.py index 550189402..858f9dd6c 100644 --- a/dynamo/shiny/utils.py +++ b/dynamo/shiny/utils.py @@ -6,4 +6,4 @@ def filter_fig(fig): for ax in fig.get_axes(): if ax.get_subplotspec() is None: ax.remove() - return fig \ No newline at end of file + return fig diff --git a/dynamo/simulation/ODE.py b/dynamo/simulation/ODE.py index 178604fc5..7bb1c8388 100755 --- a/dynamo/simulation/ODE.py +++ b/dynamo/simulation/ODE.py @@ -75,9 +75,7 @@ def hill_act_grad(x: float, A: float, K: float, n: float, g: float = 0) -> float return A * n * Kd * x ** (n - 1) / (Kd + x**n) ** 2 - g -def toggle( - ab: Union[np.ndarray, Tuple[float, float]], beta: float = 5, gamma: float = 1, n: int = 2 -) -> np.ndarray: +def toggle(ab: Union[np.ndarray, Tuple[float, float]], beta: float = 5, gamma: float = 1, n: int = 2) -> np.ndarray: """Calculates the right-hand side (RHS) of the differential equations for the toggle switch system. Args: @@ -234,7 +232,7 @@ def jacobian_bifur2genes( K: List[float] = [0.5, 0.5], m: List[float] = [4, 4], n: List[float] = [4, 4], - gamma: List[float] = [1, 1] + gamma: List[float] = [1, 1], ) -> np.ndarray: """The Jacobian of the toggle switch ODE model. @@ -497,9 +495,7 @@ def neurongenesis( dx[:, 5] = a * (x[:, 0] ** n) / (1 + x[:, 0] ** n + x[:, 1] ** n) - k * x[:, 5] dx[:, 6] = a_e * (eta**n * x[:, 5] ** n) / (1 + eta**n * x[:, 5] ** n + x[:, 7] ** n) - k * x[:, 6] dx[:, 7] = a_e * (eta**n * x[:, 5] ** n) / (1 + x[:, 6] ** n + eta**n * x[:, 5] ** n) - k * x[:, 7] - dx[:, 8] = ( - a * (eta**n * x[:, 5] ** n * x[:, 6] ** n) / (1 + eta**n * x[:, 5] ** n * x[:, 6] ** n) - k * x[:, 8] - ) + dx[:, 8] = a * (eta**n * x[:, 5] ** n * x[:, 6] ** n) / (1 + eta**n * x[:, 5] ** n * x[:, 6] ** n) - k * x[:, 8] dx[:, 9] = a * (x[:, 7] ** n) / (1 + x[:, 7] ** n) - k * x[:, 9] dx[:, 10] = a_e * (x[:, 8] ** n) / (1 + x[:, 8] ** n) - k * x[:, 10] dx[:, 11] = a * (eta_m**n * x[:, 7] ** n) / (1 + eta_m**n * x[:, 7] ** n) - k * x[:, 11] diff --git a/dynamo/simulation/bif_os_inclusive_sim.py b/dynamo/simulation/bif_os_inclusive_sim.py index 394c07a70..84933c9ff 100755 --- a/dynamo/simulation/bif_os_inclusive_sim.py +++ b/dynamo/simulation/bif_os_inclusive_sim.py @@ -8,6 +8,7 @@ # Differentiation model class sim_diff: """The differentiation model.""" + def __init__( self, a1: float, @@ -170,6 +171,7 @@ def f_stoich(self) -> np.ndarray: # Oscillator class sim_osc: """The oscillator model.""" + def __init__( self, a1: float, diff --git a/dynamo/simulation/simulate_anndata.py b/dynamo/simulation/simulate_anndata.py index d3c856efc..0cd79af21 100644 --- a/dynamo/simulation/simulate_anndata.py +++ b/dynamo/simulation/simulate_anndata.py @@ -73,6 +73,7 @@ class AnnDataSimulator: """A base anndata simulator class.""" + def __init__( self, reactions: GillespieReactions, @@ -300,6 +301,7 @@ def generate_anndata(self, remove_empty_cells: bool = False) -> anndata.AnnData: class CellularModelSimulator(AnnDataSimulator): """An anndata simulator class handling models with synthesis, splicing (optional), and first-order degrdation reactions.""" + def __init__( self, gene_names: List, @@ -481,6 +483,7 @@ def generate_anndata(self, remove_empty_cells: bool = False) -> anndata.AnnData: class KinLabelingSimulator: """A simulator for kinetic labeling experiments.""" + def __init__( self, simulator: CellularModelSimulator, @@ -640,6 +643,7 @@ def write_to_anndata(self, adata: anndata) -> anndata: class BifurcationTwoGenes(CellularModelSimulator): """Two gene bifurcation model anndata simulator.""" + def __init__( self, param_dict: Dict, @@ -718,6 +722,7 @@ def register_reactions(self, reactions: GillespieReactions) -> None: Returns: The reaction object after registration. """ + def rate_syn(x, y, gene): activation = hill_act_func( x, self.param_dict["a"][gene], self.param_dict["S"][gene], self.param_dict["m"][gene] @@ -755,6 +760,7 @@ def rate_syn(x, y, gene): class OscillationTwoGenes(CellularModelSimulator): """Two gene oscillation model anndata simulator. This is essentially a predator-prey model, where gene 1 (predator) inhibits gene 2 (prey) and gene 2 activates gene 1.""" + def __init__( self, param_dict: Dict, @@ -828,6 +834,7 @@ def register_reactions(self, reactions: GillespieReactions) -> None: Returns: The reaction object after registration. """ + def rate_syn_1(x, y, gene): activation = hill_act_func( x, self.param_dict["a"][gene], self.param_dict["S"][gene], self.param_dict["m"][gene] @@ -873,6 +880,7 @@ def rate_syn_2(x, y, gene): class Neurongenesis(CellularModelSimulator): """Neurongenesis model anndata simulator from Xiaojie Qiu, et. al, 2012. anndata simulator.""" + def __init__( self, param_dict: Dict, @@ -952,6 +960,7 @@ def register_reactions(self, reactions: GillespieReactions) -> None: Returns: The reaction object after registration. """ + def rate_pax6(x, y, z, gene): a = self.param_dict["a"][gene] K = self.param_dict["K"][gene] diff --git a/dynamo/simulation/utils.py b/dynamo/simulation/utils.py index fc6465346..328885c4a 100644 --- a/dynamo/simulation/utils.py +++ b/dynamo/simulation/utils.py @@ -426,6 +426,7 @@ def simulate_multigene( class CellularSpecies: """A class to register gene and species for easier implemention of simulations.""" + def __init__(self, gene_names: list = []) -> None: """Initialize the CellularSpecies class. @@ -588,6 +589,7 @@ def copy(self): class Reaction: """A class to register reactions for easier implementation of simulations.""" + def __init__( self, substrates: list, @@ -621,6 +623,7 @@ def __init__( class GillespieReactions: """A class to register reactions for easier implementation of Gillespie simulations.""" + def __init__(self, species: CellularSpecies) -> None: """Initialize the GillespieReactions class. diff --git a/dynamo/tools/DDRTree_graph.py b/dynamo/tools/DDRTree_graph.py index a932edcd1..ba4602c6f 100755 --- a/dynamo/tools/DDRTree_graph.py +++ b/dynamo/tools/DDRTree_graph.py @@ -3,11 +3,11 @@ import matplotlib.pyplot as plt import numpy as np from anndata import AnnData -from scipy.sparse import issparse, csr_matrix +from scipy.sparse import csr_matrix, issparse from scipy.sparse.csgraph import minimum_spanning_tree -from .DDRTree import cal_ncenter, DDRTree from ..dynamo_logger import main_info, main_info_insert_adata_uns +from .DDRTree import DDRTree, cal_ncenter def construct_velocity_tree(adata: AnnData, transition_matrix_key: str = "pearson"): @@ -25,12 +25,15 @@ def construct_velocity_tree(adata: AnnData, transition_matrix_key: str = "pearso A directed velocity tree represented as a NumPy array. """ if transition_matrix_key + "_transition_matrix" not in adata.obsp.keys(): - raise KeyError("Transition matrix not found in anndata. Please call cell_velocities() before constructing " - "velocity tree") + raise KeyError( + "Transition matrix not found in anndata. Please call cell_velocities() before constructing " "velocity tree" + ) if "cell_order" not in adata.uns.keys(): - raise KeyError("Cell order information not found in anndata. Please call order_cells() before constructing " - "velocity tree.") + raise KeyError( + "Cell order information not found in anndata. Please call order_cells() before constructing " + "velocity tree." + ) main_info("Constructing velocity tree...") @@ -168,16 +171,20 @@ def _compute_center_transition_matrix(transition_matrix: Union[csr_matrix, np.nd continue indices_a = clusters[a] indices_b = clusters[b] - q = np.sum( - R[indices_a, a][:, np.newaxis] * - R[indices_b, b].T[np.newaxis, :] * - transition_matrix[indices_a[:, None], indices_b] - ) if (indices_a.shape[0] > 0 and indices_b.shape[0] > 0) else 0 + q = ( + np.sum( + R[indices_a, a][:, np.newaxis] + * R[indices_b, b].T[np.newaxis, :] + * transition_matrix[indices_a[:, None], indices_b] + ) + if (indices_a.shape[0] > 0 and indices_b.shape[0] > 0) + else 0 + ) totals[a] += q transition[a, b] = q totals = totals.reshape(-1, 1) - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): res = transition / totals res[np.isinf(res)] = 0 res = np.nan_to_num(res) @@ -195,7 +202,7 @@ def _calculate_segment_probability(transition_matrix: np.ndarray, segments: np.n The probability for each segment. """ - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): log_transition_matrix = np.log1p(transition_matrix) log_transition_matrix[np.isinf(log_transition_matrix)] = 0 log_transition_matrix = np.nan_to_num(log_transition_matrix) @@ -217,7 +224,7 @@ def _get_edges(orders: Union[np.ndarray, List], parents: Optional[Union[np.ndarr if parents: segments = [(p, o) for p, o in zip(parents, orders) if p != -1] else: - segments = [(orders[i-1], orders[i]) for i in range(1, len(orders))] + segments = [(orders[i - 1], orders[i]) for i in range(1, len(orders))] return segments @@ -271,7 +278,8 @@ def _get_all_segments(orders: Union[np.ndarray, List], parents: Union[np.ndarray element_counts = Counter(parents) bifurcation_nodes = [ - node for node, count in element_counts.items() + node + for node, count in element_counts.items() if count > 1 and node != -1 and not (count == 2 and parents_dict == -1) ] root_nodes = [node for node in orders if parents_dict[node] == -1] diff --git a/dynamo/tools/Markov.py b/dynamo/tools/Markov.py index 5f901f23d..c54b34674 100755 --- a/dynamo/tools/Markov.py +++ b/dynamo/tools/Markov.py @@ -84,7 +84,7 @@ def grid_velocity_filter( min_mass: Optional[float] = None, autoscale: bool = False, adjust_for_stream: bool = True, - V_threshold: Optional[float]=None, + V_threshold: Optional[float] = None, ) -> Tuple: """Filter the grid velocities, adjusting for streamlines if needed. @@ -217,13 +217,14 @@ def velocity_on_grid( class MarkovChain: """Base class for all Markov Chain implementation.""" + def __init__( self, P: Optional[np.ndarray] = None, eignum: Optional[int] = None, check_norm: bool = True, sumto: int = 1, - tol: float = 1e-3 + tol: float = 1e-3, ): """Initialize the MarkovChain instance. @@ -322,12 +323,7 @@ def make_p0(self, init_states: np.ndarray) -> np.ndarray: return p0 def is_normalized( - self, - P: Optional[np.ndarray] = None, - tol: float = 1e-3, - sumto: int = 1, - axis: int = 0, - ignore_nan: bool = True + self, P: Optional[np.ndarray] = None, tol: float = 1e-3, sumto: int = 1, axis: int = 0, ignore_nan: bool = True ) -> bool: """check if the matrix is properly normalized up to `tol`. @@ -358,11 +354,12 @@ def __reset__(self) -> None: class KernelMarkovChain(MarkovChain): """KernelMarkovChain class represents a Markov chain with kernel-based transition probabilities.""" + def __init__( self, P: Optional[np.ndarray] = None, Idx: Optional[np.ndarray] = None, - n_recurse_neighbors: Optional[int] = None + n_recurse_neighbors: Optional[int] = None, ): """Initialize the KernelMarkovChain instance. @@ -418,7 +415,7 @@ def fit( if neighbor_idx is None: neighbor_idx, _ = k_nearest_neighbors( X, - k=k-1, + k=k - 1, exclude_self=False, pynn_rand_state=19491001, ) @@ -621,6 +618,7 @@ def compute_theta(self, p_st: Optional[np.ndarray] = None) -> sp.csr_matrix: class DiscreteTimeMarkovChain(MarkovChain): """DiscreteTimeMarkovChain class represents a discrete-time Markov chain.""" + def __init__(self, P: Optional[np.ndarray] = None, eignum: Optional[int] = None, sumto: int = 1, **kwargs): """Initialize the DiscreteTimeMarkovChain instance. @@ -866,6 +864,7 @@ def simulate_random_walk(self, init_idx: int, num_steps: int) -> np.ndarray: class ContinuousTimeMarkovChain(MarkovChain): """ContinuousTimeMarkovChain class represents a continuous-time Markov chain.""" + def __init__(self, P: Optional[np.ndarray] = None, eignum: Optional[int] = None, **kwargs): """Initialize the ContinuousTimeMarkovChain instance. @@ -1101,7 +1100,9 @@ def compute_mean_first_passage_time(self, p0: np.ndarray, target: int, sinks: np mfpt = -(k @ (K_inv @ K_inv @ p0_)) / (k @ (K_inv @ p0_)) return mfpt - def compute_hitting_time(self, p_st: Optional[np.ndarray] = None, return_Z: bool = False) -> Union[Tuple, np.ndarray]: + def compute_hitting_time( + self, p_st: Optional[np.ndarray] = None, return_Z: bool = False + ) -> Union[Tuple, np.ndarray]: """Compute the hitting time of the continuous-time Markov chain. Args: @@ -1123,7 +1124,9 @@ def compute_hitting_time(self, p_st: Optional[np.ndarray] = None, return_Z: bool else: return H - def diffusion_map_embedding(self, n_dims: int = 2, t: Union[int, float] = 1, n_pca_dims: Optional[int] = None) -> np.ndarray: + def diffusion_map_embedding( + self, n_dims: int = 2, t: Union[int, float] = 1, n_pca_dims: Optional[int] = None + ) -> np.ndarray: """Perform diffusion map embedding for the continuous-time Markov chain. Args: @@ -1319,7 +1322,9 @@ def compute_drift_kernel(x: np.ndarray, v: np.ndarray, X: np.ndarray, inv_s: Uni # @jit(nopython=True) -def compute_drift_local_kernel(x: np.ndarray, v: np.ndarray, X: np.ndarray, inv_s: Union[np.ndarray, float]) -> np.ndarray: +def compute_drift_local_kernel( + x: np.ndarray, v: np.ndarray, X: np.ndarray, inv_s: Union[np.ndarray, float] +) -> np.ndarray: """Compute a local kernel representing the drift based on input data and parameters. Args: @@ -1337,7 +1342,7 @@ def compute_drift_local_kernel(x: np.ndarray, v: np.ndarray, X: np.ndarray, inv_ D = X - x dists = np.zeros(n) vds = np.zeros(n) - for (i, d) in enumerate(D): + for i, d in enumerate(D): dists[i] = np.linalg.norm(d) if dists[i] > 0: vds[i] = v.dot(d) / dists[i] @@ -1446,7 +1451,7 @@ def graphize_velocity( nbrs_idx: Optional[list] = None, k: int = 30, normalize_v: bool = False, - E_func: Optional[Union[Callable, str]] = None + E_func: Optional[Union[Callable, str]] = None, ) -> Tuple: """The function generates a graph based on the velocity data. The flow from i- to j-th node is returned as the edge matrix E[i, j], and E[i, j] = -E[j, i]. diff --git a/dynamo/tools/__init__.py b/dynamo/tools/__init__.py index 5a33085a5..ff3f6f7c7 100755 --- a/dynamo/tools/__init__.py +++ b/dynamo/tools/__init__.py @@ -33,11 +33,13 @@ mnn, neighbors, ) +from .DDRTree import DDRTree, cal_ncenter # Pseudotime related from .DDRTree_graph import construct_velocity_tree, directed_pg -from .DDRTree import DDRTree, cal_ncenter -from .pseudotime import order_cells + +# deprecated functions +from .deprecated import construct_velocity_tree_py # dimension reduction related from .dimension_reduction import reduceDimension # , run_umap @@ -77,6 +79,7 @@ # vector field related from .metric_velocity import cell_wise_confidence, gene_wise_confidence from .moments import calc_1nd_moment, calc_2nd_moment, moments +from .pseudotime import order_cells from .pseudotime_velocity import pseudotime_velocity from .psl import psl @@ -90,7 +93,14 @@ ) # Sampling methods -from .sampling import TRNET, lhsclassic, sample, sample_by_kmeans, sample_by_velocity, trn +from .sampling import ( + TRNET, + lhsclassic, + sample, + sample_by_kmeans, + sample_by_velocity, + trn, +) from .utils import ( AnnDataPredicate, cell_norm, @@ -111,8 +121,3 @@ scv_dyn_convertor, vlm_to_adata, ) - -# deprecated functions -from .deprecated import ( - construct_velocity_tree_py, -) diff --git a/dynamo/tools/cell_velocities.py b/dynamo/tools/cell_velocities.py index 0a3f82acb..38482dbeb 100755 --- a/dynamo/tools/cell_velocities.py +++ b/dynamo/tools/cell_velocities.py @@ -16,7 +16,12 @@ from ..configuration import DKM from ..dynamo_logger import LoggerManager, main_info, main_warning from ..utils import areinstance, expr_to_pca -from .connectivity import generate_neighbor_keys, adj_to_knn, check_and_recompute_neighbors, construct_mapper_umap +from .connectivity import ( + adj_to_knn, + check_and_recompute_neighbors, + construct_mapper_umap, + generate_neighbor_keys, +) from .dimension_reduction import reduceDimension from .graph_calculus import calc_gaussian_weight, fp_operator, graphize_velocity from .Markov import ContinuousTimeMarkovChain, KernelMarkovChain, velocity_on_grid @@ -451,7 +456,13 @@ def cell_velocities( if calc_rnd_vel: permute_rows_nsign(V) - (T_rnd, delta_X_rnd, X_grid_rnd, V_grid_rnd, D_rnd,) = kernels_from_velocyto_scvelo( + ( + T_rnd, + delta_X_rnd, + X_grid_rnd, + V_grid_rnd, + D_rnd, + ) = kernels_from_velocyto_scvelo( X, X_embedding, V, diff --git a/dynamo/tools/clustering.py b/dynamo/tools/clustering.py index 8136e0671..b0b10894d 100644 --- a/dynamo/tools/clustering.py +++ b/dynamo/tools/clustering.py @@ -13,8 +13,8 @@ from ..configuration import DKM from ..dynamo_logger import main_info from ..preprocessing.normalization import calc_sz_factor, normalize -from ..preprocessing.QC import filter_genes_by_outliers as filter_genes from ..preprocessing.pca import pca +from ..preprocessing.QC import filter_genes_by_outliers as filter_genes from ..preprocessing.transform import log1p from ..utils import LoggerManager, copy_adata from .connectivity import generate_neighbor_keys, neighbors diff --git a/dynamo/tools/deprecated.py b/dynamo/tools/deprecated.py index c278b1fc4..643858b83 100644 --- a/dynamo/tools/deprecated.py +++ b/dynamo/tools/deprecated.py @@ -21,6 +21,8 @@ from tqdm import tqdm from ..dynamo_logger import main_info, main_warning +from ..estimation.csc.velocity import Velocity, ss_estimation +from ..estimation.tsc.utils_moments import moments from .DDRTree import DDRTree from .moments import calc_1nd_moment, strat_mom from .utils import ( @@ -32,8 +34,6 @@ set_param_ss, set_velocity, ) -from ..estimation.tsc.utils_moments import moments -from ..estimation.csc.velocity import ss_estimation, Velocity def deprecated(func): @@ -1837,7 +1837,10 @@ def moment_model(adata, subset_adata, _group, cur_grp, log_unnormalized, tkey): else: if log_unnormalized and "X_total" not in subset_adata.layers.keys(): if issparse(subset_adata.layers["total"]): - (subset_adata.layers["new"].data, subset_adata.layers["total"].data,) = ( + ( + subset_adata.layers["new"].data, + subset_adata.layers["total"].data, + ) = ( np.log1p(subset_adata.layers["new"].data), np.log1p(subset_adata.layers["total"].data), ) @@ -1866,7 +1869,7 @@ def moment_model(adata, subset_adata, _group, cur_grp, log_unnormalized, tkey): return adata, Est, t_ind -#--------------------------------------------------------------------------------------------------- +# --------------------------------------------------------------------------------------------------- # deprecated clustering.py def infomap( adata: AnnData, @@ -1880,7 +1883,7 @@ def infomap( selected_cell_subset: Union[List[int], List[str], None] = None, directed: bool = False, copy: bool = False, - **kwargs + **kwargs, ) -> AnnData: """Apply infomap community detection algorithm to cluster adata. @@ -1909,4 +1912,4 @@ def infomap( An updated AnnData object if `copy` is set to be true. """ - raise NotImplementedError("infomap algorithm has been deprecated.") \ No newline at end of file + raise NotImplementedError("infomap algorithm has been deprecated.") diff --git a/dynamo/tools/dimension_reduction.py b/dynamo/tools/dimension_reduction.py index a586d85e4..449c4b7d2 100755 --- a/dynamo/tools/dimension_reduction.py +++ b/dynamo/tools/dimension_reduction.py @@ -103,7 +103,9 @@ def reduceDimension( conn_key, dist_key, neighbor_key = generate_neighbor_keys(neighbor_result_prefix) if enforce or not has_basis: - logger.info(f"[{reduction_method.upper()}] using {basis} with n_pca_components = {n_pca_components}", indent_level=1) + logger.info( + f"[{reduction_method.upper()}] using {basis} with n_pca_components = {n_pca_components}", indent_level=1 + ) adata = run_reduce_dim( adata, X_data, diff --git a/dynamo/tools/dynamics.py b/dynamo/tools/dynamics.py index c4c13611a..39d505a1b 100755 --- a/dynamo/tools/dynamics.py +++ b/dynamo/tools/dynamics.py @@ -27,8 +27,8 @@ from ..estimation.csc.utils_velocity import solve_alpha_2p_mat from ..estimation.csc.velocity import Velocity, fit_linreg, ss_estimation from ..estimation.tsc.estimation_kinetic import * -from ..estimation.tsc.twostep import fit_slope_stochastic, lin_reg_gamma_synthesis from ..estimation.tsc.ODEs import * +from ..estimation.tsc.twostep import fit_slope_stochastic, lin_reg_gamma_synthesis from .moments import ( moments, prepare_data_deterministic, @@ -294,7 +294,13 @@ def dynamics( raise ValueError(f"\nPlease run `dyn.pp.receipe_monocle(adata)` before running this function!") if tkey is None: tkey = adata.uns["pp"]["tkey"] - (experiment_type, has_splicing, has_labeling, splicing_labeling, has_protein,) = ( + ( + experiment_type, + has_splicing, + has_labeling, + splicing_labeling, + has_protein, + ) = ( adata.uns["pp"]["experiment_type"], adata.uns["pp"]["has_splicing"], adata.uns["pp"]["has_labeling"], @@ -727,7 +733,15 @@ def dynamics( est_method = "direct" data_type = "smoothed" if use_smoothed else "sfs" - (params, half_life, cost, logLL, param_ranges, cur_X_data, cur_X_fit_data,) = kinetic_model( + ( + params, + half_life, + cost, + logLL, + param_ranges, + cur_X_data, + cur_X_fit_data, + ) = kinetic_model( subset_adata, tkey, model, diff --git a/dynamo/tools/graph_calculus.py b/dynamo/tools/graph_calculus.py index b916248b0..3bd35bc8e 100644 --- a/dynamo/tools/graph_calculus.py +++ b/dynamo/tools/graph_calculus.py @@ -1,4 +1,5 @@ """This file implements the graph calculus functions using matrix as input.""" + from typing import Callable, List, Optional, Tuple, Union try: @@ -12,9 +13,9 @@ from scipy.optimize import lsq_linear, minimize from sklearn.neighbors import NearestNeighbors -from .connectivity import k_nearest_neighbors from ..dynamo_logger import main_info, main_warning from ..tools.utils import projection_with_transition_matrix +from .connectivity import k_nearest_neighbors from .utils import ( elem_prod, flatten, @@ -181,6 +182,7 @@ def cosine_similarity(mat_a, mat_b, mat_a_norm): else: sim = 0 return sim + def func(w, v, D, kernel, mat, mat_norm): """Wrap up main operations in the object function to minimize.""" v_ = w @ D @@ -202,9 +204,7 @@ def func(w, v, D, kernel, mat, mat_norm): elif loss_func == "log": rec = np.log(rec) else: - raise NotImplementedError( - f"The function {loss_func} is not supported. Choose either `linear` or `log`." - ) + raise NotImplementedError(f"The function {loss_func} is not supported. Choose either `linear` or `log`.") # regularization reg = 0 if r == 0 else w.dot(w) @@ -217,7 +217,7 @@ def fjac(w, v, D, kernel, mat, mat_norm): v_ = w @ D if kernel == "U": v_norm = np.linalg.norm(v_) - mat_ = mat/mat_norm + mat_ = mat / mat_norm # reconstruction error jac_con = 2 * a * D @ (v_ - v) @@ -230,10 +230,10 @@ def fjac(w, v, D, kernel, mat, mat_norm): # cosine similarity if kernel == "C": w_norm = np.linalg.norm(w) - if w_norm == 0 or b == 0 or c_norm==0: + if w_norm == 0 or b == 0 or c_norm == 0: jac_sim = 0 else: - jac_sim = b * (mat / (w_norm * mat_norm) - w.dot(mat) / (w_norm ** 3 * mat_norm) * w) + jac_sim = b * (mat / (w_norm * mat_norm) - w.dot(mat) / (w_norm**3 * mat_norm) * w) elif kernel == "U": if v_norm == 0 or b == 0: jac_sim = 0 @@ -293,9 +293,7 @@ def fjac_u(w): res = minimize(func_u, x0=D @ v, jac=fjac_u, bounds=bounds) E[i][idx] = res["x"] else: - raise NotImplementedError( - f"Optimization method is not supported. Please provide one of U or C." - ) + raise NotImplementedError(f"Optimization method is not supported. Please provide one of U or C.") return E @@ -487,7 +485,7 @@ def fp_operator( if E is not None: L = calc_laplacian(E, E=E, convention="diffusion", weight_mode="naive") else: - L = calc_laplacian(F, E=E, convention="diffusion", weight_mode="naive") + L = calc_laplacian(F, E=E, convention="diffusion", weight_mode="naive") else: L = calc_laplacian(W, E=E, convention="diffusion", weight_mode=weight_mode) @@ -552,7 +550,10 @@ def potential( def divergence( - E: np.ndarray, W: Optional[np.ndarray] = None, method: Literal["direct", "operator"] = "operator", weighted: bool = False + E: np.ndarray, + W: Optional[np.ndarray] = None, + method: Literal["direct", "operator"] = "operator", + weighted: bool = False, ) -> np.ndarray: """Calculate the divergence of a weighted graph. @@ -582,9 +583,11 @@ def divergence( W = abs(E.sign()) if sp.issparse(E) else np.abs(np.sign(E)) # W = np.abs(np.sign(E)) if W is None else W if weighted: - div = (divop(W) @ elem_prod(E, np.sqrt(W))[W.nonzero()].A1 - if sp.issparse(E) - else divop(W) @ elem_prod(E, np.sqrt(W))[W.nonzero()]) + div = ( + divop(W) @ elem_prod(E, np.sqrt(W))[W.nonzero()].A1 + if sp.issparse(E) + else divop(W) @ elem_prod(E, np.sqrt(W))[W.nonzero()] + ) else: div = divop(W) @ E[W.nonzero()].A1 if sp.issparse(E) else divop(W) @ E[W.nonzero()] else: diff --git a/dynamo/tools/graph_operators.py b/dynamo/tools/graph_operators.py index fa10fd0b3..bbe04cc9e 100644 --- a/dynamo/tools/graph_operators.py +++ b/dynamo/tools/graph_operators.py @@ -5,9 +5,9 @@ Code adapted from https://github.com/kazumits/ddhodge. """ -from typing import List, Optional, Union from itertools import combinations +from typing import List, Optional, Union import numpy as np from igraph import Graph diff --git a/dynamo/tools/growth.py b/dynamo/tools/growth.py index 7403eb484..a3ec27654 100644 --- a/dynamo/tools/growth.py +++ b/dynamo/tools/growth.py @@ -15,6 +15,7 @@ from .connectivity import k_nearest_neighbors + def score_cells( adata: AnnData, genes: Optional[List[str]] = None, @@ -84,9 +85,11 @@ def score_cells( genes = ( list(adata.var_names.intersection(genes)) if adata.var_names[0].isupper() - else list(adata.var_names.intersection([i.capitalize() for i in genes])) - if adata.var_names[0][0].isupper() and adata.var_names[0][1:].islower() - else list(adata.var_names.intersection([i.lower() for i in genes])) + else ( + list(adata.var_names.intersection([i.capitalize() for i in genes])) + if adata.var_names[0][0].isupper() and adata.var_names[0][1:].islower() + else list(adata.var_names.intersection([i.lower() for i in genes])) + ) ) if len(genes) < 1: diff --git a/dynamo/tools/markers.py b/dynamo/tools/markers.py index 13d005da3..df7c9c11b 100755 --- a/dynamo/tools/markers.py +++ b/dynamo/tools/markers.py @@ -30,7 +30,7 @@ main_warning, ) from ..preprocessing.transform import _Freeman_Tukey -from ..tools.connectivity import generate_neighbor_keys, check_and_recompute_neighbors +from ..tools.connectivity import check_and_recompute_neighbors, generate_neighbor_keys from .utils import fdr, fetch_X_data @@ -676,21 +676,29 @@ def glm_degs( X_data.data = ( 2**X_data.data - 1 if adata.uns["pp"][norm_method_key] == "log2" - else np.exp(X_data.data) - 1 - if adata.uns["pp"][norm_method_key] == "log" - else _Freeman_Tukey(X_data.data + 1, inverse=True) - 1 - if adata.uns["pp"][norm_method_key] == "Freeman_Tukey" - else X_data.data + else ( + np.exp(X_data.data) - 1 + if adata.uns["pp"][norm_method_key] == "log" + else ( + _Freeman_Tukey(X_data.data + 1, inverse=True) - 1 + if adata.uns["pp"][norm_method_key] == "Freeman_Tukey" + else X_data.data + ) + ) ) else: X_data = ( 2**X_data - 1 if adata.uns["pp"][norm_method_key] == "log2" - else np.exp(X_data) - 1 - if adata.uns["pp"][norm_method_key] == "log" - else _Freeman_Tukey(X_data, inverse=True) - if adata.uns["pp"][norm_method_key] == "Freeman_Tukey" - else X_data + else ( + np.exp(X_data) - 1 + if adata.uns["pp"][norm_method_key] == "log" + else ( + _Freeman_Tukey(X_data, inverse=True) + if adata.uns["pp"][norm_method_key] == "Freeman_Tukey" + else X_data + ) + ) ) factors = get_all_variables(fullModelFormulaStr) @@ -721,7 +729,10 @@ def diff_test_helper( data: pd.DataFrame, fullModelFormulaStr: str = "~cr(time, df=3)", reducedModelFormulaStr: str = "~1", -) -> Union[Tuple[Literal["fail"], Literal["NB2"], Literal[1]], Tuple[Literal["ok"], Literal["NB2"], np.ndarray],]: +) -> Union[ + Tuple[Literal["fail"], Literal["NB2"], Literal[1]], + Tuple[Literal["ok"], Literal["NB2"], np.ndarray], +]: """A helper function to generate required data fields for differential gene expression test. Args: diff --git a/dynamo/tools/metric_velocity.py b/dynamo/tools/metric_velocity.py index 4c812e252..cbd8a54c2 100755 --- a/dynamo/tools/metric_velocity.py +++ b/dynamo/tools/metric_velocity.py @@ -431,4 +431,3 @@ def consensus(x: np.ndarray, y: np.ndarray) -> np.ndarray: ) return consensus - diff --git a/dynamo/tools/moments.py b/dynamo/tools/moments.py index 83079fd6c..de18fb51f 100755 --- a/dynamo/tools/moments.py +++ b/dynamo/tools/moments.py @@ -126,7 +126,12 @@ def moments( with warnings.catch_warnings(): warnings.simplefilter("ignore") if group is None: - (kNN, knn_indices, knn_dists, _,) = umap_conn_indices_dist_embedding( + ( + kNN, + knn_indices, + knn_dists, + _, + ) = umap_conn_indices_dist_embedding( X, n_neighbors=np.min((n_neighbors, adata.n_obs - 1)), return_mapper=False, @@ -146,7 +151,12 @@ def moments( for cur_grp in uniq_grp: cur_cells = cells_group == cur_grp cur_X = X[cur_cells, :] - (cur_kNN, cur_knn_indices, cur_knn_dists, _,) = umap_conn_indices_dist_embedding( + ( + cur_kNN, + cur_knn_indices, + cur_knn_dists, + _, + ) = umap_conn_indices_dist_embedding( cur_X, n_neighbors=np.min((n_neighbors, sum(cur_cells) - 1)), return_mapper=False, diff --git a/dynamo/tools/pseudotime.py b/dynamo/tools/pseudotime.py index 355b4a4c3..3ca3e2bb5 100755 --- a/dynamo/tools/pseudotime.py +++ b/dynamo/tools/pseudotime.py @@ -3,14 +3,13 @@ import anndata import numpy as np import pandas as pd -from scipy.spatial import distance from scipy.sparse.csgraph import minimum_spanning_tree +from scipy.spatial import distance +from ..dynamo_logger import main_info, main_info_insert_adata_obs from .DDRTree import DDRTree from .utils import log1p_ -from ..dynamo_logger import main_info, main_info_insert_adata_obs - def order_cells( adata: anndata.AnnData, @@ -98,12 +97,9 @@ def order_cells( root_cell = select_root_cell(adata, Z=Z, root_state=root_state, init_cells=init_cells, reverse=reverse) cc_ordering = get_order_from_DDRTree(dp=dp, mst=mst, root_cell=root_cell) - ( - cellPairwiseDistances, - pr_graph_cell_proj_dist, - pr_graph_cell_proj_closest_vertex, - pr_graph_cell_proj_tree - ) = project2MST(mst, Z, Y, project_point_to_line_segment) + (cellPairwiseDistances, pr_graph_cell_proj_dist, pr_graph_cell_proj_closest_vertex, pr_graph_cell_proj_tree) = ( + project2MST(mst, Z, Y, project_point_to_line_segment) + ) adata.uns["cell_order"]["root_cell"] = root_cell adata.uns["cell_order"]["centers_order"] = cc_ordering["orders"].values @@ -121,11 +117,15 @@ def order_cells( root_cell_candidates = np.intersect1d(cells_mapped_to_graph_root, tip_leaves) if len(root_cell_candidates) == 0: - root_cell = select_root_cell(adata, Z=Z, root_state=root_state, init_cells=init_cells, reverse=reverse, map_to_tree=False) + root_cell = select_root_cell( + adata, Z=Z, root_state=root_state, init_cells=init_cells, reverse=reverse, map_to_tree=False + ) else: root_cell = root_cell_candidates[0] - cc_ordering_new_pseudotime = get_order_from_DDRTree(dp=cellPairwiseDistances, mst=pr_graph_cell_proj_tree, root_cell=root_cell) # re-calculate the pseudotime again + cc_ordering_new_pseudotime = get_order_from_DDRTree( + dp=cellPairwiseDistances, mst=pr_graph_cell_proj_tree, root_cell=root_cell + ) # re-calculate the pseudotime again adata.uns["cell_order"]["root_cell"] = root_cell adata.obs["Pseudotime"] = cc_ordering_new_pseudotime["pseudo_time"].values @@ -157,12 +157,7 @@ def get_order_from_DDRTree(dp: np.ndarray, mst: np.ndarray, root_cell: int) -> p dp_mst = ig.Graph.Weighted_Adjacency(matrix=mst) curr_state = 0 pseudotimes = [0 for _ in range(dp.shape[1])] - ordering_dict = { - 'cell_index': [], - 'cell_pseudo_state': [], - 'pseudo_time': [], - 'parent': [] - } + ordering_dict = {"cell_index": [], "cell_pseudo_state": [], "pseudo_time": [], "parent": []} orders, pres = dp_mst.dfs(vid=root_cell, mode="all") @@ -189,8 +184,8 @@ def get_order_from_DDRTree(dp: np.ndarray, mst: np.ndarray, root_cell: int) -> p ordering_df = pd.DataFrame.from_dict(ordering_dict) ordering_df.reset_index(inplace=True) - ordering_df = ordering_df.rename(columns={'index': 'orders'}) - ordering_df.set_index('cell_index', inplace=True) + ordering_df = ordering_df.rename(columns={"index": "orders"}) + ordering_df.set_index("cell_index", inplace=True) ordering_df = ordering_df.sort_index() return ordering_df @@ -361,11 +356,11 @@ def select_root_cell( root_cell = root_cell_candidates[index_of_closest_sample] if map_to_tree: - cell_proj_closest_vertex = np.argmax(adata.uns['cell_order']['R'], axis=1) + cell_proj_closest_vertex = np.argmax(adata.uns["cell_order"]["R"], axis=1) root_cell = cell_proj_closest_vertex[root_cell] elif root_state is not None: - if 'cell_pseudo_state' not in adata.obs.keys(): + if "cell_pseudo_state" not in adata.obs.keys(): raise ValueError("State has not yet been set. Please call order_cells() without specifying root_state.") root_cell_candidates = np.where(adata.obs["cell_pseudo_state"] == root_state)[0] @@ -375,29 +370,31 @@ def select_root_cell( reduced_dim_subset = Z[:, root_cell_candidates].T dp = distance.cdist(reduced_dim_subset, reduced_dim_subset, metric="euclidean") gp = ig.Graph.Weighted_Adjacency(dp, mode="undirected") - dp_mst = gp.spanning_tree(weights=gp.es['weight']) + dp_mst = gp.spanning_tree(weights=gp.es["weight"]) diameter = dp_mst.get_diameter(directed=False) if len(diameter) == 0: raise ValueError("No valid root cells for State =", root_state) root_cell_candidates = root_cell_candidates[diameter] - if adata.uns['cell_order']['root_cell'] is not None and \ - adata.obs["cell_pseudo_state"][adata.uns['cell_order']['root_cell']] == root_state: - root_cell = root_cell_candidates[np.argmin(adata[root_cell_candidates].obs['Pseudotime'].values)] + if ( + adata.uns["cell_order"]["root_cell"] is not None + and adata.obs["cell_pseudo_state"][adata.uns["cell_order"]["root_cell"]] == root_state + ): + root_cell = root_cell_candidates[np.argmin(adata[root_cell_candidates].obs["Pseudotime"].values)] else: - root_cell = root_cell_candidates[np.argmax(adata[root_cell_candidates].obs['Pseudotime'].values)] + root_cell = root_cell_candidates[np.argmax(adata[root_cell_candidates].obs["Pseudotime"].values)] if isinstance(root_cell, list): root_cell = root_cell[0] if map_to_tree: - root_cell = adata.uns['cell_order']['pr_graph_cell_proj_closest_vertex'][root_cell] + root_cell = adata.uns["cell_order"]["pr_graph_cell_proj_closest_vertex"][root_cell] else: - if 'minSpanningTree' not in adata.uns['cell_order'].keys(): + if "minSpanningTree" not in adata.uns["cell_order"].keys(): raise ValueError("No spanning tree found for adata object.") - graph = ig.Graph.Weighted_Adjacency(adata.uns['cell_order']['minSpanningTree'], mode="undirected") + graph = ig.Graph.Weighted_Adjacency(adata.uns["cell_order"]["minSpanningTree"], mode="undirected") diameter = graph.get_diameter(directed=False) if reverse: root_cell = diameter[-1] diff --git a/dynamo/tools/sampling.py b/dynamo/tools/sampling.py index cb75161c9..f98e57571 100644 --- a/dynamo/tools/sampling.py +++ b/dynamo/tools/sampling.py @@ -9,8 +9,8 @@ from scipy.cluster.vq import kmeans2 from sklearn.neighbors import NearestNeighbors -from .connectivity import k_nearest_neighbors from ..dynamo_logger import LoggerManager +from .connectivity import k_nearest_neighbors from .utils import nearest_neighbors, timeit diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index 5a2c33a37..6d717bdc9 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -1044,25 +1044,37 @@ def inverse_norm(adata: AnnData, layer_x: Union[np.ndarray, sp.csr_matrix]) -> n layer_x.data = ( np.expm1(layer_x.data) if adata.uns["pp"]["layers_norm_method"] == "log1p" - else 2**layer_x.data - 1 - if adata.uns["pp"]["layers_norm_method"] == "log2" - else np.exp(layer_x.data) - 1 - if adata.uns["pp"]["layers_norm_method"] == "log" - else _Freeman_Tukey(layer_x.data + 1, inverse=True) - 1 - if adata.uns["pp"]["layers_norm_method"] == "Freeman_Tukey" - else layer_x.data + else ( + 2**layer_x.data - 1 + if adata.uns["pp"]["layers_norm_method"] == "log2" + else ( + np.exp(layer_x.data) - 1 + if adata.uns["pp"]["layers_norm_method"] == "log" + else ( + _Freeman_Tukey(layer_x.data + 1, inverse=True) - 1 + if adata.uns["pp"]["layers_norm_method"] == "Freeman_Tukey" + else layer_x.data + ) + ) + ) ) else: layer_x = ( np.expm1(layer_x) if adata.uns["pp"]["layers_norm_method"] == "log1p" - else 2**layer_x - 1 - if adata.uns["pp"]["layers_norm_method"] == "log2" - else np.exp(layer_x) - 1 - if adata.uns["pp"]["layers_norm_method"] == "log" - else _Freeman_Tukey(layer_x, inverse=True) - if adata.uns["pp"]["layers_norm_method"] == "Freeman_Tukey" - else layer_x + else ( + 2**layer_x - 1 + if adata.uns["pp"]["layers_norm_method"] == "log2" + else ( + np.exp(layer_x) - 1 + if adata.uns["pp"]["layers_norm_method"] == "log" + else ( + _Freeman_Tukey(layer_x, inverse=True) + if adata.uns["pp"]["layers_norm_method"] == "Freeman_Tukey" + else layer_x + ) + ) + ) ) return layer_x @@ -1294,7 +1306,16 @@ def get_data_for_kin_params_estimation( ) NTR_vel = True - U, Ul, S, Sl, P, US, U2, S2, = ( + ( + U, + Ul, + S, + Sl, + P, + US, + U2, + S2, + ) = ( None, None, None, @@ -1735,9 +1756,7 @@ def set_param_kinetic( if isarray(alpha) and alpha.ndim > 1: params_df.loc[valid_ind, kin_param_pre + "alpha"] = ( - np.asarray(alpha.mean(1)) - if sp.issparse(alpha) - else alpha.mean(1) + np.asarray(alpha.mean(1)) if sp.issparse(alpha) else alpha.mean(1) ) cur_cells_ind, valid_ind_ = ( np.where(cur_cells_bools)[0][:, np.newaxis], @@ -2574,7 +2593,9 @@ def set_transition_genes( ) if is_group_alpha.sum() > 0: vel_params_df["alpha"] = vel_params_df.loc[:, is_group_alpha].mean(1, skipna=True) - vel_params_df["alpha_r2"] = vel_params_df.loc[:, np.hstack((is_group_alpha_r2, False))].mean(1, skipna=True) + vel_params_df["alpha_r2"] = vel_params_df.loc[:, np.hstack((is_group_alpha_r2, False))].mean( + 1, skipna=True + ) else: raise Exception("there is no alpha/alpha_r2 parameter estimated for your adata object") @@ -2595,7 +2616,9 @@ def set_transition_genes( ) if is_group_gamma.sum() > 0: vel_params_df["gamma"] = vel_params_df.loc[:, is_group_gamma].mean(1, skipna=True) - vel_params_df["gamma_r2"] = vel_params_df.loc[:, np.hstack((is_group_gamma_r2, False))].mean(1, skipna=True) + vel_params_df["gamma_r2"] = vel_params_df.loc[:, np.hstack((is_group_gamma_r2, False))].mean( + 1, skipna=True + ) else: raise Exception("there is no gamma/gamma_r2 parameter estimated for your adata object") @@ -2618,7 +2641,9 @@ def set_transition_genes( ) if is_group_delta.sum() > 0: vel_params_df["delta"] = vel_params_df.loc[:, is_group_delta].mean(1, skipna=True) - vel_params_df["delta_r2"] = vel_params_df.loc[:, np.hstack((is_group_delta_r2, False))].mean(1, skipna=True) + vel_params_df["delta_r2"] = vel_params_df.loc[:, np.hstack((is_group_delta_r2, False))].mean( + 1, skipna=True + ) else: raise Exception("there is no delta/delta_r2 parameter estimated for your adata object") @@ -2639,7 +2664,9 @@ def set_transition_genes( ) if is_group_gamma.sum() > 0: vel_params_df["gamma"] = vel_params_df.loc[:, is_group_gamma].mean(1, skipna=True) - vel_params_df["gamma_r2"] = vel_params_df.loc[:, np.hstack((is_group_gamma_r2, False))].mean(1, skipna=True) + vel_params_df["gamma_r2"] = vel_params_df.loc[:, np.hstack((is_group_gamma_r2, False))].mean( + 1, skipna=True + ) else: raise Exception("there is no gamma/gamma_r2 parameter estimated for your adata object") @@ -2962,7 +2989,7 @@ def integrate_vf( args: Tuple, integration_direction: Literal["forward", "backward", "both"], f: Callable, - interpolation_num: Optional[int]=None, + interpolation_num: Optional[int] = None, average: bool = True, ): """Integrating along vector field function. @@ -3229,6 +3256,7 @@ def compute_smallest_distance( # --------------------------------------------------------------------------------------------------- # multiple core related + # Pass kwargs to starmap while using Pool # https://stackoverflow.com/questions/45718523/pass-kwargs-to-starmap-while-using-pool-in-python def starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): @@ -3381,6 +3409,7 @@ def projection_with_transition_matrix( return delta_X + def density_corrected_transition_matrix(T: Union[npt.ArrayLike, sp.csr_matrix]) -> sp.csr_matrix: """Compute the density corrected transition matrix. diff --git a/dynamo/tools/velocyto_scvelo.py b/dynamo/tools/velocyto_scvelo.py index 37e435097..12ca8b838 100755 --- a/dynamo/tools/velocyto_scvelo.py +++ b/dynamo/tools/velocyto_scvelo.py @@ -5,6 +5,7 @@ Convert adata to loom object or vice versa. Convert Dynamo AnnData object to scvelo AnnData object or vice versa. """ + # from .moments import * from typing import List, Optional @@ -17,11 +18,11 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from .utils import get_vel_params -from ..dynamo_logger import main_info from scipy.sparse import csr_matrix from ..configuration import DKM +from ..dynamo_logger import main_info +from .utils import get_vel_params def vlm_to_adata( @@ -254,16 +255,21 @@ def scv_dyn_convertor(adata: anndata, mode: Literal["to_dyn", "to_scv"] = "to_dy Returns: The adata object after conversion. """ - main_info("Dynamo and scvelo have different preprocessing procedures and velocity estimation methods. " - "The conversion of adata may not be optimal for every use case, requiring potential manual adjustments.") + main_info( + "Dynamo and scvelo have different preprocessing procedures and velocity estimation methods. " + "The conversion of adata may not be optimal for every use case, requiring potential manual adjustments." + ) if mode == "to_dyn": main_info("Start converting Scvelo adata into Dynamo adata...") - main_info("Scvelo data wil be converted into Dynamo adata with the conventional assumption and the" - "stochastic model. If this is not what you want, please change them manually.") + main_info( + "Scvelo data wil be converted into Dynamo adata with the conventional assumption and the" + "stochastic model. If this is not what you want, please change them manually." + ) if "highly_variable_genes" in adata.var.columns: adata.var["pass_basic_filter"] = adata.var.pop("highly_variable_genes") - adata.var["pass_basic_filter"] = [True if item == 'True' else False for item in - adata.var["pass_basic_filter"]] + adata.var["pass_basic_filter"] = [ + True if item == "True" else False for item in adata.var["pass_basic_filter"] + ] if "spliced" in adata.layers.keys(): adata.layers["X_spliced"] = adata.layers.pop("spliced") if "unspliced" in adata.layers.keys(): @@ -339,8 +345,9 @@ def scv_dyn_convertor(adata: anndata, mode: Literal["to_dyn", "to_scv"] = "to_dy main_info("Start converting Dynamo adata into Scvelo adata...") if "pass_basic_filter" in adata.var.columns: adata.var["highly_variable_genes"] = adata.var.pop("pass_basic_filter") - adata.var["highly_variable_genes"] = ["True" if item else "False" for item in - adata.var["highly_variable_genes"]] + adata.var["highly_variable_genes"] = [ + "True" if item else "False" for item in adata.var["highly_variable_genes"] + ] if "X_spliced" in adata.layers.keys(): adata.layers["spliced"] = adata.layers.pop("X_spliced") if "X_unspliced" in adata.layers.keys(): diff --git a/dynamo/utils.py b/dynamo/utils.py index a37aa4f1a..267aaed4f 100644 --- a/dynamo/utils.py +++ b/dynamo/utils.py @@ -1,5 +1,6 @@ """General utility functions """ + from typing import Any, Callable, Dict, List, Optional, Tuple, Union import anndata diff --git a/dynamo/vectorfield/Ao.py b/dynamo/vectorfield/Ao.py index c5f8423a3..059a3c0bb 100755 --- a/dynamo/vectorfield/Ao.py +++ b/dynamo/vectorfield/Ao.py @@ -1,8 +1,8 @@ from typing import Callable, List, Optional, Tuple, Union import numpy as np -from scipy.optimize import least_squares import tqdm +from scipy.optimize import least_squares from ..tools.utils import condensed_idx_to_squareform_idx, squareform, timeit @@ -123,6 +123,7 @@ def Ao_pot_map( return X, U, P, vecMat, S, A + def Ao_pot_map_jac(fjac, X, D=None, **kwargs): nobs, ndim = X.shape if D is None: diff --git a/dynamo/vectorfield/Bhattacharya.py b/dynamo/vectorfield/Bhattacharya.py index 59b3f9044..29ee05165 100755 --- a/dynamo/vectorfield/Bhattacharya.py +++ b/dynamo/vectorfield/Bhattacharya.py @@ -216,11 +216,13 @@ def path_integral( # check if start points of current and previous paths are "adjacent" - if so, assign separatrix if startPt_dist_sqr < (2 * (xyGridSpacing**2)): - curr_sepx = np.array([ - path_tag[path_counter - 1][0], - path_tag[path_counter][0], - path_counter - 1, - ]) # create array + curr_sepx = np.array( + [ + path_tag[path_counter - 1][0], + path_tag[path_counter][0], + path_counter - 1, + ] + ) # create array sepx_old_new_pathNum = ( np.vstack((sepx_old_new_pathNum, curr_sepx)) if sepx_old_new_pathNum is not None @@ -248,11 +250,13 @@ def path_integral( if prev_attr_new == 1: # check if start points of current and previous paths are "adjacent" - if so, assign separatrix if startPt_dist_sqr < (2 * (xyGridSpacing**2)): - curr_sepx = np.array([ - path_tag[path_counter - 1][0], - path_tag[path_counter][0], - (path_counter - 1), - ]) # create array + curr_sepx = np.array( + [ + path_tag[path_counter - 1][0], + path_tag[path_counter][0], + (path_counter - 1), + ] + ) # create array sepx_old_new_pathNum = ( np.vstack((sepx_old_new_pathNum, curr_sepx)) if sepx_old_new_pathNum is not None diff --git a/dynamo/vectorfield/Tang.py b/dynamo/vectorfield/Tang.py index ba4f34f1d..68031d13e 100644 --- a/dynamo/vectorfield/Tang.py +++ b/dynamo/vectorfield/Tang.py @@ -3,6 +3,7 @@ import numpy as np import scipy as sp import scipy.optimize + # import autograd.numpy as autonp # from autograd import grad, jacobian # calculate gradient and jacobian diff --git a/dynamo/vectorfield/VectorField.py b/dynamo/vectorfield/VectorField.py index 82361ace2..2438b6d46 100644 --- a/dynamo/vectorfield/VectorField.py +++ b/dynamo/vectorfield/VectorField.py @@ -13,9 +13,7 @@ from ..utils import copy_adata from .scVectorField import BaseVectorField, SvcVectorField from .topography import topography -from .utils import ( - angle, -) +from .utils import angle def VectorField( @@ -139,6 +137,7 @@ def VectorField( elif method.lower() == "dynode": try: from dynode.vectorfield import Dynode # networkModels, + from .scVectorField import dynode_vectorfield except ImportError: raise ImportError("You need to install the package `dynode`." "install dynode via `pip install dynode`") @@ -474,9 +473,11 @@ def _resume_training( "Y": Y, "V": Dynode_obj.predict_velocity(Dynode_obj.Velocity["sampler"].X_raw), "grid_V": Dynode_obj.predict_velocity(Dynode_obj.Velocity["sampler"].Grid), - "valid_ind": Dynode_obj.Velocity["sampler"].valid_ind - if hasattr(Dynode_obj.Velocity["sampler"], "valid_ind") - else np.arange(X.shape[0]), + "valid_ind": ( + Dynode_obj.Velocity["sampler"].valid_ind + if hasattr(Dynode_obj.Velocity["sampler"], "valid_ind") + else np.arange(X.shape[0]) + ), "parameters": Dynode_obj.Velocity, "dynode_object": VecFld, } diff --git a/dynamo/vectorfield/__init__.py b/dynamo/vectorfield/__init__.py index c5a4a679a..1824ca743 100644 --- a/dynamo/vectorfield/__init__.py +++ b/dynamo/vectorfield/__init__.py @@ -37,18 +37,12 @@ SvcVectorField, graphize_vecfld, ) -from .Tang import action, gen_fixed_points, IntGrad # stochastic process related from .stochastic_process import diffusionMatrix -from .topography import ( - FixedPoints, - Topography2D, - assign_fixedpoints, - topography, -) +from .Tang import IntGrad, action, gen_fixed_points +from .topography import FixedPoints, Topography2D, assign_fixedpoints, topography from .utils import get_jacobian, parse_int_df, vector_field_function -from .VectorField import VectorField from .vector_calculus import ( acceleration, curl, @@ -62,6 +56,7 @@ torsion, velocities, ) +from .VectorField import VectorField # vfGraph operation related: from .vfGraph_deprecated import vfGraph diff --git a/dynamo/vectorfield/cell_vectors.py b/dynamo/vectorfield/cell_vectors.py index c497a9805..70e76c4f1 100644 --- a/dynamo/vectorfield/cell_vectors.py +++ b/dynamo/vectorfield/cell_vectors.py @@ -3,8 +3,8 @@ from anndata import AnnData from ..tools.cell_velocities import cell_velocities -from .VectorField import VectorField from .vector_calculus import acceleration, curvature +from .VectorField import VectorField def cell_accelerations( diff --git a/dynamo/vectorfield/clustering.py b/dynamo/vectorfield/clustering.py index 4245a3f30..43efe2c91 100644 --- a/dynamo/vectorfield/clustering.py +++ b/dynamo/vectorfield/clustering.py @@ -413,8 +413,7 @@ def streamline_clusters( feature_adata.obs[key] = adata.obs.obs[key].astype("category") else: raise ValueError( - "only louvain, leiden, hdbscan and kmeans clustering supported but your requested " - f"method is {method}" + "only louvain, leiden, hdbscan and kmeans clustering supported but your requested " f"method is {method}" ) if assign_fixedpoints or reversed_fixedpoints: diff --git a/dynamo/vectorfield/scPotential.py b/dynamo/vectorfield/scPotential.py index fd6e0d2ea..d9cee1bee 100755 --- a/dynamo/vectorfield/scPotential.py +++ b/dynamo/vectorfield/scPotential.py @@ -1,10 +1,10 @@ +from typing import Callable, List, Optional, Tuple, Union from warnings import warn -from anndata._core.anndata import AnnData import numpy as np import scipy as sp import scipy.optimize -from typing import Callable, List, Optional, Tuple, Union +from anndata._core.anndata import AnnData from ..tools.sampling import lhsclassic from .Ao import Ao_pot_map, construct_Ao_potential_grid @@ -120,8 +120,8 @@ def fit( method: str = "Ao", xyGridSpacing: int = 2, dt: float = 1e-2, - tol: float= 1e-2, - numTimeSteps: int =1400, + tol: float = 1e-2, + numTimeSteps: int = 1400, ) -> AnnData: """Function to map out the pseudo-potential landscape. @@ -183,7 +183,19 @@ def fit( "A": A, } elif method == "Bhattacharya": - (_, _, _, _, numPaths, numTimeSteps, pot_path, path_tag, attractors_pot, x_path, y_path,) = path_integral( + ( + _, + _, + _, + _, + numPaths, + numTimeSteps, + pot_path, + path_tag, + attractors_pot, + x_path, + y_path, + ) = path_integral( self.VecFld["Function"], x_lim=x_lim, y_lim=y_lim, @@ -216,7 +228,15 @@ def fit( self.VecFld["Function"], self.VecFld["DiffusionMatrix"], ) - (boundary, n_points, fixed_point_only, find_fixed_points, refpoint, stable, saddle,) = ( + ( + boundary, + n_points, + fixed_point_only, + find_fixed_points, + refpoint, + stable, + saddle, + ) = ( self.parameters["boundary"], self.parameters["n_points"], self.parameters["fixed_point_only"], @@ -238,23 +258,23 @@ def fit( saddle=self.parameters["saddle"], ) - adata.uns['grid_Pot_' + basis] = {'Xgrid': X, "Ygrid": Y, 'Zgrid': retmat} + adata.uns["grid_Pot_" + basis] = {"Xgrid": X, "Ygrid": Y, "Zgrid": retmat} return adata # return retmat, LAP def search_fixed_points( - func: Callable, - domain: np.ndarray, - x0: np.ndarray, - x0_method: str = "lhs", - reverse: bool = False, - return_x0: bool = False, - fval_tol: float = 1e-8, - remove_outliers: bool = True, - ignore_fsolve_err: bool = False, - **fsolve_kwargs + func: Callable, + domain: np.ndarray, + x0: np.ndarray, + x0_method: str = "lhs", + reverse: bool = False, + return_x0: bool = False, + fval_tol: float = 1e-8, + remove_outliers: bool = True, + ignore_fsolve_err: bool = False, + **fsolve_kwargs ) -> Union[FixedPoints, Tuple[FixedPoints, np.ndarray]]: """Search the fixed points of (learned) vector field function in a given domain. @@ -332,10 +352,10 @@ def search_fixed_points( def gen_gradient( - dim: int, - N: int, - Function: Callable, - DiffusionMatrix: Callable, + dim: int, + N: int, + Function: Callable, + DiffusionMatrix: Callable, ) -> Tuple[np.ndarray, np.ndarray]: """Calculate the gradient of the (learned) vector field function for the least action path (LAP) symbolically diff --git a/dynamo/vectorfield/scVectorField.py b/dynamo/vectorfield/scVectorField.py index 9adcd6cea..1e42baee4 100755 --- a/dynamo/vectorfield/scVectorField.py +++ b/dynamo/vectorfield/scVectorField.py @@ -134,6 +134,7 @@ def bandwidth_selector(X: np.ndarray) -> float: d = np.mean(distances[:, 1:]) / 1.5 return np.sqrt(2) * d + def denorm( VecFld: Dict[str, Union[np.ndarray, None]], X_old: np.ndarray, @@ -1011,7 +1012,12 @@ def train(self, **kwargs) -> VecFldDict: if self.normalize: X_norm, V_norm, T_norm, norm_dict = norm(self.data["X"], self.data["V"], self.data["Grid"]) - (self.data["X"], self.data["V"], self.data["Grid"], self.norm_dict,) = ( + ( + self.data["X"], + self.data["V"], + self.data["Grid"], + self.norm_dict, + ) = ( X_norm, V_norm, T_norm, diff --git a/dynamo/vectorfield/stochastic_process.py b/dynamo/vectorfield/stochastic_process.py index 291dccc60..75f6c63ae 100644 --- a/dynamo/vectorfield/stochastic_process.py +++ b/dynamo/vectorfield/stochastic_process.py @@ -5,7 +5,11 @@ from sklearn.neighbors import NearestNeighbors from tqdm import tqdm -from ..tools.connectivity import generate_neighbor_keys, check_and_recompute_neighbors, k_nearest_neighbors +from ..tools.connectivity import ( + check_and_recompute_neighbors, + generate_neighbor_keys, + k_nearest_neighbors, +) from ..tools.utils import log1p_ from .utils import VecFldDict, vecfld_from_adata, vector_field_function diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index 31b84cfaa..c9c6caafc 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -609,6 +609,7 @@ class Topography3D(Topography2D): the vector at each point, or by separate functions for the x and y components of the vector. Nullclines calculation are not supported for 3D vector space because of the computational complexity. """ + def __init__( self, func: Callable, @@ -684,7 +685,6 @@ def find_fixed_points_by_sampling( raise ValueError(f"No fixed points found. Try to increase the number of samples n.") self.Xss.add_fixed_points(X, J, tol_redundant) - def output_to_dict(self, dict_vf) -> Dict: """Output the vector field as a dictionary. diff --git a/dynamo/vectorfield/vector_calculus.py b/dynamo/vectorfield/vector_calculus.py index 57dcf3d3f..76858514b 100644 --- a/dynamo/vectorfield/vector_calculus.py +++ b/dynamo/vectorfield/vector_calculus.py @@ -3,6 +3,7 @@ # from anndata._core.views import ArrayView # import scipy.sparse as sp from typing import Dict, List, Optional, Union + try: from typing import Literal except ImportError: @@ -61,12 +62,12 @@ def get_vf_class(adata: AnnData, basis: str = "pca") -> SvcVectorField: """Get the corresponding vector field class according to different methods. - Args: - adata: AnnData object that contains the reconstructed vector field in the `uns` attribute. - basis: The embedding data in which the vector field was reconstructed. + Args: + adata: AnnData object that contains the reconstructed vector field in the `uns` attribute. + basis: The embedding data in which the vector field was reconstructed. - Returns: - SvcVectorField object that is extracted from the `uns` attribute of adata. + Returns: + SvcVectorField object that is extracted from the `uns` attribute of adata. """ vf_dict = get_vf_dict(adata, basis=basis) if "method" not in vf_dict.keys(): @@ -197,7 +198,7 @@ def jacobian( regulators: Optional[List] = None, effectors: Optional[List] = None, cell_idx: Optional[List] = None, - sampling: Optional[Literal['random', 'velocity', 'trn']] = None, + sampling: Optional[Literal["random", "velocity", "trn"]] = None, sample_ncells: int = 1000, basis: str = "pca", Qkey: str = "PCs", @@ -340,7 +341,7 @@ def hessian( coregulators: List, effector: Optional[List] = None, cell_idx: Optional[List] = None, - sampling: Optional[Literal['random', 'velocity', 'trn']] = None, + sampling: Optional[Literal["random", "velocity", "trn"]] = None, sample_ncells: int = 1000, basis: str = "pca", Qkey: str = "PCs", @@ -577,7 +578,7 @@ def sensitivity( regulators: Optional[List] = None, effectors: Optional[List] = None, cell_idx: Optional[List] = None, - sampling: Optional[Literal['random', 'velocity', 'trn']] = None, + sampling: Optional[Literal["random", "velocity", "trn"]] = None, sample_ncells: int = 1000, basis: str = "pca", Qkey: str = "PCs", @@ -940,7 +941,7 @@ def curl( def divergence( adata: AnnData, cell_idx: Optional[List] = None, - sampling: Optional[Literal['random', 'velocity', 'trn']] = None, + sampling: Optional[Literal["random", "velocity", "trn"]] = None, sample_ncells: int = 1000, basis: str = "pca", vector_field_class=None, diff --git a/setup.py b/setup.py index c3262b290..b2012b393 100755 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ +import os from pathlib import Path from setuptools import find_packages, setup -import os + # from version import __version__ diff --git a/tests/test_data_io.py b/tests/test_data_io.py index 964ff0026..d7c553ab0 100644 --- a/tests/test_data_io.py +++ b/tests/test_data_io.py @@ -1,10 +1,11 @@ import os import numpy as np +import pytest import dynamo import dynamo as dyn -import pytest + # import utils diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c77389d1f..082e83e4d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,8 @@ -import dynamo as dyn import pytest +import dynamo as dyn + + def test_dynamcis(adata): raw_adata = dyn.sample_data.zebrafish() adata = raw_adata[:300, :1000].copy() @@ -96,4 +98,4 @@ def test_run_rpe1_tutorial(): ) ax.set_aspect(0.8) - instance = dyn.mv.StreamFuncAnim(adata=rpe1_kinetics, basis="RFP_GFP", color="Cell_cycle_relativePos", ax=ax) \ No newline at end of file + instance = dyn.mv.StreamFuncAnim(adata=rpe1_kinetics, basis="RFP_GFP", color="Cell_cycle_relativePos", ax=ax) diff --git a/tests/test_pl.py b/tests/test_pl.py index 10b519aee..816873792 100644 --- a/tests/test_pl.py +++ b/tests/test_pl.py @@ -143,14 +143,21 @@ def test_lap_plots(): adata = dyn.sample_data.hematopoiesis() - progenitor = adata.obs_names[adata.obs.cell_type.isin(['HSC'])] + progenitor = adata.obs_names[adata.obs.cell_type.isin(["HSC"])] - dyn.pd.fate(adata, basis='umap', init_cells=progenitor, interpolation_num=25, direction='forward', - inverse_transform=False, average=False) - ax = dyn.pl.fate_bias(adata, group="cell_type", basis="umap", save_show_or_return='return') + dyn.pd.fate( + adata, + basis="umap", + init_cells=progenitor, + interpolation_num=25, + direction="forward", + inverse_transform=False, + average=False, + ) + ax = dyn.pl.fate_bias(adata, group="cell_type", basis="umap", save_show_or_return="return") assert isinstance(ax, sns.matrix.ClusterGrid) - ax = dyn.pl.fate(adata, basis="umap", save_show_or_return='return') + ax = dyn.pl.fate(adata, basis="umap", save_show_or_return="return") assert isinstance(ax, plt.Axes) # genes = adata.var_names[adata.var.use_for_dynamics] @@ -217,7 +224,12 @@ def test_heatmaps(): assert isinstance(ax, pd.DataFrame) ax = dyn.pl.comb_logic( - adata, pairs_mat=np.array(pair_matrix), xkey="M_n", ykey="M_t", zkey="velocity_alpha_minus_gamma_s", return_data=True + adata, + pairs_mat=np.array(pair_matrix), + xkey="M_n", + ykey="M_t", + zkey="velocity_alpha_minus_gamma_s", + return_data=True, ) assert isinstance(ax, pd.DataFrame) @@ -255,10 +267,17 @@ def test_time_series_plot(adata): adata = adata.copy() adata.uns["umap_fit"]["umap_kwargs"]["max_iter"] = None - progenitor = adata.obs_names[adata.obs.Cell_type.isin(['Proliferating Progenitor', 'Pigment Progenitor'])] + progenitor = adata.obs_names[adata.obs.Cell_type.isin(["Proliferating Progenitor", "Pigment Progenitor"])] - dyn.pd.fate(adata, basis='umap', init_cells=progenitor, interpolation_num=25, direction='forward', - inverse_transform=True, average=False) + dyn.pd.fate( + adata, + basis="umap", + init_cells=progenitor, + interpolation_num=25, + direction="forward", + inverse_transform=True, + average=False, + ) ax = dyn.pl.kinetic_curves(adata, basis="umap", genes=adata.var_names[:4], save_show_or_return="return") assert isinstance(ax, sns.axisgrid.FacetGrid) @@ -266,10 +285,11 @@ def test_time_series_plot(adata): assert isinstance(ax, sns.matrix.ClusterGrid) dyn.tl.order_cells(adata, basis="umap") - progenitor = adata.obs_names[adata.obs.Cell_type.isin(['Proliferating Progenitor', 'Pigment Progenitor'])] + progenitor = adata.obs_names[adata.obs.Cell_type.isin(["Proliferating Progenitor", "Pigment Progenitor"])] - dyn.vf.jacobian(adata, basis="umap", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], - cell_idx=progenitor) + dyn.vf.jacobian( + adata, basis="umap", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=progenitor + ) ax = dyn.pl.jacobian_kinetics( adata, basis="umap", diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 4a5a10789..445922d8e 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -2,23 +2,23 @@ def test_fate(adata): - progenitor = adata.obs_names[adata.obs.Cell_type.isin(['Proliferating Progenitor', 'Pigment Progenitor'])] - dyn.pd.fate(adata, basis='umap', init_cells=progenitor, direction='backward') + progenitor = adata.obs_names[adata.obs.Cell_type.isin(["Proliferating Progenitor", "Pigment Progenitor"])] + dyn.pd.fate(adata, basis="umap", init_cells=progenitor, direction="backward") assert "fate_umap" in adata.uns.keys() assert adata.uns["fate_umap"]["prediction"][0].shape == (2, 250) - dyn.pd.fate(adata, basis='umap', init_cells=progenitor, direction='both') + dyn.pd.fate(adata, basis="umap", init_cells=progenitor, direction="both") assert "fate_umap" in adata.uns.keys() assert adata.uns["fate_umap"]["prediction"][0].shape == (2, 500) - dyn.pd.fate(adata, basis='umap', init_cells=progenitor, direction='forward') + dyn.pd.fate(adata, basis="umap", init_cells=progenitor, direction="forward") assert "fate_umap" in adata.uns.keys() assert adata.uns["fate_umap"]["prediction"][0].shape == (2, 250) bias = dyn.pd.fate_bias(adata, group="Cell_type") assert len(bias) == len(adata.uns["fate_umap"]["prediction"]) - dyn.pd.andecestor(adata, init_cells=adata.obs_names[adata.obs.Cell_type.isin(['Iridophore'])], direction='backward') + dyn.pd.andecestor(adata, init_cells=adata.obs_names[adata.obs.Cell_type.isin(["Iridophore"])], direction="backward") assert "ancestor" in adata.obs.keys() dyn.pd.andecestor(adata, init_cells=progenitor) @@ -29,10 +29,10 @@ def test_fate(adata): def test_perturbation(adata): - dyn.pd.perturbation(adata, basis='umap', genes=adata.var_names[0], expression=-10) + dyn.pd.perturbation(adata, basis="umap", genes=adata.var_names[0], expression=-10) assert "X_umap_perturbation" in adata.obsm.keys() - vf_ko = dyn.pd.KO(adata, basis='pca', KO_genes=adata.var_names[0]) + vf_ko = dyn.pd.KO(adata, basis="pca", KO_genes=adata.var_names[0]) assert vf_ko.K.shape[0] == adata.n_vars dyn.pd.rank_perturbation_genes(adata) @@ -57,18 +57,18 @@ def test_state_graph(adata): # cell_type_to_excluded=["Unknown"], # ) - dyn.pd.state_graph(adata, group='Cell_type') + dyn.pd.state_graph(adata, group="Cell_type") assert "Cell_type_graph" in adata.uns.keys() - ax = dyn.pl.state_graph(adata, group='Cell_type', save_show_or_return='return') + ax = dyn.pl.state_graph(adata, group="Cell_type", save_show_or_return="return") assert isinstance(ax, tuple) res = dyn.pd.tree_model( adata, - group='Cell_type', - basis='umap', - progenitor='Proliferating Progenitor', - terminators=['Iridophore'], + group="Cell_type", + basis="umap", + progenitor="Proliferating Progenitor", + terminators=["Iridophore"], ) assert len(res) == len(adata.obs["Cell_type"].unique()) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index ae9474a85..54a5610f1 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -14,7 +14,13 @@ from dynamo.preprocessing.cell_cycle import get_cell_phase from dynamo.preprocessing.deprecated import _calc_mean_var_dispersion_sparse_legacy from dynamo.preprocessing.normalization import calc_sz_factor, normalize -from dynamo.preprocessing.transform import log, log1p, log2, Freeman_Tukey, is_log1p_transformed_adata +from dynamo.preprocessing.transform import ( + Freeman_Tukey, + is_log1p_transformed_adata, + log, + log1p, + log2, +) from dynamo.preprocessing.utils import ( convert_layers2csr, is_float_integer_arr, @@ -112,7 +118,13 @@ def test_recipe_monocle_feature_selection_layer_simple0(): + rpe1_kinetics.layers["ul"], ) - del rpe1, rpe1_kinetics.layers["uu"], rpe1_kinetics.layers["ul"], rpe1_kinetics.layers["su"], rpe1_kinetics.layers["sl"] + del ( + rpe1, + rpe1_kinetics.layers["uu"], + rpe1_kinetics.layers["ul"], + rpe1_kinetics.layers["su"], + rpe1_kinetics.layers["sl"], + ) rpe1_kinetics = rpe1_kinetics[:100, :100] dyn.pl.basic_stats(rpe1_kinetics, save_show_or_return="return") rpe1_genes = ["UNG", "PCNA", "PLK1", "HPRT1"] @@ -148,7 +160,6 @@ def test_calc_dispersion_sparse(): # assert np.all(np.isclose(sc_var, expected_var)) - def test_Preprocessor_monocle_recipe(): raw_zebra_adata = dyn.sample_data.zebrafish() adata = raw_zebra_adata[:1000, :1000].copy() @@ -207,17 +218,17 @@ def test_layers2csr_matrix(): data = np.array([[1, 2], [3, 4]]) adata = anndata.AnnData( X=data, - obs={'obs1': ['cell1', 'cell2']}, - var={'var1': ['gene1', 'gene2']}, + obs={"obs1": ["cell1", "cell2"]}, + var={"var1": ["gene1", "gene2"]}, ) layer = csr_matrix([[1, 2], [3, 4]]).transpose() # Transpose the matrix - adata.layers['layer1'] = layer + adata.layers["layer1"] = layer result = dyn.preprocessing.utils.convert_layers2csr(adata) - assert issparse(result.layers['layer1']) - assert result.layers['layer1'].shape == layer.shape - assert (result.layers['layer1'].toarray() == layer.toarray()).all() + assert issparse(result.layers["layer1"]) + assert result.layers["layer1"].shape == layer.shape + assert (result.layers["layer1"].toarray() == layer.toarray()).all() def test_compute_gene_exp_fraction(): @@ -307,10 +318,10 @@ def test_filter_genes_by_clusters_(): # Add cluster information clusters = np.random.randint(low=0, high=3, size=n_cells) - adata.obs['clusters'] = clusters + adata.obs["clusters"] = clusters # Filter genes by cluster - clu_avg_selected = dyn.pp.filter_genes_by_clusters(adata, 'clusters') + clu_avg_selected = dyn.pp.filter_genes_by_clusters(adata, "clusters") # Check that the output is a numpy array assert type(clu_avg_selected) == np.ndarray @@ -319,7 +330,7 @@ def test_filter_genes_by_clusters_(): assert clu_avg_selected.shape == (n_genes,) # Check that all genes with U and S average > min_avg_U and min_avg_S respectively are selected - U, S = adata.layers['unspliced'], adata.layers['spliced'] + U, S = adata.layers["unspliced"], adata.layers["spliced"] U_avgs = np.array([np.mean(U[clusters == i], axis=0) for i in range(3)]) S_avgs = np.array([np.mean(S[clusters == i], axis=0) for i in range(3)]) expected_clu_avg_selected = np.any((U_avgs.max(1) > 0.02) & (S_avgs.max(1) > 0.08), axis=0) @@ -348,12 +359,9 @@ def test_filter_genes_by_outliers(): # check that the original object is unchanged assert np.all(adata.var_names.values == ["gene1", "gene2", "gene3", "gene4"]) - dyn.pp.filter_genes_by_outliers(adata, - min_avg_exp_s=0.5, - min_cell_s=2, - max_avg_exp=2.5, - min_count_s=2, - inplace=True) + dyn.pp.filter_genes_by_outliers( + adata, min_avg_exp_s=0.5, min_cell_s=2, max_avg_exp=2.5, min_count_s=2, inplace=True + ) # check that the adata has been updated assert adata.shape == (6, 3) @@ -362,14 +370,12 @@ def test_filter_genes_by_outliers(): def test_filter_cells_by_outliers(): # Create a test AnnData object with some example data - adata = anndata.AnnData( - X=np.array([[1, 0, 3], [4 ,0 ,0], [7, 8, 9], [10, 11, 12]])) + adata = anndata.AnnData(X=np.array([[1, 0, 3], [4, 0, 0], [7, 8, 9], [10, 11, 12]])) adata.var_names = ["gene1", "gene2", "gene3"] adata.obs_names = ["cell1", "cell2", "cell3", "cell4"] # Test the function with custom range values - dyn.pp.filter_cells_by_outliers( - adata, min_expr_genes_s=2, max_expr_genes_s=6) + dyn.pp.filter_cells_by_outliers(adata, min_expr_genes_s=2, max_expr_genes_s=6) assert np.array_equal( adata.obs_names.values, @@ -385,8 +391,7 @@ def test_filter_cells_by_outliers(): def test_filter_genes_by_patterns(): - adata = anndata.AnnData( - X=np.array([[1, 0, 3], [4, 0, 0], [7, 8, 9], [10, 11, 12]])) + adata = anndata.AnnData(X=np.array([[1, 0, 3], [4, 0, 0], [7, 8, 9], [10, 11, 12]])) adata.var_names = ["MT-1", "RPS", "GATA1"] adata.obs_names = ["cell1", "cell2", "cell3", "cell4"] diff --git a/tests/test_tl.py b/tests/test_tl.py index 0f8022d99..ea3df3091 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -6,9 +6,9 @@ import dynamo as dyn from dynamo.tools.connectivity import ( - generate_neighbor_keys, check_and_recompute_neighbors, check_neighbors_completeness, + generate_neighbor_keys, ) @@ -22,8 +22,7 @@ def test_calc_1nd_moment(): result, normalized_W = dyn.tl.calc_1nd_moment(X, W, normalize_W=True) expected_result = np.array([[3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]) assert np.array_equal(result, expected_result) - assert np.array_equal(normalized_W, - np.array([[0.0, 1, 0.0], [0.5, 0.0, 0.5], [0.0, 1, 0.0]])) + assert np.array_equal(normalized_W, np.array([[0.0, 1, 0.0], [0.5, 0.0, 0.5], [0.0, 1, 0.0]])) def test_calc_2nd_moment(): @@ -35,7 +34,7 @@ def test_calc_2nd_moment(): assert np.array_equal(result, expected_result) result = dyn.tl.calc_2nd_moment(X, Y, W, normalize_W=True, center=False) - expected_result = np.array([[12., 20.], [16., 24.], [12., 20.]]) + expected_result = np.array([[12.0, 20.0], [16.0, 24.0], [12.0, 20.0]]) assert np.array_equal(result, expected_result) @@ -53,20 +52,20 @@ def test_cell_growth_rate(adata): @pytest.mark.skip(reason="umap compatability issue with numpy, pynndescent and pytest") def test_dynamics(): adata = dyn.sample_data.scNT_seq_neuron_labeling() - adata.obs['label_time'] = 2 # this is the labeling time + adata.obs["label_time"] = 2 # this is the labeling time adata = adata[:, adata.var.activity_genes] - adata.obs['time'] = adata.obs['time'] / 60 + adata.obs["time"] = adata.obs["time"] / 60 adata1 = adata.copy() preprocessor = dyn.pp.Preprocessor(cell_cycle_score_enable=True) - preprocessor.preprocess_adata(adata1, recipe='monocle', tkey='label_time', experiment_type='one-shot') + preprocessor.preprocess_adata(adata1, recipe="monocle", tkey="label_time", experiment_type="one-shot") dyn.tl.dynamics(adata1) assert "velocity_N" in adata.layers.keys() adata2 = adata.copy() preprocessor = dyn.pp.Preprocessor(cell_cycle_score_enable=True) - preprocessor.preprocess_adata(adata2, recipe='monocle', tkey='label_time', experiment_type='kin') + preprocessor.preprocess_adata(adata2, recipe="monocle", tkey="label_time", experiment_type="kin") dyn.tl.dynamics(adata2) assert "velocity_N" in adata.layers.keys() @@ -175,7 +174,9 @@ def test_DDRTree_pseudotime(adata): dyn.tl.pseudotime_velocity(adata, pseudotime="Pseudotime") assert "velocity_S" in adata.layers.keys() - ax = dyn.pl.plot_dim_reduced_direct_graph(adata, graph=adata.uns["directed_velocity_tree"], save_show_or_return="return") + ax = dyn.pl.plot_dim_reduced_direct_graph( + adata, graph=adata.uns["directed_velocity_tree"], save_show_or_return="return" + ) assert isinstance(ax, list) @@ -302,6 +303,7 @@ def test_fp_operator(): def test_triangles(): import igraph as ig + g = ig.Graph(edges=[(0, 1), (1, 2), (2, 0), (2, 3), (3, 0)]) result = dyn.tools.graph_operators.triangles(g) @@ -319,7 +321,9 @@ def test_cell_and_gene_confidence(adata): dyn.tl.cell_wise_confidence(adata, method=method) assert method + "_velocity_confidence" in adata.obs.keys() - dyn.tl.confident_cell_velocities(adata, group="Cell_type", lineage_dict={'Proliferating Progenitor': ['Schwann Cell']}) + dyn.tl.confident_cell_velocities( + adata, group="Cell_type", lineage_dict={"Proliferating Progenitor": ["Schwann Cell"]} + ) assert "gene_wise_confidence" in adata.uns.keys() @@ -378,8 +382,8 @@ def test_broken_neighbors_check_recompute(): # Test for utils def smallest_distance_bf(coords): res = float("inf") - for (i, c1) in enumerate(coords): - for (j, c2) in enumerate(coords): + for i, c1 in enumerate(coords): + for j, c2 in enumerate(coords): if i == j: continue else: @@ -422,9 +426,10 @@ def test_norm_loglikelihood(): def test_fit_linreg(): - from dynamo.estimation.csc.utils_velocity import fit_linreg, fit_linreg_robust from sklearn.datasets import make_regression + from dynamo.estimation.csc.utils_velocity import fit_linreg, fit_linreg_robust + X0, y0 = make_regression(n_samples=100, n_features=1, noise=0.5, random_state=0) X1, y1 = make_regression(n_samples=100, n_features=1, noise=0.5, random_state=2) X = np.vstack([X0.T, X1.T]) diff --git a/tests/test_vf.py b/tests/test_vf.py index a72418b86..1f2163d09 100644 --- a/tests/test_vf.py +++ b/tests/test_vf.py @@ -13,7 +13,7 @@ def test_vector_calculus_rank_vf(adata): adata = adata.copy() - progenitor = adata.obs_names[adata.obs.Cell_type.isin(['Proliferating Progenitor', 'Pigment Progenitor'])] + progenitor = adata.obs_names[adata.obs.Cell_type.isin(["Proliferating Progenitor", "Pigment Progenitor"])] dyn.vf.velocities(adata, basis="umap", init_cells=progenitor) assert "velocities_umap" in adata.uns.keys() @@ -23,17 +23,27 @@ def test_vector_calculus_rank_vf(adata): ax = dyn.pl.speed(adata, basis="umap", save_show_or_return="return") assert isinstance(ax, plt.Axes) - dyn.vf.jacobian(adata, basis="umap", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=progenitor) + dyn.vf.jacobian( + adata, basis="umap", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=progenitor + ) assert "jacobian_umap" in adata.uns.keys() ax = dyn.pl.jacobian(adata, basis="umap", j_basis="umap", save_show_or_return="return") assert isinstance(ax, matplotlib.gridspec.GridSpec) ax = dyn.pl.jacobian_heatmap( - adata, basis="umap", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=[2, 3], save_show_or_return="return") + adata, + basis="umap", + regulators=["ptmaa", "rpl5b"], + effectors=["ptmaa", "rpl5b"], + cell_idx=[2, 3], + save_show_or_return="return", + ) assert isinstance(ax, matplotlib.gridspec.GridSpec) - dyn.vf.hessian(adata, basis="umap", regulators=["rpl5b"], coregulators=["ptmaa"], effector=["ptmaa"], cell_idx=progenitor) + dyn.vf.hessian( + adata, basis="umap", regulators=["rpl5b"], coregulators=["ptmaa"], effector=["ptmaa"], cell_idx=progenitor + ) assert "hessian_umap" in adata.uns.keys() dyn.vf.laplacian(adata, hkey="hessian_umap", basis="umap") @@ -46,7 +56,8 @@ def test_vector_calculus_rank_vf(adata): assert isinstance(ax, matplotlib.gridspec.GridSpec) ax = dyn.pl.sensitivity_heatmap( - adata, basis="umap", regulators=["rpl5b"], effectors=["ptmaa"], cell_idx=[2, 3], save_show_or_return="return") + adata, basis="umap", regulators=["rpl5b"], effectors=["ptmaa"], cell_idx=[2, 3], save_show_or_return="return" + ) assert isinstance(ax, matplotlib.gridspec.GridSpec) dyn.vf.acceleration(adata, basis="pca") @@ -146,14 +157,25 @@ def test_networks(adata): adata = adata.copy() - dyn.vf.jacobian(adata, basis="pca", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=np.arange(adata.n_obs)) + dyn.vf.jacobian( + adata, basis="pca", regulators=["ptmaa", "rpl5b"], effectors=["ptmaa", "rpl5b"], cell_idx=np.arange(adata.n_obs) + ) edges_list = dyn.vf.build_network_per_cluster(adata, cluster="Cell_type", genes=adata.var_names) assert isinstance(edges_list, dict) - network = nx.from_pandas_edgelist(edges_list['Unknown'], 'regulator', 'target', edge_attr='weight', - create_using=nx.DiGraph()) - ax = dyn.pl.arcPlot(adata, cluster="Cell_type", cluster_name="Unknown", edges_list=None, network=network, color="M_s", save_show_or_return="return") + network = nx.from_pandas_edgelist( + edges_list["Unknown"], "regulator", "target", edge_attr="weight", create_using=nx.DiGraph() + ) + ax = dyn.pl.arcPlot( + adata, + cluster="Cell_type", + cluster_name="Unknown", + edges_list=None, + network=network, + color="M_s", + save_show_or_return="return", + ) assert isinstance(ax, dyn.plot.networks.ArcPlot) res = dyn.vf.adj_list_to_matrix(adj_list=edges_list["Neuron"]) From 4025eea740fc96a13bffa7d9642e09d206204dbb Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 20:17:39 +0000 Subject: [PATCH 03/27] reformat with isort==5.12.0 black==22.6.0 --- dynamo/estimation/csc/velocity.py | 58 ++++---------------------- dynamo/external/hodge.py | 6 +-- dynamo/external/scribe.py | 6 ++- dynamo/plot/dynamics.py | 19 +-------- dynamo/plot/utils.py | 21 ++++++---- dynamo/preprocessing/gene_selection.py | 6 +-- dynamo/preprocessing/utils.py | 4 +- dynamo/simulation/ODE.py | 4 +- dynamo/tools/cell_velocities.py | 8 +--- dynamo/tools/deprecated.py | 5 +-- dynamo/tools/dynamics.py | 18 +------- dynamo/tools/markers.py | 5 +-- dynamo/tools/moments.py | 14 +------ dynamo/tools/pseudotime.py | 9 ++-- dynamo/tools/utils.py | 11 +---- dynamo/vectorfield/scPotential.py | 24 +---------- dynamo/vectorfield/scVectorField.py | 7 +--- 17 files changed, 53 insertions(+), 172 deletions(-) diff --git a/dynamo/estimation/csc/velocity.py b/dynamo/estimation/csc/velocity.py index 3130faf3c..30769184c 100755 --- a/dynamo/estimation/csc/velocity.py +++ b/dynamo/estimation/csc/velocity.py @@ -662,14 +662,7 @@ def fit( bs, bf, ) = zip(*res) - ( - gamma, - gamma_intercept, - gamma_r2, - gamma_logLL, - bs, - bf, - ) = ( + (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( np.array(gamma), np.array(gamma_intercept), np.array(gamma_r2), @@ -755,14 +748,7 @@ def fit( bs, bf, ) = zip(*res) - ( - gamma, - gamma_intercept, - gamma_r2, - gamma_logLL, - bs, - bf, - ) = ( + (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( np.array(gamma), np.array(gamma_intercept), np.array(gamma_r2), @@ -850,11 +836,7 @@ def fit( uu_m, uu_v, _ = calc_12_mom_labeling(self.data["uu"], self.t) if cores == 1: for i in tqdm(range(n_genes), desc="estimating alpha"): - ( - alpha[i], - alpha_b[i], - alpha_r2[i], - ) = fit_alpha_degradation( + (alpha[i], alpha_b[i], alpha_r2[i],) = fit_alpha_degradation( t_uniq, uu_m[i], self.parameters["gamma"][i], @@ -1046,10 +1028,7 @@ def fit( total, ), ) - ( - self.aux_param["total0"], - self.parameters["gamma"], - ) = ( + (self.aux_param["total0"], self.parameters["gamma"],) = ( total0, gamma, ) @@ -1086,14 +1065,7 @@ def fit( if issparse(self.data["ul"]) else np.zeros_like(self.data["ul"].shape) ) - ( - t_uniq, - gamma, - gamma_k, - gamma_intercept, - gamma_r2, - gamma_logLL, - ) = ( + (t_uniq, gamma, gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( np.unique(self.t), np.zeros(n_genes), np.zeros(n_genes), @@ -1141,12 +1113,7 @@ def fit( _, gamma_logLL, ) = zip(*res1) - ( - gamma_k, - gamma_intercept, - gamma_r2, - gamma_logLL, - ) = ( + (gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( np.array(gamma_k), np.array(gamma_intercept), np.array(gamma_r2), @@ -1500,11 +1467,7 @@ def fit( # gamma_3 = solve_gamma(np.max(self.t), self.data['uu'][i, self.t == np.max(self.t)], tmp) # sci-fate gamma[i] = gamma_2 # print('Steady state, stimulation, sci-fate like gamma values are ', gamma_1, '; ', gamma_2, '; ', gamma_3) - ( - self.parameters["gamma"], - self.aux_param["U0"], - self.parameters["beta"], - ) = ( + (self.parameters["gamma"], self.aux_param["U0"], self.parameters["beta"],) = ( gamma, U, np.ones(gamma.shape), @@ -1521,12 +1484,7 @@ def fit( if self.asspt_prot.lower() == "ss" and n_genes > 0: self.parameters["eta"] = np.ones(n_genes) - ( - delta, - delta_intercept, - delta_r2, - delta_logLL, - ) = ( + (delta, delta_intercept, delta_r2, delta_logLL,) = ( np.zeros(n_genes), np.zeros(n_genes), np.zeros(n_genes), diff --git a/dynamo/external/hodge.py b/dynamo/external/hodge.py index 4eb359fbd..8864434bf 100644 --- a/dynamo/external/hodge.py +++ b/dynamo/external/hodge.py @@ -218,11 +218,7 @@ def func(x): W.dot(ddhodge_div), W.dot(potential_), ) - ( - adata.obs[prefix + "ddhodge_sampled"], - adata.obs[prefix + "ddhodge_div"], - adata.obs[prefix + "potential"], - ) = ( + (adata.obs[prefix + "ddhodge_sampled"], adata.obs[prefix + "ddhodge_div"], adata.obs[prefix + "potential"],) = ( False, 0, 0, diff --git a/dynamo/external/scribe.py b/dynamo/external/scribe.py index 41831f98a..587140d2f 100644 --- a/dynamo/external/scribe.py +++ b/dynamo/external/scribe.py @@ -86,7 +86,11 @@ def scribe( str_format = ( "upper" if adata.var_names[0].isupper() - else "lower" if adata.var_names[0].islower() else "title" if adata.var_names[0].istitle() else "other" + else "lower" + if adata.var_names[0].islower() + else "title" + if adata.var_names[0].istitle() + else "other" ) motifAnnotations_hgnc = pd.read_csv(motif_ref, sep="\t") diff --git a/dynamo/plot/dynamics.py b/dynamo/plot/dynamics.py index 6f10babf5..3824b213d 100755 --- a/dynamo/plot/dynamics.py +++ b/dynamo/plot/dynamics.py @@ -1629,14 +1629,7 @@ def dynamics( mom.integrate(t) mom_data = mom.get_all_central_moments() if has_splicing else mom.get_nosplice_central_moments() if true_param_prefix is not None: - ( - true_a, - true_b, - true_alpha_a, - true_alpha_i, - true_beta, - true_gamma, - ) = ( + (true_a, true_b, true_alpha_a, true_alpha_i, true_beta, true_gamma,) = ( ( vel_params_df.loc[gene_name, true_param_prefix + "a"] if true_param_prefix + "a" in vel_params_df.columns @@ -1883,15 +1876,7 @@ def dynamics( np.log1p(sl), ) - ( - alpha, - beta, - gamma, - ul0, - sl0, - uu0, - half_life, - ) = vel_params_df.loc[ + (alpha, beta, gamma, ul0, sl0, uu0, half_life,) = vel_params_df.loc[ gene_name, [ prefix + "alpha", diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index d4caedc09..d60bd1c35 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -246,7 +246,9 @@ def calculate_colors( else ( np.nanpercentile(values, vmin * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmin) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmin + else np.nanpercentile(values, vmin) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmin ) ) _vmax = ( @@ -255,7 +257,9 @@ def calculate_colors( else ( np.nanpercentile(values, vmax * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmax) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmax + else np.nanpercentile(values, vmax) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmax ) ) @@ -663,7 +667,9 @@ def _matplotlib_points( else ( np.nanpercentile(values, vmin * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmin) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmin + else np.nanpercentile(values, vmin) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmin ) ) _vmax = ( @@ -672,7 +678,9 @@ def _matplotlib_points( else ( np.nanpercentile(values, vmax * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) - else np.nanpercentile(values, vmax) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmax + else np.nanpercentile(values, vmax) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmax ) ) @@ -934,10 +942,7 @@ def _datashade_points( data["label"] == "other", ) reorder_data = data.copy(deep=True) - ( - reorder_data.iloc[: sum(background_ids), :], - reorder_data.iloc[sum(background_ids) :, :], - ) = ( + (reorder_data.iloc[: sum(background_ids), :], reorder_data.iloc[sum(background_ids) :, :],) = ( data.loc[background_ids, :], data.loc[highlight_ids, :], ) diff --git a/dynamo/preprocessing/gene_selection.py b/dynamo/preprocessing/gene_selection.py index 12fbe1db1..3016e725c 100644 --- a/dynamo/preprocessing/gene_selection.py +++ b/dynamo/preprocessing/gene_selection.py @@ -247,11 +247,7 @@ def calc_dispersion_by_svr( adata_ori.uns[key] = {"mean": mean, "cv": cv, "svr_gamma": svr_gamma} prefix = "" if layer == "X" else layer + "_" - ( - adata.var[prefix + "log_m"], - adata.var[prefix + "log_cv"], - adata.var[prefix + "score"], - ) = ( + (adata.var[prefix + "log_m"], adata.var[prefix + "log_cv"], adata.var[prefix + "score"],) = ( np.nan, np.nan, -np.inf, diff --git a/dynamo/preprocessing/utils.py b/dynamo/preprocessing/utils.py index 789811306..7c428b65e 100755 --- a/dynamo/preprocessing/utils.py +++ b/dynamo/preprocessing/utils.py @@ -703,7 +703,9 @@ def default_layer(adata: anndata.AnnData) -> str: default_layer = ( "M_s" if "M_s" in adata.layers.keys() - else "X_spliced" if "X_spliced" in adata.layers.keys() else "spliced" + else "X_spliced" + if "X_spliced" in adata.layers.keys() + else "spliced" ) else: default_layer = ( diff --git a/dynamo/simulation/ODE.py b/dynamo/simulation/ODE.py index 7bb1c8388..b286051ee 100755 --- a/dynamo/simulation/ODE.py +++ b/dynamo/simulation/ODE.py @@ -495,7 +495,9 @@ def neurongenesis( dx[:, 5] = a * (x[:, 0] ** n) / (1 + x[:, 0] ** n + x[:, 1] ** n) - k * x[:, 5] dx[:, 6] = a_e * (eta**n * x[:, 5] ** n) / (1 + eta**n * x[:, 5] ** n + x[:, 7] ** n) - k * x[:, 6] dx[:, 7] = a_e * (eta**n * x[:, 5] ** n) / (1 + x[:, 6] ** n + eta**n * x[:, 5] ** n) - k * x[:, 7] - dx[:, 8] = a * (eta**n * x[:, 5] ** n * x[:, 6] ** n) / (1 + eta**n * x[:, 5] ** n * x[:, 6] ** n) - k * x[:, 8] + dx[:, 8] = ( + a * (eta**n * x[:, 5] ** n * x[:, 6] ** n) / (1 + eta**n * x[:, 5] ** n * x[:, 6] ** n) - k * x[:, 8] + ) dx[:, 9] = a * (x[:, 7] ** n) / (1 + x[:, 7] ** n) - k * x[:, 9] dx[:, 10] = a_e * (x[:, 8] ** n) / (1 + x[:, 8] ** n) - k * x[:, 10] dx[:, 11] = a * (eta_m**n * x[:, 7] ** n) / (1 + eta_m**n * x[:, 7] ** n) - k * x[:, 11] diff --git a/dynamo/tools/cell_velocities.py b/dynamo/tools/cell_velocities.py index 38482dbeb..baff9cb89 100755 --- a/dynamo/tools/cell_velocities.py +++ b/dynamo/tools/cell_velocities.py @@ -456,13 +456,7 @@ def cell_velocities( if calc_rnd_vel: permute_rows_nsign(V) - ( - T_rnd, - delta_X_rnd, - X_grid_rnd, - V_grid_rnd, - D_rnd, - ) = kernels_from_velocyto_scvelo( + (T_rnd, delta_X_rnd, X_grid_rnd, V_grid_rnd, D_rnd,) = kernels_from_velocyto_scvelo( X, X_embedding, V, diff --git a/dynamo/tools/deprecated.py b/dynamo/tools/deprecated.py index 643858b83..3fef4f922 100644 --- a/dynamo/tools/deprecated.py +++ b/dynamo/tools/deprecated.py @@ -1837,10 +1837,7 @@ def moment_model(adata, subset_adata, _group, cur_grp, log_unnormalized, tkey): else: if log_unnormalized and "X_total" not in subset_adata.layers.keys(): if issparse(subset_adata.layers["total"]): - ( - subset_adata.layers["new"].data, - subset_adata.layers["total"].data, - ) = ( + (subset_adata.layers["new"].data, subset_adata.layers["total"].data,) = ( np.log1p(subset_adata.layers["new"].data), np.log1p(subset_adata.layers["total"].data), ) diff --git a/dynamo/tools/dynamics.py b/dynamo/tools/dynamics.py index 39d505a1b..490b574b0 100755 --- a/dynamo/tools/dynamics.py +++ b/dynamo/tools/dynamics.py @@ -294,13 +294,7 @@ def dynamics( raise ValueError(f"\nPlease run `dyn.pp.receipe_monocle(adata)` before running this function!") if tkey is None: tkey = adata.uns["pp"]["tkey"] - ( - experiment_type, - has_splicing, - has_labeling, - splicing_labeling, - has_protein, - ) = ( + (experiment_type, has_splicing, has_labeling, splicing_labeling, has_protein,) = ( adata.uns["pp"]["experiment_type"], adata.uns["pp"]["has_splicing"], adata.uns["pp"]["has_labeling"], @@ -733,15 +727,7 @@ def dynamics( est_method = "direct" data_type = "smoothed" if use_smoothed else "sfs" - ( - params, - half_life, - cost, - logLL, - param_ranges, - cur_X_data, - cur_X_fit_data, - ) = kinetic_model( + (params, half_life, cost, logLL, param_ranges, cur_X_data, cur_X_fit_data,) = kinetic_model( subset_adata, tkey, model, diff --git a/dynamo/tools/markers.py b/dynamo/tools/markers.py index df7c9c11b..82c783fb9 100755 --- a/dynamo/tools/markers.py +++ b/dynamo/tools/markers.py @@ -729,10 +729,7 @@ def diff_test_helper( data: pd.DataFrame, fullModelFormulaStr: str = "~cr(time, df=3)", reducedModelFormulaStr: str = "~1", -) -> Union[ - Tuple[Literal["fail"], Literal["NB2"], Literal[1]], - Tuple[Literal["ok"], Literal["NB2"], np.ndarray], -]: +) -> Union[Tuple[Literal["fail"], Literal["NB2"], Literal[1]], Tuple[Literal["ok"], Literal["NB2"], np.ndarray],]: """A helper function to generate required data fields for differential gene expression test. Args: diff --git a/dynamo/tools/moments.py b/dynamo/tools/moments.py index de18fb51f..83079fd6c 100755 --- a/dynamo/tools/moments.py +++ b/dynamo/tools/moments.py @@ -126,12 +126,7 @@ def moments( with warnings.catch_warnings(): warnings.simplefilter("ignore") if group is None: - ( - kNN, - knn_indices, - knn_dists, - _, - ) = umap_conn_indices_dist_embedding( + (kNN, knn_indices, knn_dists, _,) = umap_conn_indices_dist_embedding( X, n_neighbors=np.min((n_neighbors, adata.n_obs - 1)), return_mapper=False, @@ -151,12 +146,7 @@ def moments( for cur_grp in uniq_grp: cur_cells = cells_group == cur_grp cur_X = X[cur_cells, :] - ( - cur_kNN, - cur_knn_indices, - cur_knn_dists, - _, - ) = umap_conn_indices_dist_embedding( + (cur_kNN, cur_knn_indices, cur_knn_dists, _,) = umap_conn_indices_dist_embedding( cur_X, n_neighbors=np.min((n_neighbors, sum(cur_cells) - 1)), return_mapper=False, diff --git a/dynamo/tools/pseudotime.py b/dynamo/tools/pseudotime.py index 3ca3e2bb5..0d3e87a7e 100755 --- a/dynamo/tools/pseudotime.py +++ b/dynamo/tools/pseudotime.py @@ -97,9 +97,12 @@ def order_cells( root_cell = select_root_cell(adata, Z=Z, root_state=root_state, init_cells=init_cells, reverse=reverse) cc_ordering = get_order_from_DDRTree(dp=dp, mst=mst, root_cell=root_cell) - (cellPairwiseDistances, pr_graph_cell_proj_dist, pr_graph_cell_proj_closest_vertex, pr_graph_cell_proj_tree) = ( - project2MST(mst, Z, Y, project_point_to_line_segment) - ) + ( + cellPairwiseDistances, + pr_graph_cell_proj_dist, + pr_graph_cell_proj_closest_vertex, + pr_graph_cell_proj_tree, + ) = project2MST(mst, Z, Y, project_point_to_line_segment) adata.uns["cell_order"]["root_cell"] = root_cell adata.uns["cell_order"]["centers_order"] = cc_ordering["orders"].values diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index 6d717bdc9..9abae007a 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -1306,16 +1306,7 @@ def get_data_for_kin_params_estimation( ) NTR_vel = True - ( - U, - Ul, - S, - Sl, - P, - US, - U2, - S2, - ) = ( + (U, Ul, S, Sl, P, US, U2, S2,) = ( None, None, None, diff --git a/dynamo/vectorfield/scPotential.py b/dynamo/vectorfield/scPotential.py index d9cee1bee..a51c3f59d 100755 --- a/dynamo/vectorfield/scPotential.py +++ b/dynamo/vectorfield/scPotential.py @@ -183,19 +183,7 @@ def fit( "A": A, } elif method == "Bhattacharya": - ( - _, - _, - _, - _, - numPaths, - numTimeSteps, - pot_path, - path_tag, - attractors_pot, - x_path, - y_path, - ) = path_integral( + (_, _, _, _, numPaths, numTimeSteps, pot_path, path_tag, attractors_pot, x_path, y_path,) = path_integral( self.VecFld["Function"], x_lim=x_lim, y_lim=y_lim, @@ -228,15 +216,7 @@ def fit( self.VecFld["Function"], self.VecFld["DiffusionMatrix"], ) - ( - boundary, - n_points, - fixed_point_only, - find_fixed_points, - refpoint, - stable, - saddle, - ) = ( + (boundary, n_points, fixed_point_only, find_fixed_points, refpoint, stable, saddle,) = ( self.parameters["boundary"], self.parameters["n_points"], self.parameters["fixed_point_only"], diff --git a/dynamo/vectorfield/scVectorField.py b/dynamo/vectorfield/scVectorField.py index 1e42baee4..8d185e796 100755 --- a/dynamo/vectorfield/scVectorField.py +++ b/dynamo/vectorfield/scVectorField.py @@ -1012,12 +1012,7 @@ def train(self, **kwargs) -> VecFldDict: if self.normalize: X_norm, V_norm, T_norm, norm_dict = norm(self.data["X"], self.data["V"], self.data["Grid"]) - ( - self.data["X"], - self.data["V"], - self.data["Grid"], - self.norm_dict, - ) = ( + (self.data["X"], self.data["V"], self.data["Grid"], self.norm_dict,) = ( X_norm, V_norm, T_norm, From 97ed400cad157a6c8a6d2f375a7036f890c22ffb Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 20:19:18 +0000 Subject: [PATCH 04/27] update requirements to accomodate python3.8 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index de7c56e53..34f2ce896 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ scipy>=1.4.1 scikit-learn>=0.19.1,<1.5.0 anndata>=0.8.0 loompy>=3.0.5 -matplotlib>=3.9.0 +matplotlib>=3.7.5 setuptools numdifftools>=0.9.39 umap-learn>=0.5.1 From 674f01305a8a01ba4644f084b5234f3faefedae5 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 21:23:37 +0000 Subject: [PATCH 05/27] chore: fix the circular import issue --- dynamo/preprocessing/cell_cycle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dynamo/preprocessing/cell_cycle.py b/dynamo/preprocessing/cell_cycle.py index 9f951b400..c23505092 100644 --- a/dynamo/preprocessing/cell_cycle.py +++ b/dynamo/preprocessing/cell_cycle.py @@ -9,7 +9,7 @@ import pandas as pd from scipy.sparse import issparse -from ..tools.utils import einsum_correlation, log1p_ +from ..tools.utils import log1p_ from ..utils import LoggerManager, copy_adata @@ -30,7 +30,7 @@ def group_corr(adata: anndata.AnnData, layer: Union[str, None], gene_list: list) between input gene_list and the adata.var_names. corr contains the correlation coefficient of each gene with the mean expression of all genes in the list. """ - + from ..tools.utils import einsum_correlation # returns list of correlations of each gene within a list of genes with the total expression of the group tmp = adata.var_names.intersection(gene_list) # get the location of gene names From bd9d7e76921a6461b9612ee4f2f3f42b7ed8cd65 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 21:26:42 +0000 Subject: [PATCH 06/27] chore: Import log1p_ function in cell_cycle.py --- dynamo/preprocessing/cell_cycle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dynamo/preprocessing/cell_cycle.py b/dynamo/preprocessing/cell_cycle.py index c23505092..79ff49e99 100644 --- a/dynamo/preprocessing/cell_cycle.py +++ b/dynamo/preprocessing/cell_cycle.py @@ -9,7 +9,6 @@ import pandas as pd from scipy.sparse import issparse -from ..tools.utils import log1p_ from ..utils import LoggerManager, copy_adata @@ -30,7 +29,7 @@ def group_corr(adata: anndata.AnnData, layer: Union[str, None], gene_list: list) between input gene_list and the adata.var_names. corr contains the correlation coefficient of each gene with the mean expression of all genes in the list. """ - from ..tools.utils import einsum_correlation + from ..tools.utils import einsum_correlation, log1p_ # returns list of correlations of each gene within a list of genes with the total expression of the group tmp = adata.var_names.intersection(gene_list) # get the location of gene names From bd82f3b979c410ad62de4ef8b46d0cfd581280de Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Thu, 4 Jul 2024 05:29:16 +0800 Subject: [PATCH 07/27] Update python-package.yml --- .github/workflows/python-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 475d15062..4bb3e0bcb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.9] steps: - uses: actions/checkout@v2 @@ -40,4 +40,4 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics # - name: Test with pytest # run: | -# pytest \ No newline at end of file +# pytest From b3acc5d212fc7ebd11438870e1729d436c949fb2 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 21:32:15 +0000 Subject: [PATCH 08/27] fix log1p_ importing --- dynamo/preprocessing/cell_cycle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dynamo/preprocessing/cell_cycle.py b/dynamo/preprocessing/cell_cycle.py index 79ff49e99..02eadf339 100644 --- a/dynamo/preprocessing/cell_cycle.py +++ b/dynamo/preprocessing/cell_cycle.py @@ -105,7 +105,7 @@ def group_score(adata: anndata.AnnData, layer: Union[str, None], gene_list: List Returns: The Z-scored expression data. """ - + from ..tools.utils import log1p_ tmp = adata.var_names.intersection(gene_list) # use indices intersect_genes = [adata.var_names.get_loc(i) for i in tmp] From 52f9542751b66f712ff61c47774c57ba4dc98e33 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 21:48:14 +0000 Subject: [PATCH 09/27] chore: Import 'update_dict' inside function to fix circular import issue --- dynamo/preprocessing/cell_cycle.py | 2 ++ dynamo/preprocessing/deprecated.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dynamo/preprocessing/cell_cycle.py b/dynamo/preprocessing/cell_cycle.py index 02eadf339..540f9dd52 100644 --- a/dynamo/preprocessing/cell_cycle.py +++ b/dynamo/preprocessing/cell_cycle.py @@ -30,6 +30,7 @@ def group_corr(adata: anndata.AnnData, layer: Union[str, None], gene_list: list) the mean expression of all genes in the list. """ from ..tools.utils import einsum_correlation, log1p_ + # returns list of correlations of each gene within a list of genes with the total expression of the group tmp = adata.var_names.intersection(gene_list) # get the location of gene names @@ -106,6 +107,7 @@ def group_score(adata: anndata.AnnData, layer: Union[str, None], gene_list: List The Z-scored expression data. """ from ..tools.utils import log1p_ + tmp = adata.var_names.intersection(gene_list) # use indices intersect_genes = [adata.var_names.get_loc(i) for i in tmp] diff --git a/dynamo/preprocessing/deprecated.py b/dynamo/preprocessing/deprecated.py index 122b39a1c..00cfbbd7f 100644 --- a/dynamo/preprocessing/deprecated.py +++ b/dynamo/preprocessing/deprecated.py @@ -26,7 +26,6 @@ main_info_insert_adata_obsm, main_warning, ) -from ..tools.utils import update_dict from ..utils import copy_adata from .cell_cycle import cell_cycle_scores from .gene_selection import calc_dispersion_by_svr @@ -1785,6 +1784,7 @@ def _select_genes_monocle_legacy( for downstream analysis. adata will be subsetted with only the genes pass filter if keep_unflitered is set to be False. """ + from ..tools.utils import update_dict filter_bool = ( adata.var["pass_basic_filter"] From 2fffc20f7d36c7b628327be2c80795c8ce3be944 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Wed, 3 Jul 2024 23:29:35 +0000 Subject: [PATCH 10/27] fixing more circular imports --- dynamo/prediction/utils.py | 4 +++- dynamo/simulation/evaluation.py | 4 +--- dynamo/vectorfield/stochastic_process.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dynamo/prediction/utils.py b/dynamo/prediction/utils.py index b1a8f0650..7064b3c34 100644 --- a/dynamo/prediction/utils.py +++ b/dynamo/prediction/utils.py @@ -8,7 +8,6 @@ from tqdm import tqdm from ..dynamo_logger import main_warning -from ..tools.utils import log1p_, nearest_neighbors from ..utils import isarray, normalize # import scipy.sparse as sp @@ -335,6 +334,8 @@ def estimate_sigma( Returns: The estimated diffusion matrix. """ + from ..tools.utils import nearest_neighbors + if nbr_idx is None: nbr_idx = nearest_neighbors(X, X, k=num_nbrs) @@ -450,6 +451,7 @@ def fetch_exprs( Returns: The expression data for the given genes and time points. """ + from ..tools.utils import log1p_ if type(genes) != list: genes = list(genes) diff --git a/dynamo/simulation/evaluation.py b/dynamo/simulation/evaluation.py index ca3f3d82d..042946504 100755 --- a/dynamo/simulation/evaluation.py +++ b/dynamo/simulation/evaluation.py @@ -1,8 +1,6 @@ import numpy as np from sklearn.metrics import mean_squared_error -from ..tools.utils import einsum_correlation - def evaluate(reference: np.ndarray, prediction: np.ndarray, metric: str = "cosine") -> float: """Function to evaluate the vector field related reference quantities vs. that from vector field prediction. @@ -17,7 +15,7 @@ def evaluate(reference: np.ndarray, prediction: np.ndarray, metric: str = "cosin Returns: res: The score between the reference vs. reconstructed quantities based on the metric. """ - + from ..tools.utils import einsum_correlation if metric == "cosine": true_normalized = reference / (np.linalg.norm(reference, axis=1).reshape(-1, 1) + 1e-20) predict_normalized = prediction / (np.linalg.norm(prediction, axis=1).reshape(-1, 1) + 1e-20) diff --git a/dynamo/vectorfield/stochastic_process.py b/dynamo/vectorfield/stochastic_process.py index 75f6c63ae..b3316db72 100644 --- a/dynamo/vectorfield/stochastic_process.py +++ b/dynamo/vectorfield/stochastic_process.py @@ -10,7 +10,7 @@ generate_neighbor_keys, k_nearest_neighbors, ) -from ..tools.utils import log1p_ + from .utils import VecFldDict, vecfld_from_adata, vector_field_function @@ -47,6 +47,8 @@ def diffusionMatrix( the diffusion matrix for each cell. A column `diffusion` corresponds to the square root of the sum of all elements for each cell's diffusion matrix will also be added. """ + + from ..tools.utils import log1p_ if X_data is None or V_data is not None: if genes is not None: From 84d956be2a30fe4bf98de6793d63d004283b808d Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Thu, 4 Jul 2024 00:38:16 +0000 Subject: [PATCH 11/27] Update node version to 22.x and 24.x in GitHub Actions workflow file --- .github/workflows/python-docker.yml | 2 +- dynamo/vectorfield/stochastic_process.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/python-docker.yml b/.github/workflows/python-docker.yml index 2da5531f0..9ed47dbab 100644 --- a/.github/workflows/python-docker.yml +++ b/.github/workflows/python-docker.yml @@ -29,7 +29,7 @@ jobs: strategy: matrix: os: [ubuntu-lastest, windows-2016] - node-version: [12.x, 14.x] + node-version: [22.x, 24.x] steps: - uses: actions/checkout@v1 diff --git a/dynamo/vectorfield/stochastic_process.py b/dynamo/vectorfield/stochastic_process.py index b3316db72..256bdb0ec 100644 --- a/dynamo/vectorfield/stochastic_process.py +++ b/dynamo/vectorfield/stochastic_process.py @@ -10,7 +10,6 @@ generate_neighbor_keys, k_nearest_neighbors, ) - from .utils import VecFldDict, vecfld_from_adata, vector_field_function From d2b8c534e52c1e5c3880af6032f520b9c49497f2 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Thu, 4 Jul 2024 00:39:00 +0000 Subject: [PATCH 12/27] change import order in init since dynamo relies on init order to work --- dynamo/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dynamo/__init__.py b/dynamo/__init__.py index d5735723c..870f8c30e 100755 --- a/dynamo/__init__.py +++ b/dynamo/__init__.py @@ -10,7 +10,19 @@ # # __version__ = get_dynamo_version() -from . import configuration, est, ext, mv, pd, pl, pp, sample_data, shiny, sim, tl, vf +from . import pp +from . import est +from . import tl +from . import vf +from . import pd +from . import pl +from . import mv +from . import shiny +from . import sim +from . import sample_data +from . import configuration +from . import ext + from .data_io import * from .dynamo_logger import ( Logger, From facac04c0887129d32fa39075b0f419a84db83b9 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 04:05:16 +0000 Subject: [PATCH 13/27] chore: Update resource data URLs from dropbox to figshare --- dynamo/external/scifate.py | 8 ++++---- dynamo/external/scribe.py | 8 ++++---- dynamo/preprocessing/utils.py | 1 + dynamo/sample_data.py | 29 +++++++++++++++++------------ 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/dynamo/external/scifate.py b/dynamo/external/scifate.py index 76eeaee61..1f45df13a 100644 --- a/dynamo/external/scifate.py +++ b/dynamo/external/scifate.py @@ -21,8 +21,8 @@ def scifate_glmnet( cell_filter_UMI: int = 10000, core_n_lasso: int = 1, core_n_filtering: int = 1, - motif_ref: str = "https://www.dropbox.com/s/s8em539ojl55kgf/motifAnnotations_hgnc.csv?dl=1", - TF_link_ENCODE_ref: str = "https://www.dropbox.com/s/bjuope41pte7mf4/df_gene_TF_link_ENCODE.csv?dl=1", + motif_ref: str = "https://figshare.com/ndownloader/files/47439455", + TF_link_ENCODE_ref: str = "https://figshare.com/ndownloader/files/47439458", nt_layers: list = ["X_new", "X_total"], ) -> AnnData: """Perform scifate analysis using glmnet. @@ -58,11 +58,11 @@ def scifate_glmnet( motif_ref: The path to the TF binding motif data as described above. It provides the list of TFs gene names and is used to process adata object to generate the TF expression and target new expression matrix for glmnet based TF-target synthesis rate linkage analysis. But currently it is not used for motif based filtering. - By default, it is a dropbox link that store the data from us. Other motif reference can bed downloaded from + By default, it is a cloud link that store the data from us. Other motif reference can bed downloaded from RcisTarget: https://resources.aertslab.org/cistarget/. For human motif matrix, it can be downloaded from June's shared folder: https://shendure-web.gs.washington.edu/content/members/cao1025/public/nobackup/sci_fate/data/hg19-tss-centered-10kb-7species.mc9nr.feather - TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a dropbox link from us that stores the + TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a cloud link from us that stores the data. Other data can be downloaded from: https://amp.pharm.mssm.edu/Harmonizome/dataset/ENCODE+Transcription+Factor+Targets. nt_layers: The layers that will be used for the network inference. Note that the layers can be changed flexibly. diff --git a/dynamo/external/scribe.py b/dynamo/external/scribe.py index 587140d2f..43675803e 100644 --- a/dynamo/external/scribe.py +++ b/dynamo/external/scribe.py @@ -19,12 +19,12 @@ def scribe( Targets: Union[list, None] = None, gene_filter_rate: float = 0.1, cell_filter_UMI: int = 10000, - motif_ref: str = "https://www.dropbox.com/s/s8em539ojl55kgf/motifAnnotations_hgnc.csv?dl=1", + motif_ref: str = "https://figshare.com/ndownloader/files/47439455", nt_layers: list = ["X_new", "X_total"], normalize: bool = False, do_CLR: bool = True, drop_zero_cells: bool = True, - TF_link_ENCODE_ref: str = "https://www.dropbox.com/s/bjuope41pte7mf4/df_gene_TF_link_ENCODE.csv?dl=1", + TF_link_ENCODE_ref: str = "https://figshare.com/ndownloader/files/47439458", ) -> AnnData: """Apply Scribe to calculate causal network from spliced/unspliced, metabolic labeling based and other "real" time series datasets. @@ -49,7 +49,7 @@ def scribe( cell_filter_UMI: Minimum number of UMIs for cell filtering. motif_ref: It provides the list of TFs gene names and is used to parse the data to get the list of TFs and Targets for the causal network inference from those TFs to Targets. But currently the motif based filtering - is not implemented. By default, it is a dropbox link that store the data from us. Other motif reference can + is not implemented. By default, it is a cloud link that store the data from us. Other motif reference can bed downloaded from RcisTarget: https://resources.aertslab.org/cistarget/. For human motif matrix, it can be downloaded from June's shared folder: https://shendure-web.gs.washington.edu/content/members/cao1025/public/nobackup/sci_fate/data/hg19-tss- @@ -64,7 +64,7 @@ def scribe( target. This can signify the relationship between potential regulators and targets, speed up the calculation, but at the risk of ignoring strong inhibition effects from certain regulators to targets. do_CLR: Whether to perform context likelihood relatedness analysis on the reconstructed causal network - TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a dropbox link from us that stores the + TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a cloud link from us that stores the data. Other data can be downloaded from: https://amp.pharm.mssm.edu/Harmonizome/dataset/ENCODE+Transcription+Factor+Targets. diff --git a/dynamo/preprocessing/utils.py b/dynamo/preprocessing/utils.py index 7c428b65e..766501383 100755 --- a/dynamo/preprocessing/utils.py +++ b/dynamo/preprocessing/utils.py @@ -840,6 +840,7 @@ def relative2abs( """ if ERCC_annotation is None: + #TODO: outdated link. consider replacing or removing it. ERCC_annotation = pd.read_csv( "https://www.dropbox.com/s/cmiuthdw5tt76o5/ERCC_specification.txt?dl=1", sep="\t", diff --git a/dynamo/sample_data.py b/dynamo/sample_data.py index b0c8fef1e..ef3ee1174 100755 --- a/dynamo/sample_data.py +++ b/dynamo/sample_data.py @@ -65,27 +65,32 @@ def get_adata(url: str, filename: Optional[str] = None) -> Optional[AnnData]: # add our toy sample data def Gillespie(): + #TODO: add data here pass def HL60(): + #TODO: add data here pass def NASCseq(): + #TODO: add data here pass def scSLAMseq(): + #TODO: add data here pass def scifate(): + #TODO: add data here pass def scNT_seq_neuron_splicing( - url: str = "https://www.dropbox.com/s/g1afqdcsczgyj2m/neuron_splicing_4_11.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439605", filename: str = "neuron_splicing.h5ad", ) -> AnnData: """The neuron splicing data is from Qiu, et al (2020). @@ -98,7 +103,7 @@ def scNT_seq_neuron_splicing( def scNT_seq_neuron_labeling( - url: str = "https://www.dropbox.com/s/lk9cl63yd28mfuq/neuron_labeling.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439629", filename: str = "neuron_labeling.h5ad", ) -> AnnData: """The neuron splicing data is from Qiu, et al (2020). @@ -115,7 +120,7 @@ def cite_seq(): def zebrafish( - url: str = "https://www.dropbox.com/scl/fi/3zt89ee0j5twxk4ttzmij/zebrafish.h5ad?rlkey=phwg0b7aqiizd9kf69l2kciak&dl=1", + url: str = "https://figshare.com/ndownloader/files/47420257", filename: str = "zebrafish.h5ad", ) -> AnnData: """The zebrafish is from Saunders, et al (2019). @@ -180,7 +185,7 @@ def hgForebrainGlutamatergic( def chromaffin( - url: str = "https://www.dropbox.com/s/awevuz836tlclvw/onefilepercell_A1_unique_and_others_J2CH1.loom?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439620", filename: str = "onefilepercell_A1_unique_and_others_J2CH1.loom", ) -> AnnData: # """The chromaffin dataset used in http://pklab.med.harvard.edu/velocyto/notebooks/R/chromaffin2.nb.html @@ -224,7 +229,7 @@ def pancreatic_endocrinogenesis( def DentateGyrus_scvelo( - url: str = "https://www.dropbox.com/s/3w1wzb0b68fhdsw/dentategyrus_scv.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439623", filename: str = "dentategyrus_scv.h5ad", ) -> AnnData: """The Dentate Gyrus dataset used in https://github.com/theislab/scvelo_notebooks/tree/master/data/DentateGyrus. @@ -238,10 +243,10 @@ def DentateGyrus_scvelo( def scEU_seq_rpe1( - url: str = "https://www.dropbox.com/s/25enev458c8egn7/rpe1.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439641", filename: str = "rpe1.h5ad", ): - """Download rpe1 dataset from Battich, et al (2020) via Dropbox link. + """Download rpe1 dataset from Battich, et al (2020) via a cloud link. This data consists of 13,913 genes across 2,930 cells. """ @@ -251,10 +256,10 @@ def scEU_seq_rpe1( def scEU_seq_organoid( - url: str = "https://www.dropbox.com/s/es7sroy5ceb7wwz/organoid.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439632", filename: str = "organoid.h5ad", ): - """Download organoid dataset from Battich, et al (2020) via Dropbox link. + """Download organoid dataset from Battich, et al (2020) via a cloud link. This data consists of 9,157 genes across 3,831 cells. """ @@ -264,7 +269,7 @@ def scEU_seq_organoid( def hematopoiesis( - url: str = "https://www.dropbox.com/s/n9mx9trv1h78q0r/hematopoiesis_v1.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439635", # url: str = "https://pitt.box.com/shared/static/kyh3s4wrxdywupn9wk9r2j27vzlvk8vf.h5ad", # with box # url: str = "https://pitt.box.com/shared/static/efqa8icu1m6d1ghfcc3s9tj0j91pky1h.h5ad", # v0: umap_ori version filename: str = "hematopoiesis.h5ad", @@ -276,7 +281,7 @@ def hematopoiesis( def hematopoiesis_raw( - url: str = "https://www.dropbox.com/s/rvkxvq8694xnxz3/hsc_raw_with_metadata.h5ad?dl=1", + url: str = "https://figshare.com/ndownloader/files/47439626", # url: str = "https://pitt.box.com/shared/static/bv7q0kgxjncc5uoget5wvmi700xwntje.h5ad", # with box filename: str = "hematopoiesis_raw.h5ad", ) -> AnnData: @@ -287,7 +292,7 @@ def hematopoiesis_raw( def human_tfs( - url: str = "https://www.dropbox.com/scl/fi/pyocgrhvglg6p7q8yf9ol/human_tfs.txt?rlkey=kbc8vfzf72f8ez0xldrb5nb2d&dl=1", + url: str = "https://figshare.com/ndownloader/files/47439617", filename: str = "human_tfs.txt", ) -> pd.DataFrame: """Download human transcription factors.""" From 6b8840c5cb14042bb2d57d4519451d2da9ba7eb4 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 04:25:29 +0000 Subject: [PATCH 14/27] chore: update register_cmap to register --- dynamo/plot/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index d60bd1c35..8128f90e8 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -212,7 +212,7 @@ def calculate_colors( with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True) + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, override_builtin=True) if values.shape[0] != points.shape[0]: raise ValueError( @@ -633,7 +633,7 @@ def _matplotlib_points( with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True) + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, override_builtin=True) if values.shape[0] != points.shape[0]: raise ValueError( From e31a5f4d7a8ae9419c1396517da0f086789383c2 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 05:25:39 +0000 Subject: [PATCH 15/27] chore: Update scipy.sparse import in normalization and transform modules --- dynamo/preprocessing/normalization.py | 2 +- dynamo/preprocessing/transform.py | 2 +- tests/test_pl.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dynamo/preprocessing/normalization.py b/dynamo/preprocessing/normalization.py index f6d9d8e7f..ad2dc9b8a 100644 --- a/dynamo/preprocessing/normalization.py +++ b/dynamo/preprocessing/normalization.py @@ -11,7 +11,7 @@ import numpy.typing as npt import pandas as pd from scipy.sparse import csr_matrix -from scipy.sparse.base import issparse +from scipy.sparse import issparse from ..configuration import DKM from ..dynamo_logger import ( diff --git a/dynamo/preprocessing/transform.py b/dynamo/preprocessing/transform.py index 524abf6d4..72a4c72d9 100644 --- a/dynamo/preprocessing/transform.py +++ b/dynamo/preprocessing/transform.py @@ -9,7 +9,7 @@ import anndata import numpy as np from anndata import AnnData -from scipy.sparse.base import issparse +from scipy.sparse import issparse from ..configuration import DKM from ..dynamo_logger import main_debug, main_info_insert_adata_uns diff --git a/tests/test_pl.py b/tests/test_pl.py index 816873792..465bb8bb7 100644 --- a/tests/test_pl.py +++ b/tests/test_pl.py @@ -99,12 +99,13 @@ def test_nxviz7_circosplot(adata): def test_scatters_markers_ezplots(): adata = dyn.sample_data.hematopoiesis() - - ax = dyn.pl.cell_cycle_scores(adata, save_show_or_return="return") - assert isinstance(ax, plt.Axes) + import matplotlib.pyplot as plt ax = dyn.pl.pca(adata, color="cell_type", save_show_or_return="return") assert isinstance(ax, plt.Axes) + + ax = dyn.pl.cell_cycle_scores(adata, save_show_or_return="return") + assert isinstance(ax, plt.Axes) ax = dyn.pl.umap(adata, color="cell_type", save_show_or_return="return") assert isinstance(ax, plt.Axes) From 9a5833fb73769058fc4fb150d93ec705cf97f16e Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 06:52:23 +0000 Subject: [PATCH 16/27] update deprecated functions in matplotlib and numpy --- dynamo/plot/heatmaps.py | 2 +- dynamo/plot/utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dynamo/plot/heatmaps.py b/dynamo/plot/heatmaps.py index b6aee78e0..1c1461f9e 100644 --- a/dynamo/plot/heatmaps.py +++ b/dynamo/plot/heatmaps.py @@ -127,7 +127,7 @@ def kde2d( h /= 4 ax = pd.DataFrame((gx - x[:, np.newaxis]) / h[0]).T ay = pd.DataFrame((gy - y[:, np.newaxis]) / h[1]).T - z = (np.matrix(dnorm(ax)) * np.matrix(dnorm(ay).T)) / (nx * h[0] * h[1]) + z = (np.array(dnorm(ax)) @ np.array(dnorm(ay).T)) / (nx * h[0] * h[1]) return gx, gy, z diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 8128f90e8..fb21020ac 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -207,12 +207,12 @@ def calculate_colors( elif values is not None: main_debug("drawing points by values") color_type = "values" - cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap)) + cmap_ = copy.copy(matplotlib.colormaps[cmap]) cmap_.set_bad("lightgray") with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, override_builtin=True) + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_) if values.shape[0] != points.shape[0]: raise ValueError( @@ -276,7 +276,7 @@ def calculate_colors( mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(values) - cmap = matplotlib.cm.get_cmap(cmap) + cmap = matplotlib.colormaps[cmap] colors = cmap(values) # No color (just pick the midpoint of the cmap) else: @@ -633,7 +633,7 @@ def _matplotlib_points( with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, override_builtin=True) + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_) if values.shape[0] != points.shape[0]: raise ValueError( From 2e15f82c1f0f679a478d0761410813aa581f1b43 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 07:13:47 +0000 Subject: [PATCH 17/27] chore: check if cmap exists before registering --- dynamo/plot/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index fb21020ac..bd73f6cba 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -212,7 +212,8 @@ def calculate_colors( with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_) + if cmap_.name not in plt.colormaps(): + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) if values.shape[0] != points.shape[0]: raise ValueError( @@ -451,7 +452,7 @@ def _matplotlib_points( ) if color_key is None: main_debug("color_key is None") - cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap)) + cmap = copy.copy(matplotlib.colormaps[color_key_cmap]) cmap.set_bad("lightgray") colors = None @@ -628,12 +629,13 @@ def _matplotlib_points( # Color by values elif values is not None: main_debug("drawing points by values") - cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap)) + cmap_ = copy.copy(matplotlib.colormaps[cmap]) cmap_.set_bad("lightgray") with warnings.catch_warnings(): warnings.simplefilter("ignore") - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_) + if cmap_.name not in plt.colormaps(): + matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) if values.shape[0] != points.shape[0]: raise ValueError( From e79aed16c49a4ed897f3fa76101a9edec4adf9cc Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 07:15:58 +0000 Subject: [PATCH 18/27] fix: SettingWithCopyWarning in pandas dataframe --- dynamo/vectorfield/vector_calculus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dynamo/vectorfield/vector_calculus.py b/dynamo/vectorfield/vector_calculus.py index 76858514b..1f2879a2f 100644 --- a/dynamo/vectorfield/vector_calculus.py +++ b/dynamo/vectorfield/vector_calculus.py @@ -737,7 +737,7 @@ def sensitivity( S_det = [np.linalg.det(S[:, :, i]) for i in np.arange(S.shape[2])] adata.obs["sensitivity_det_" + basis] = np.nan - adata.obs["sensitivity_det_" + basis][cell_idx] = S_det + adata.obs.loc[cell_idx, "sensitivity_det_" + basis] = S_det if store_in_adata: skey = "sensitivity" if basis is None else "sensitivity_" + basis adata.uns[skey] = ret_dict From d28b8acc6d0a43186fe4980fbe5a16910b1d9800 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 07:18:06 +0000 Subject: [PATCH 19/27] fix: The `scale` parameter has been renamed and will be removed in v0.15.0. Pass `density_norm='width'` for the same effect. --- dynamo/plot/markers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dynamo/plot/markers.py b/dynamo/plot/markers.py index c6bd0012d..09c31b157 100644 --- a/dynamo/plot/markers.py +++ b/dynamo/plot/markers.py @@ -259,7 +259,7 @@ def bubble( linewidth=None, palette=color_key, inner="box", - scale="width", + density_norm="width", cut=0, ax=axes[igene], alpha=alpha, From 9ed280c1ead75d8f3a8867ee3dc67af6eb74fc23 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Fri, 5 Jul 2024 08:10:54 +0000 Subject: [PATCH 20/27] chore: update get_cmap to colormaps[ ] --- dynamo/plot/markers.py | 2 +- dynamo/plot/scatters.py | 4 ++-- dynamo/plot/topography.py | 4 ++-- dynamo/plot/utils.py | 8 ++++---- dynamo/vectorfield/vector_calculus.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dynamo/plot/markers.py b/dynamo/plot/markers.py index 09c31b157..2dd4f73c0 100644 --- a/dynamo/plot/markers.py +++ b/dynamo/plot/markers.py @@ -217,7 +217,7 @@ def bubble( ) if color_key is None: - cmap_ = matplotlib.cm.get_cmap(color_key_cmap) + cmap_ = matplotlib.colormaps[color_key_cmap] cmap_.set_bad("lightgray") unique_labels = np.unique(clusters) num_labels = unique_labels.shape[0] diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 83d16b9ea..8eb7a0f24 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -705,7 +705,7 @@ def _plot_basis_layer(cur_b, cur_l): if stack_colors: main_debug("stack colors: changing cmap") _cmap = stack_colors_cmaps[ax_index % len(stack_colors_cmaps)] - max_color = matplotlib.cm.get_cmap(_cmap)(float("inf")) + max_color = matplotlib.colormaps[_cmap](float("inf")) legend_circle = Line2D( [0], [0], @@ -2276,7 +2276,7 @@ def scatters_single_input( main_debug("stack colors: changing cmap") cur_title = stack_colors_title cmap = stack_colors_cmaps[(ax_index - 1) % len(stack_colors_cmaps)] - max_color = matplotlib.cm.get_cmap(cmap)(float("inf")) + max_color = matplotlib.colormaps[cmap](float("inf")) # TODO: consider remove the legend because it is not helpful legend_circle = Line2D( [0], diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index bfbe86a75..51f0fb3b3 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -321,7 +321,7 @@ def plot_fixed_points_2d( if ax is None: ax = plt.gca() - cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap + cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap for i in range(len(Xss)): cur_ftype = ftype[i] marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) @@ -457,7 +457,7 @@ def plot_fixed_points( vecfld_dict["confidence"], ) - cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap + cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index bd73f6cba..2910e4b2b 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -125,7 +125,7 @@ def calculate_colors( ) if color_key is None: main_debug("color_key is None") - cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap)) + cmap = copy.copy(matplotlib.colormaps[color_key_cmap]) cmap.set_bad("lightgray") colors = None @@ -814,7 +814,7 @@ def _matplotlib_points( cb.locator = MaxNLocator(nbins=3, integer=True) cb.update_ticks() - cmap = matplotlib.cm.get_cmap(cmap) + cmap = matplotlib.colormaps[cmap] colors = cmap(values) # No color (just pick the midpoint of the cmap) else: @@ -919,7 +919,7 @@ def _datashade_points( aggregation = canvas.points(data, "x", "y", agg=ds.count_cat("label")) result = tf.shade(aggregation, how="eq_hist") elif color_key is None: - cmap = matplotlib.cm.get_cmap(color_key_cmap) + cmap = matplotlib.colormaps[color_key_cmap] cmap.set_bad("lightgray") # add plotnonfinite=True to canvas.points @@ -960,7 +960,7 @@ def _datashade_points( # Color by values elif values is not None: - cmap_ = matplotlib.cm.get_cmap(cmap) + cmap_ = matplotlib.colormaps[cmap] cmap_.set_bad("lightgray") if values.shape[0] != points.shape[0]: diff --git a/dynamo/vectorfield/vector_calculus.py b/dynamo/vectorfield/vector_calculus.py index 1f2879a2f..76858514b 100644 --- a/dynamo/vectorfield/vector_calculus.py +++ b/dynamo/vectorfield/vector_calculus.py @@ -737,7 +737,7 @@ def sensitivity( S_det = [np.linalg.det(S[:, :, i]) for i in np.arange(S.shape[2])] adata.obs["sensitivity_det_" + basis] = np.nan - adata.obs.loc[cell_idx, "sensitivity_det_" + basis] = S_det + adata.obs["sensitivity_det_" + basis][cell_idx] = S_det if store_in_adata: skey = "sensitivity" if basis is None else "sensitivity_" + basis adata.uns[skey] = ret_dict From 02dbd745881b1585287f08d6b8f3076a6e9645b6 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 19:55:15 +0000 Subject: [PATCH 21/27] chore: `cloud` -> `figshare` --- dynamo/external/scifate.py | 4 ++-- dynamo/external/scribe.py | 4 ++-- dynamo/sample_data.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dynamo/external/scifate.py b/dynamo/external/scifate.py index 1f45df13a..db67a1061 100644 --- a/dynamo/external/scifate.py +++ b/dynamo/external/scifate.py @@ -58,11 +58,11 @@ def scifate_glmnet( motif_ref: The path to the TF binding motif data as described above. It provides the list of TFs gene names and is used to process adata object to generate the TF expression and target new expression matrix for glmnet based TF-target synthesis rate linkage analysis. But currently it is not used for motif based filtering. - By default, it is a cloud link that store the data from us. Other motif reference can bed downloaded from + By default, it is a figshare link that store the data from us. Other motif reference can bed downloaded from RcisTarget: https://resources.aertslab.org/cistarget/. For human motif matrix, it can be downloaded from June's shared folder: https://shendure-web.gs.washington.edu/content/members/cao1025/public/nobackup/sci_fate/data/hg19-tss-centered-10kb-7species.mc9nr.feather - TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a cloud link from us that stores the + TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a figshare link from us that stores the data. Other data can be downloaded from: https://amp.pharm.mssm.edu/Harmonizome/dataset/ENCODE+Transcription+Factor+Targets. nt_layers: The layers that will be used for the network inference. Note that the layers can be changed flexibly. diff --git a/dynamo/external/scribe.py b/dynamo/external/scribe.py index 43675803e..518d22c6a 100644 --- a/dynamo/external/scribe.py +++ b/dynamo/external/scribe.py @@ -49,7 +49,7 @@ def scribe( cell_filter_UMI: Minimum number of UMIs for cell filtering. motif_ref: It provides the list of TFs gene names and is used to parse the data to get the list of TFs and Targets for the causal network inference from those TFs to Targets. But currently the motif based filtering - is not implemented. By default, it is a cloud link that store the data from us. Other motif reference can + is not implemented. By default, it is a figshare link that store the data from us. Other motif reference can bed downloaded from RcisTarget: https://resources.aertslab.org/cistarget/. For human motif matrix, it can be downloaded from June's shared folder: https://shendure-web.gs.washington.edu/content/members/cao1025/public/nobackup/sci_fate/data/hg19-tss- @@ -64,7 +64,7 @@ def scribe( target. This can signify the relationship between potential regulators and targets, speed up the calculation, but at the risk of ignoring strong inhibition effects from certain regulators to targets. do_CLR: Whether to perform context likelihood relatedness analysis on the reconstructed causal network - TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a cloud link from us that stores the + TF_link_ENCODE_ref: The path to the TF chip-seq data. By default, it is a figshare link from us that stores the data. Other data can be downloaded from: https://amp.pharm.mssm.edu/Harmonizome/dataset/ENCODE+Transcription+Factor+Targets. diff --git a/dynamo/sample_data.py b/dynamo/sample_data.py index ef3ee1174..9a2641b17 100755 --- a/dynamo/sample_data.py +++ b/dynamo/sample_data.py @@ -246,7 +246,7 @@ def scEU_seq_rpe1( url: str = "https://figshare.com/ndownloader/files/47439641", filename: str = "rpe1.h5ad", ): - """Download rpe1 dataset from Battich, et al (2020) via a cloud link. + """Download rpe1 dataset from Battich, et al (2020) via a figshare link. This data consists of 13,913 genes across 2,930 cells. """ @@ -259,7 +259,7 @@ def scEU_seq_organoid( url: str = "https://figshare.com/ndownloader/files/47439632", filename: str = "organoid.h5ad", ): - """Download organoid dataset from Battich, et al (2020) via a cloud link. + """Download organoid dataset from Battich, et al (2020) via a figshare link. This data consists of 9,157 genes across 3,831 cells. """ From b582a4c662cfe317da73baf0ec8aee97ff175e4a Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 20:02:49 +0000 Subject: [PATCH 22/27] chore: import matplotlib as mpl --- dynamo/configuration.py | 41 ++++++++++++++++++++------------------- dynamo/plot/markers.py | 3 ++- dynamo/plot/scatters.py | 6 +++--- dynamo/plot/topography.py | 5 +++-- dynamo/plot/utils.py | 21 ++++++++++---------- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/dynamo/configuration.py b/dynamo/configuration.py index 033c5c484..0de47e714 100755 --- a/dynamo/configuration.py +++ b/dynamo/configuration.py @@ -3,6 +3,7 @@ import colorcet import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -442,26 +443,26 @@ def update_data_store_mode(mode: str) -> None: # register cmap with warnings.catch_warnings(): warnings.simplefilter("ignore") - if "zebrafish" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="zebrafish", cmap=zebrafish_cmap) - if "fire" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="fire", cmap=fire_cmap) - if "darkblue" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="darkblue", cmap=darkblue_cmap) - if "darkgreen" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="darkgreen", cmap=darkgreen_cmap) - if "darkred" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="darkred", cmap=darkred_cmap) - if "darkpurple" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="darkpurple", cmap=darkpurple_cmap) - if "div_blue_black_red" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="div_blue_black_red", cmap=div_blue_black_red_cmap) - if "div_blue_red" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="div_blue_red", cmap=div_blue_red_cmap) - if "glasbey_white" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="glasbey_white", cmap=glasbey_white_cmap) - if "glasbey_dark" not in matplotlib.colormaps(): - matplotlib.colormaps.register(name="glasbey_dark", cmap=glasbey_dark_cmap) + if "zebrafish" not in mpl.colormaps(): + mpl.colormaps.register(name="zebrafish", cmap=zebrafish_cmap) + if "fire" not in mpl.colormaps(): + mpl.colormaps.register(name="fire", cmap=fire_cmap) + if "darkblue" not in mpl.colormaps(): + mpl.colormaps.register(name="darkblue", cmap=darkblue_cmap) + if "darkgreen" not in mpl.colormaps(): + mpl.colormaps.register(name="darkgreen", cmap=darkgreen_cmap) + if "darkred" not in mpl.colormaps(): + mpl.colormaps.register(name="darkred", cmap=darkred_cmap) + if "darkpurple" not in mpl.colormaps(): + mpl.colormaps.register(name="darkpurple", cmap=darkpurple_cmap) + if "div_blue_black_red" not in mpl.colormaps(): + mpl.colormaps.register(name="div_blue_black_red", cmap=div_blue_black_red_cmap) + if "div_blue_red" not in mpl.colormaps(): + mpl.colormaps.register(name="div_blue_red", cmap=div_blue_red_cmap) + if "glasbey_white" not in mpl.colormaps(): + mpl.colormaps.register(name="glasbey_white", cmap=glasbey_white_cmap) + if "glasbey_dark" not in mpl.colormaps(): + mpl.colormaps.register(name="glasbey_dark", cmap=glasbey_dark_cmap) _themes = { diff --git a/dynamo/plot/markers.py b/dynamo/plot/markers.py index 2dd4f73c0..96eee6491 100644 --- a/dynamo/plot/markers.py +++ b/dynamo/plot/markers.py @@ -10,6 +10,7 @@ import numpy.typing as npt import pandas as pd from anndata import AnnData +import matplotlib as mpl from matplotlib.axes import Axes from matplotlib.figure import Figure from scipy.sparse import issparse @@ -217,7 +218,7 @@ def bubble( ) if color_key is None: - cmap_ = matplotlib.colormaps[color_key_cmap] + cmap_ = mpl.colormaps[color_key_cmap] cmap_.set_bad("lightgray") unique_labels = np.unique(clusters) num_labels = unique_labels.shape[0] diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 8eb7a0f24..8be93e828 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -9,7 +9,7 @@ from typing_extensions import Literal import anndata -import matplotlib.cm +import matplotlib as mpl import numpy as np import pandas as pd from anndata import AnnData @@ -705,7 +705,7 @@ def _plot_basis_layer(cur_b, cur_l): if stack_colors: main_debug("stack colors: changing cmap") _cmap = stack_colors_cmaps[ax_index % len(stack_colors_cmaps)] - max_color = matplotlib.colormaps[_cmap](float("inf")) + max_color = mpl.colormaps[_cmap](float("inf")) legend_circle = Line2D( [0], [0], @@ -2276,7 +2276,7 @@ def scatters_single_input( main_debug("stack colors: changing cmap") cur_title = stack_colors_title cmap = stack_colors_cmaps[(ax_index - 1) % len(stack_colors_cmaps)] - max_color = matplotlib.colormaps[cmap](float("inf")) + max_color = mpl.colormaps[cmap](float("inf")) # TODO: consider remove the legend because it is not helpful legend_circle = Line2D( [0], diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 51f0fb3b3..d7c6a35ce 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -6,6 +6,7 @@ except ImportError: from typing_extensions import Literal +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -321,7 +322,7 @@ def plot_fixed_points_2d( if ax is None: ax = plt.gca() - cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap + cm = mpl.colormaps[_cmap] if type(_cmap) is str else _cmap for i in range(len(Xss)): cur_ftype = ftype[i] marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) @@ -457,7 +458,7 @@ def plot_fixed_points( vecfld_dict["confidence"], ) - cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap + cm = mpl.colormaps[_cmap] if type(_cmap) is str else _cmap colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 2910e4b2b..802c982ed 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -8,6 +8,7 @@ from warnings import warn import matplotlib +import matplotlib as mpl import matplotlib.patheffects as PathEffects import matplotlib.pyplot as plt import numba @@ -125,7 +126,7 @@ def calculate_colors( ) if color_key is None: main_debug("color_key is None") - cmap = copy.copy(matplotlib.colormaps[color_key_cmap]) + cmap = copy.copy(mpl.colormaps[color_key_cmap]) cmap.set_bad("lightgray") colors = None @@ -207,13 +208,13 @@ def calculate_colors( elif values is not None: main_debug("drawing points by values") color_type = "values" - cmap_ = copy.copy(matplotlib.colormaps[cmap]) + cmap_ = copy.copy(mpl.colormaps[cmap]) cmap_.set_bad("lightgray") with warnings.catch_warnings(): warnings.simplefilter("ignore") if cmap_.name not in plt.colormaps(): - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) + mpl.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) if values.shape[0] != points.shape[0]: raise ValueError( @@ -277,7 +278,7 @@ def calculate_colors( mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(values) - cmap = matplotlib.colormaps[cmap] + cmap = mpl.colormaps[cmap] colors = cmap(values) # No color (just pick the midpoint of the cmap) else: @@ -452,7 +453,7 @@ def _matplotlib_points( ) if color_key is None: main_debug("color_key is None") - cmap = copy.copy(matplotlib.colormaps[color_key_cmap]) + cmap = copy.copy(mpl.colormaps[color_key_cmap]) cmap.set_bad("lightgray") colors = None @@ -629,13 +630,13 @@ def _matplotlib_points( # Color by values elif values is not None: main_debug("drawing points by values") - cmap_ = copy.copy(matplotlib.colormaps[cmap]) + cmap_ = copy.copy(mpl.colormaps[cmap]) cmap_.set_bad("lightgray") with warnings.catch_warnings(): warnings.simplefilter("ignore") if cmap_.name not in plt.colormaps(): - matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) + mpl.colormaps.register(name=cmap_.name, cmap=cmap_, force=False) if values.shape[0] != points.shape[0]: raise ValueError( @@ -814,7 +815,7 @@ def _matplotlib_points( cb.locator = MaxNLocator(nbins=3, integer=True) cb.update_ticks() - cmap = matplotlib.colormaps[cmap] + cmap = mpl.colormaps[cmap] colors = cmap(values) # No color (just pick the midpoint of the cmap) else: @@ -919,7 +920,7 @@ def _datashade_points( aggregation = canvas.points(data, "x", "y", agg=ds.count_cat("label")) result = tf.shade(aggregation, how="eq_hist") elif color_key is None: - cmap = matplotlib.colormaps[color_key_cmap] + cmap = mpl.colormaps[color_key_cmap] cmap.set_bad("lightgray") # add plotnonfinite=True to canvas.points @@ -960,7 +961,7 @@ def _datashade_points( # Color by values elif values is not None: - cmap_ = matplotlib.colormaps[cmap] + cmap_ = mpl.colormaps[cmap] cmap_.set_bad("lightgray") if values.shape[0] != points.shape[0]: From 0be3cf0fca7041507bcd53faf63d1179bdceeb85 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 20:21:32 +0000 Subject: [PATCH 23/27] chore: resolve warnings --- dynamo/configuration.py | 2 +- dynamo/data_io.py | 4 ++-- dynamo/estimation/csc/utils_velocity.py | 2 +- dynamo/estimation/csc/velocity.py | 3 ++- dynamo/estimation/fit_jacobian.py | 6 +++--- dynamo/plot/dynamics.py | 1 + 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/dynamo/configuration.py b/dynamo/configuration.py index 0de47e714..887b0f4eb 100755 --- a/dynamo/configuration.py +++ b/dynamo/configuration.py @@ -352,7 +352,7 @@ def use_default_var_if_none(val: Any, key: str, replace_val: Optional[Any] = Non Returns: `val` or config value set in DynamoAdataConfig according to the method description above. """ - if not key in DynamoAdataConfig.config_key_to_values: + if key not in DynamoAdataConfig.config_key_to_values: assert KeyError("Config %s not exist in DynamoAdataConfig." % (key)) if val == replace_val: config_val = DynamoAdataConfig.config_key_to_values[key] diff --git a/dynamo/data_io.py b/dynamo/data_io.py index 2822ce590..b893b8dd3 100755 --- a/dynamo/data_io.py +++ b/dynamo/data_io.py @@ -138,7 +138,7 @@ def load_NASC_seq( tot_RNA = None cells_raw, cells = None, None - for f in tqdm(files, desc=f"reading rmse output files:"): + for f in tqdm(files, desc="reading rmse output files:"): tmp = pd.read_csv(f, index_col=0, sep="\t") if tot_RNA is None: @@ -251,7 +251,7 @@ def aggregate_adata(file_list: list) -> AnnData: if len(valid_cells) == 0 or len(valid_genes) == 0: raise Exception( - f"we don't find any gene or cell names shared across different adata objects." f"Please check your data. " + "we don't find any gene or cell names shared across different adata objects."+"Please check your data. " ) layer_dict = {} diff --git a/dynamo/estimation/csc/utils_velocity.py b/dynamo/estimation/csc/utils_velocity.py index 3040f61e0..1df548b5d 100755 --- a/dynamo/estimation/csc/utils_velocity.py +++ b/dynamo/estimation/csc/utils_velocity.py @@ -302,7 +302,7 @@ def fit_linreg_robust( f"estimation method {est_method} is not implemented. " f"Currently supported linear regression methods include `rlm` and `ransac`." ) - except: + except Exception as e: if intercept: ym = np.mean(yy) xm = np.mean(xx) diff --git a/dynamo/estimation/csc/velocity.py b/dynamo/estimation/csc/velocity.py index 30769184c..74a14c6a6 100755 --- a/dynamo/estimation/csc/velocity.py +++ b/dynamo/estimation/csc/velocity.py @@ -2,6 +2,7 @@ from multiprocessing.dummy import Pool as ThreadPool from warnings import warn +import numpy as np from scipy.sparse import csr_matrix from tqdm import tqdm @@ -909,7 +910,7 @@ def fit( for i in tqdm(range(n_genes), desc="estimating gamma"): try: gamma[i], u0[i] = fit_first_order_deg_lsq(t_uniq, uu_m[i]) - except: + except Exception as e: gamma[i], u0[i] = 0, 0 self.parameters["gamma"], self.aux_param["uu0"] = gamma, u0 alpha = np.zeros(n_genes) diff --git a/dynamo/estimation/fit_jacobian.py b/dynamo/estimation/fit_jacobian.py index 3e0d2083e..8704518b7 100644 --- a/dynamo/estimation/fit_jacobian.py +++ b/dynamo/estimation/fit_jacobian.py @@ -171,7 +171,7 @@ def fit_hill_grad( msd_min = msd p_opt_min = [A, K, n, g] - except: + except Exception as e: #TODO: not a good practice pass if p_opt_min is None: @@ -236,7 +236,7 @@ def fit_hill_inh_grad( msd_min = msd p_opt_min = p_opt - except: + except Exception as e: #TODO: not a good practice pass return {"A": p_opt_min[0], "K": p_opt_min[1], "n": np.exp(p_opt_min[2]), "g": p_opt_min[3]}, msd_min @@ -296,7 +296,7 @@ def fit_hill_act_grad( msd_min = msd p_opt_min = p_opt - except: + except Exception as e: # not a good practice pass return {"A": p_opt_min[0], "K": p_opt_min[1], "n": np.exp(p_opt_min[2]), "g": p_opt_min[3]}, msd_min diff --git a/dynamo/plot/dynamics.py b/dynamo/plot/dynamics.py index 3824b213d..e77296739 100755 --- a/dynamo/plot/dynamics.py +++ b/dynamo/plot/dynamics.py @@ -7,6 +7,7 @@ except ImportError: from typing_extensions import Literal +import numpy as np import pandas as pd from anndata import AnnData from matplotlib.figure import Figure From 6fd158d4593b77fda15cb77337b42911011e68ed Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 20:25:53 +0000 Subject: [PATCH 24/27] trial: Update python-version matrix in GitHub workflow --- .github/workflows/python-plain-run-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-plain-run-test.yml b/.github/workflows/python-plain-run-test.yml index 97d305895..12c62349f 100644 --- a/.github/workflows/python-plain-run-test.yml +++ b/.github/workflows/python-plain-run-test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9] + python-version: [3.8, 3.9, 3.10, 3.11, 3.12] steps: - uses: actions/checkout@v2 From 06729a84b42ed44cf073b5b14ebe6def3c252342 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 20:29:19 +0000 Subject: [PATCH 25/27] fix weird 3.10 version issue: if i use 3.10, it would be recognized as 3.1 by github action --- .github/workflows/python-plain-run-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-plain-run-test.yml b/.github/workflows/python-plain-run-test.yml index 12c62349f..476230b9c 100644 --- a/.github/workflows/python-plain-run-test.yml +++ b/.github/workflows/python-plain-run-test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, 3.10, 3.11, 3.12] + python-version: [3.8, 3.9, 3.10.14, 3.11, 3.12] steps: - uses: actions/checkout@v2 From 2736e39ce47b2958c9c67bb30ced241d3abd62be Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 20:35:58 +0000 Subject: [PATCH 26/27] trial: let the build continue on error --- .github/workflows/python-plain-run-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-plain-run-test.yml b/.github/workflows/python-plain-run-test.yml index 476230b9c..b82ffe087 100644 --- a/.github/workflows/python-plain-run-test.yml +++ b/.github/workflows/python-plain-run-test.yml @@ -15,6 +15,7 @@ jobs: matrix: python-version: [3.8, 3.9, 3.10.14, 3.11, 3.12] + continue-on-error: true steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 3339f3d246fbdfb7f8340a7efc7f9b6b41452595 Mon Sep 17 00:00:00 2001 From: Sijie Chen Date: Sat, 6 Jul 2024 21:11:21 +0000 Subject: [PATCH 27/27] trial: remove python 3.12 tests --- .github/workflows/python-package.yml | 2 +- .github/workflows/python-plain-run-test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4bb3e0bcb..0ca83aff8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9] + python-version: [3.8, 3.9, 3.10.14, 3.11] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/python-plain-run-test.yml b/.github/workflows/python-plain-run-test.yml index b82ffe087..c48b72e97 100644 --- a/.github/workflows/python-plain-run-test.yml +++ b/.github/workflows/python-plain-run-test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, 3.10.14, 3.11, 3.12] + python-version: [3.8, 3.9, 3.10.14, 3.11] continue-on-error: true steps: