Skip to content

Commit

Permalink
feat: added colors parameter to allow user to select the colour of ea…
Browse files Browse the repository at this point in the history
…ch histogram
  • Loading branch information
jjhw3 committed Oct 3, 2023
1 parent 82c1e2d commit 269d545
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import warnings
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Union, List

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -68,6 +68,7 @@ def histplot(
histtype="step",
xerr=False,
label=None,
colors=None,
sort=None,
edges=True,
binticks=False,
Expand Down Expand Up @@ -125,6 +126,8 @@ def histplot(
Size of xerr if ``histtype == 'errorbar'``. If ``True``, bin-width will be used.
label : str or list, optional
Label for legend entry.
colors : any valid mpl color or iterable thereof, optional
The color for each histogram
sort: {'label'/'l', 'yield'/'y'}, optional
Append '_r' for reverse.
edges : bool, default: True, optional
Expand Down Expand Up @@ -296,6 +299,16 @@ def histplot(
else:
_labels = [str(lab) for lab in label]

_colors: List
if colors is None:
_colors = [None] * len(plottables)
elif isinstance(colors, str):
_colors = [colors] * len(plottables)
elif not np.iterable(colors):
_colors = [str(colors)] * len(plottables)
else:
_colors = colors

def iterable_not_string(arg):
return isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str)

Expand Down Expand Up @@ -425,6 +438,8 @@ def iterable_not_string(arg):

_plot_info = plottables[i].to_stairs()
_plot_info["baseline"] = None if not edges else 0
if _colors[i] is not None:
_plot_info["color"] = _colors[i]
_s = ax.stairs(
**_plot_info,
label=_step_label,
Expand Down Expand Up @@ -454,8 +469,11 @@ def iterable_not_string(arg):
elif histtype == "fill":
for i in range(len(plottables)):
_kwargs = _chunked_kwargs[i]
_plot_info = {}
if _colors[i] is not None:
_plot_info["color"] = _colors[i]
_f = ax.stairs(
**plottables[i].to_stairs(), label=_labels[i], fill=True, **_kwargs
**_plot_info, **plottables[i].to_stairs(), label=_labels[i], fill=True, **_kwargs
)
return_artists.append(StairsArtists(_f, None, None))
_artist = _f
Expand All @@ -482,6 +500,8 @@ def iterable_not_string(arg):
if yerr is False:
_plot_info["yerr"] = None
_plot_info["xerr"] = _xerr
if _colors[i] is not None:
_plot_info["color"] = _colors[i]
_e = ax.errorbar(
**_plot_info,
label=_labels[i],
Expand Down
Binary file added tests/baseline/test_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,17 @@ def test_histplot_inputs_pass(h, yerr, htype):
fig, ax = plt.subplots()
hep.histplot(h, bins, yerr=yerr, histtype=htype)
plt.close(fig)


@pytest.mark.mpl_image_compare(style="default", remove_text=True)
def test_color():
hs, bins = [[2, 3, 4], [5, 4, 3]], [0, 1, 2, 3]
fig, axs = plt.subplots(3, 2, figsize=(8, 12))
colors = ['green', 'blue']
axs = axs.flatten()

for i, htype in enumerate(["step", "fill", "errorbar"]):
hep.histplot(hs[0], bins, yerr=True, histtype=htype, ax=axs[i * 2], alpha=0.7, color='red')
hep.histplot(hs, bins, yerr=True, histtype=htype, ax=axs[i * 2 + 1], alpha=0.7, color=colors)

return fig

0 comments on commit 269d545

Please sign in to comment.