Skip to content

Commit

Permalink
chore: import matplotlib as mpl
Browse files Browse the repository at this point in the history
  • Loading branch information
chansigit committed Jul 6, 2024
1 parent 02dbd74 commit b582a4c
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 36 deletions.
41 changes: 21 additions & 20 deletions dynamo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import colorcet
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -442,26 +443,26 @@ def update_data_store_mode(mode: str) -> None:
# register cmap
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if "zebrafish" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="zebrafish", cmap=zebrafish_cmap)
if "fire" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="fire", cmap=fire_cmap)
if "darkblue" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="darkblue", cmap=darkblue_cmap)
if "darkgreen" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="darkgreen", cmap=darkgreen_cmap)
if "darkred" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="darkred", cmap=darkred_cmap)
if "darkpurple" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="darkpurple", cmap=darkpurple_cmap)
if "div_blue_black_red" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="div_blue_black_red", cmap=div_blue_black_red_cmap)
if "div_blue_red" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="div_blue_red", cmap=div_blue_red_cmap)
if "glasbey_white" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="glasbey_white", cmap=glasbey_white_cmap)
if "glasbey_dark" not in matplotlib.colormaps():
matplotlib.colormaps.register(name="glasbey_dark", cmap=glasbey_dark_cmap)
if "zebrafish" not in mpl.colormaps():
mpl.colormaps.register(name="zebrafish", cmap=zebrafish_cmap)
if "fire" not in mpl.colormaps():
mpl.colormaps.register(name="fire", cmap=fire_cmap)
if "darkblue" not in mpl.colormaps():
mpl.colormaps.register(name="darkblue", cmap=darkblue_cmap)
if "darkgreen" not in mpl.colormaps():
mpl.colormaps.register(name="darkgreen", cmap=darkgreen_cmap)
if "darkred" not in mpl.colormaps():
mpl.colormaps.register(name="darkred", cmap=darkred_cmap)
if "darkpurple" not in mpl.colormaps():
mpl.colormaps.register(name="darkpurple", cmap=darkpurple_cmap)
if "div_blue_black_red" not in mpl.colormaps():
mpl.colormaps.register(name="div_blue_black_red", cmap=div_blue_black_red_cmap)
if "div_blue_red" not in mpl.colormaps():
mpl.colormaps.register(name="div_blue_red", cmap=div_blue_red_cmap)
if "glasbey_white" not in mpl.colormaps():
mpl.colormaps.register(name="glasbey_white", cmap=glasbey_white_cmap)
if "glasbey_dark" not in mpl.colormaps():
mpl.colormaps.register(name="glasbey_dark", cmap=glasbey_dark_cmap)


_themes = {
Expand Down
3 changes: 2 additions & 1 deletion dynamo/plot/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy.typing as npt
import pandas as pd
from anndata import AnnData
import matplotlib as mpl
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from scipy.sparse import issparse
Expand Down Expand Up @@ -217,7 +218,7 @@ def bubble(
)

if color_key is None:
cmap_ = matplotlib.colormaps[color_key_cmap]
cmap_ = mpl.colormaps[color_key_cmap]
cmap_.set_bad("lightgray")
unique_labels = np.unique(clusters)
num_labels = unique_labels.shape[0]
Expand Down
6 changes: 3 additions & 3 deletions dynamo/plot/scatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Literal

import anndata
import matplotlib.cm
import matplotlib as mpl
import numpy as np
import pandas as pd
from anndata import AnnData
Expand Down Expand Up @@ -705,7 +705,7 @@ def _plot_basis_layer(cur_b, cur_l):
if stack_colors:
main_debug("stack colors: changing cmap")
_cmap = stack_colors_cmaps[ax_index % len(stack_colors_cmaps)]
max_color = matplotlib.colormaps[_cmap](float("inf"))
max_color = mpl.colormaps[_cmap](float("inf"))
legend_circle = Line2D(
[0],
[0],
Expand Down Expand Up @@ -2276,7 +2276,7 @@ def scatters_single_input(
main_debug("stack colors: changing cmap")
cur_title = stack_colors_title
cmap = stack_colors_cmaps[(ax_index - 1) % len(stack_colors_cmaps)]
max_color = matplotlib.colormaps[cmap](float("inf"))
max_color = mpl.colormaps[cmap](float("inf"))
# TODO: consider remove the legend because it is not helpful
legend_circle = Line2D(
[0],
Expand Down
5 changes: 3 additions & 2 deletions dynamo/plot/topography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
except ImportError:
from typing_extensions import Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -321,7 +322,7 @@ def plot_fixed_points_2d(
if ax is None:
ax = plt.gca()

cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap
cm = mpl.colormaps[_cmap] if type(_cmap) is str else _cmap
for i in range(len(Xss)):
cur_ftype = ftype[i]
marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)])
Expand Down Expand Up @@ -457,7 +458,7 @@ def plot_fixed_points(
vecfld_dict["confidence"],
)

cm = matplotlib.colormaps[_cmap] if type(_cmap) is str else _cmap
cm = mpl.colormaps[_cmap] if type(_cmap) is str else _cmap
colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))]
text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype]

Expand Down
21 changes: 11 additions & 10 deletions dynamo/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from warnings import warn

import matplotlib
import matplotlib as mpl
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import numba
Expand Down Expand Up @@ -125,7 +126,7 @@ def calculate_colors(
)
if color_key is None:
main_debug("color_key is None")
cmap = copy.copy(matplotlib.colormaps[color_key_cmap])
cmap = copy.copy(mpl.colormaps[color_key_cmap])
cmap.set_bad("lightgray")
colors = None

Expand Down Expand Up @@ -207,13 +208,13 @@ def calculate_colors(
elif values is not None:
main_debug("drawing points by values")
color_type = "values"
cmap_ = copy.copy(matplotlib.colormaps[cmap])
cmap_ = copy.copy(mpl.colormaps[cmap])
cmap_.set_bad("lightgray")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
if cmap_.name not in plt.colormaps():
matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False)
mpl.colormaps.register(name=cmap_.name, cmap=cmap_, force=False)

if values.shape[0] != points.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -277,7 +278,7 @@ def calculate_colors(
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array(values)

cmap = matplotlib.colormaps[cmap]
cmap = mpl.colormaps[cmap]
colors = cmap(values)
# No color (just pick the midpoint of the cmap)
else:
Expand Down Expand Up @@ -452,7 +453,7 @@ def _matplotlib_points(
)
if color_key is None:
main_debug("color_key is None")
cmap = copy.copy(matplotlib.colormaps[color_key_cmap])
cmap = copy.copy(mpl.colormaps[color_key_cmap])
cmap.set_bad("lightgray")
colors = None

Expand Down Expand Up @@ -629,13 +630,13 @@ def _matplotlib_points(
# Color by values
elif values is not None:
main_debug("drawing points by values")
cmap_ = copy.copy(matplotlib.colormaps[cmap])
cmap_ = copy.copy(mpl.colormaps[cmap])
cmap_.set_bad("lightgray")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
if cmap_.name not in plt.colormaps():
matplotlib.colormaps.register(name=cmap_.name, cmap=cmap_, force=False)
mpl.colormaps.register(name=cmap_.name, cmap=cmap_, force=False)

if values.shape[0] != points.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -814,7 +815,7 @@ def _matplotlib_points(
cb.locator = MaxNLocator(nbins=3, integer=True)
cb.update_ticks()

cmap = matplotlib.colormaps[cmap]
cmap = mpl.colormaps[cmap]
colors = cmap(values)
# No color (just pick the midpoint of the cmap)
else:
Expand Down Expand Up @@ -919,7 +920,7 @@ def _datashade_points(
aggregation = canvas.points(data, "x", "y", agg=ds.count_cat("label"))
result = tf.shade(aggregation, how="eq_hist")
elif color_key is None:
cmap = matplotlib.colormaps[color_key_cmap]
cmap = mpl.colormaps[color_key_cmap]
cmap.set_bad("lightgray")
# add plotnonfinite=True to canvas.points

Expand Down Expand Up @@ -960,7 +961,7 @@ def _datashade_points(

# Color by values
elif values is not None:
cmap_ = matplotlib.colormaps[cmap]
cmap_ = mpl.colormaps[cmap]
cmap_.set_bad("lightgray")

if values.shape[0] != points.shape[0]:
Expand Down

0 comments on commit b582a4c

Please sign in to comment.