Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove highest_frac_genes by cleanup() before saving and remove long nxviz warning #370

Closed
wants to merge 12 commits into from
31 changes: 30 additions & 1 deletion dynamo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,33 @@ def allowed_X_layer_names():
def init_uns_pp_namespace(adata: AnnData):
adata.uns[DynamoAdataKeyManager.UNS_PP_KEY] = {}

def get_vf_dict(adata, basis="", vf_key="VecFld"):
if basis is not None:
if len(basis) > 0:
vf_key = "%s_%s" % (vf_key, basis)

if vf_key not in adata.uns.keys():
raise ValueError(
f"Vector field function {vf_key} is not included in the adata object! "
f"Try firstly running dyn.vf.VectorField(adata, basis='{basis}')"
)

vf_dict = adata.uns[vf_key]
return vf_dict

def vecfld_from_adata(adata, basis="", vf_key="VecFld"):
vf_dict = DynamoAdataKeyManager.get_vf_dict(adata, basis=basis, vf_key=vf_key)

method = vf_dict["method"]
if method.lower() == "sparsevfc":
func = lambda x: vector_field_function(x, vf_dict)
elif method.lower() == "dynode":
func = lambda x: dynode_vector_field_function(x, vf_dict)
else:
raise ValueError(f"current only support two methods, SparseVFC and dynode")

return vf_dict, func


# TODO discuss alias naming convention
DKM = DynamoAdataKeyManager
Expand Down Expand Up @@ -786,5 +813,7 @@ def set_pub_style_mpltex():

# initialize DynamoSaveConfig and DynamoVisConfig mode defaults
DynamoAdataConfig.update_data_store_mode("full")
main_info("setting visualization default mode in dynamo. Your customized matplotlib settings might be overritten.")
main_info(
"[dynamo import initialization] setting visualization default mode in dynamo. Your customized matplotlib settings might be overritten."
)
DynamoVisConfig.set_default_mode()
2 changes: 2 additions & 0 deletions dynamo/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def cleanup(adata, del_prediction=False, del_2nd_moments=False):
adata.uns.pop("kinetics_heatmap")
if "hdbscan" in adata.uns_keys():
adata.uns.pop("hdbscan")
if "highest_frac_genes" in adata.uns_keys():
adata.uns.pop("highest_frac_genes")

VF_keys = [i if i.startswith("VecFld") else None for i in adata.uns_keys()]
for i in VF_keys:
Expand Down
6 changes: 5 additions & 1 deletion dynamo/plot/networks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# nxviz prints long warning messages regarding nxviz's own version update information
import warnings

import networkx as nx
import numpy as np
import nxviz as nv
import nxviz.annotate
import pandas as pd
from matplotlib.axes import Axes

warnings.filterwarnings("ignore", module="nxviz")

from ..tools.utils import flatten, index_gene, update_dict
from .utils import save_fig, set_colorbar
from .utils_graph import ArcPlot
Expand Down
2 changes: 1 addition & 1 deletion dynamo/plot/scVectorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def streamline_plot(
cmap: Optional[str] = None,
color_key: Union[dict, list] = None,
color_key_cmap: Optional[str] = None,
background: Optional[str] = "white",
background: Optional[str] = None,
ncols: int = 4,
pointsize: Union[None, float] = None,
figsize: tuple = (6, 4),
Expand Down
3 changes: 2 additions & 1 deletion dynamo/plot/scatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ def _plot_basis_layer(cur_b, cur_l):
if deaxis:
deaxis_all(ax)

ax.set_title(cur_title)
ax.set_title(cur_title, color=font_color)
ax.tick_params(axis="both", colors=font_color)

axes_list.append(ax)
color_list.append(color_out)
Expand Down
18 changes: 3 additions & 15 deletions dynamo/vectorfield/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scipy.spatial.distance import cdist, pdist
from tqdm import tqdm

from ..configuration import DKM
from ..dynamo_logger import LoggerManager, main_info
from ..tools.utils import (
form_triu_matrix,
Expand All @@ -22,6 +23,8 @@
)
from .FixedPoints import FixedPoints

get_vf_dict = DKM.get_vf_dict


def is_outside_domain(x, domain):
x = x[None, :] if x.ndim == 1 else x
Expand Down Expand Up @@ -225,21 +228,6 @@ def con_K_div_cur_free(x, y, sigma=0.8, eta=0.5):
return G, df_kernel, cf_kernel


def get_vf_dict(adata, basis="", vf_key="VecFld"):
if basis is not None:
if len(basis) > 0:
vf_key = "%s_%s" % (vf_key, basis)

if vf_key not in adata.uns.keys():
raise ValueError(
f"Vector field function {vf_key} is not included in the adata object! "
f"Try firstly running dyn.vf.VectorField(adata, basis='{basis}')"
)

vf_dict = adata.uns[vf_key]
return vf_dict


def vecfld_from_adata(adata, basis="", vf_key="VecFld"):
vf_dict = get_vf_dict(adata, basis=basis, vf_key=vf_key)

Expand Down