diff --git a/dynamo/configuration.py b/dynamo/configuration.py index 3f6e55eac..7e74123db 100755 --- a/dynamo/configuration.py +++ b/dynamo/configuration.py @@ -137,6 +137,33 @@ def allowed_X_layer_names(): def init_uns_pp_namespace(adata: AnnData): adata.uns[DynamoAdataKeyManager.UNS_PP_KEY] = {} + def get_vf_dict(adata, basis="", vf_key="VecFld"): + if basis is not None: + if len(basis) > 0: + vf_key = "%s_%s" % (vf_key, basis) + + if vf_key not in adata.uns.keys(): + raise ValueError( + f"Vector field function {vf_key} is not included in the adata object! " + f"Try firstly running dyn.vf.VectorField(adata, basis='{basis}')" + ) + + vf_dict = adata.uns[vf_key] + return vf_dict + + def vecfld_from_adata(adata, basis="", vf_key="VecFld"): + vf_dict = DynamoAdataKeyManager.get_vf_dict(adata, basis=basis, vf_key=vf_key) + + method = vf_dict["method"] + if method.lower() == "sparsevfc": + func = lambda x: vector_field_function(x, vf_dict) + elif method.lower() == "dynode": + func = lambda x: dynode_vector_field_function(x, vf_dict) + else: + raise ValueError(f"current only support two methods, SparseVFC and dynode") + + return vf_dict, func + # TODO discuss alias naming convention DKM = DynamoAdataKeyManager @@ -786,5 +813,7 @@ def set_pub_style_mpltex(): # initialize DynamoSaveConfig and DynamoVisConfig mode defaults DynamoAdataConfig.update_data_store_mode("full") -main_info("setting visualization default mode in dynamo. Your customized matplotlib settings might be overritten.") +main_info( + "[dynamo import initialization] setting visualization default mode in dynamo. Your customized matplotlib settings might be overritten." +) DynamoVisConfig.set_default_mode() diff --git a/dynamo/data_io.py b/dynamo/data_io.py index 10fd41d13..756228610 100755 --- a/dynamo/data_io.py +++ b/dynamo/data_io.py @@ -300,6 +300,8 @@ def cleanup(adata, del_prediction=False, del_2nd_moments=False): adata.uns.pop("kinetics_heatmap") if "hdbscan" in adata.uns_keys(): adata.uns.pop("hdbscan") + if "highest_frac_genes" in adata.uns_keys(): + adata.uns.pop("highest_frac_genes") VF_keys = [i if i.startswith("VecFld") else None for i in adata.uns_keys()] for i in VF_keys: diff --git a/dynamo/plot/networks.py b/dynamo/plot/networks.py index 7b570f956..10f1125eb 100644 --- a/dynamo/plot/networks.py +++ b/dynamo/plot/networks.py @@ -1,10 +1,14 @@ +# nxviz prints long warning messages regarding nxviz's own version update information +import warnings + import networkx as nx import numpy as np import nxviz as nv -import nxviz.annotate import pandas as pd from matplotlib.axes import Axes +warnings.filterwarnings("ignore", module="nxviz") + from ..tools.utils import flatten, index_gene, update_dict from .utils import save_fig, set_colorbar from .utils_graph import ArcPlot diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 46bed7424..f34d6b04d 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -1418,7 +1418,7 @@ def streamline_plot( cmap: Optional[str] = None, color_key: Union[dict, list] = None, color_key_cmap: Optional[str] = None, - background: Optional[str] = "white", + background: Optional[str] = None, ncols: int = 4, pointsize: Union[None, float] = None, figsize: tuple = (6, 4), diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 84f228943..8303bcc7b 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -832,7 +832,8 @@ def _plot_basis_layer(cur_b, cur_l): if deaxis: deaxis_all(ax) - ax.set_title(cur_title) + ax.set_title(cur_title, color=font_color) + ax.tick_params(axis="both", colors=font_color) axes_list.append(ax) color_list.append(color_out) diff --git a/dynamo/vectorfield/utils.py b/dynamo/vectorfield/utils.py index 1f621fb41..0c72e84d0 100644 --- a/dynamo/vectorfield/utils.py +++ b/dynamo/vectorfield/utils.py @@ -13,6 +13,7 @@ from scipy.spatial.distance import cdist, pdist from tqdm import tqdm +from ..configuration import DKM from ..dynamo_logger import LoggerManager, main_info from ..tools.utils import ( form_triu_matrix, @@ -22,6 +23,8 @@ ) from .FixedPoints import FixedPoints +get_vf_dict = DKM.get_vf_dict + def is_outside_domain(x, domain): x = x[None, :] if x.ndim == 1 else x @@ -225,21 +228,6 @@ def con_K_div_cur_free(x, y, sigma=0.8, eta=0.5): return G, df_kernel, cf_kernel -def get_vf_dict(adata, basis="", vf_key="VecFld"): - if basis is not None: - if len(basis) > 0: - vf_key = "%s_%s" % (vf_key, basis) - - if vf_key not in adata.uns.keys(): - raise ValueError( - f"Vector field function {vf_key} is not included in the adata object! " - f"Try firstly running dyn.vf.VectorField(adata, basis='{basis}')" - ) - - vf_dict = adata.uns[vf_key] - return vf_dict - - def vecfld_from_adata(adata, basis="", vf_key="VecFld"): vf_dict = get_vf_dict(adata, basis=basis, vf_key=vf_key)