Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: flow bins xticklabel correction #502

Merged
merged 7 commits into from
Jun 10, 2024
166 changes: 112 additions & 54 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,68 +550,126 @@ def iterable_not_string(arg):
transform=ax.get_xaxis_transform(),
)

elif flow == "show" and (underflow > 0.0 or overflow > 0.0):
xticks = [label.get_text() for label in ax.get_xticklabels()]
elif flow == "show":
underflow_xticklabel = f"<{flow_bins[1]:.2f}"
overflow_xticklabel = f">{flow_bins[-2]:.2f}"

# Loop over shared x axes to get xticks and xticklabels
xticks, xticklabels = np.array([]), []
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
shared_axes = [
_ax for _ax in shared_axes if _ax.get_position().x0 == ax.get_position().x0
]
for _ax in shared_axes:
_xticks = _ax.get_xticks()
_xticklabels = [label.get_text() for label in _ax.get_xticklabels()]

# Check if underflow/overflow xtick already exists
if (
underflow_xticklabel in _xticklabels
or overflow_xticklabel in _xticklabels
):
xticks = _xticks
xticklabels = _xticklabels
break
elif len(_xticklabels) > 0:
xticks = _xticks
xticklabels = _xticklabels

lw = ax.spines["bottom"].get_linewidth()
_edges = plottables[0].edges
_centers = plottables[0].centers
_marker_size = (
20
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
)
if underflow > 0.0:
xticks[0] = ""
xticks[1] = f"<{flow_bins[2]}"
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(xticks)

ax.plot(
[_edges[0], _edges[1]],
[0, 0],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=ax.get_xaxis_transform(),
clip_on=False,
)
ax.scatter(
_centers[0],
0,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
if overflow > 0.0:
xticks[-1] = ""
xticks[-2] = f">{flow_bins[-3]}"
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(xticks)
ax.plot(
[_edges[-2], _edges[-1]],
[0, 0],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=ax.get_xaxis_transform(),
clip_on=False,
)
ax.scatter(
_centers[-1],
0,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
if underflow > 0.0 or underflow_xticklabel in xticklabels:
# Replace any existing xticks in underflow region with underflow bin center
_mask = xticks > flow_bins[1]
xticks = np.insert(xticks[_mask], 0, _centers[0])
xticklabels = [underflow_xticklabel] + [
xlab for i, xlab in enumerate(xticklabels) if _mask[i]
]

# Don't draw markers on the top of the top axis
top_axis = max(shared_axes, key=lambda a: a.get_position().y0)

# Draw on all shared axes
for _ax in shared_axes:
for h in [0, 1]:
_ax.set_xticks(xticks)
_ax.set_xticklabels(xticklabels)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you move the skip condition up here, wouldn't that prevent drawing the dashed line on top?

_ax.plot(
[_edges[0], _edges[1]],
[h, h],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=_ax.get_xaxis_transform(),
clip_on=False,
)

# Don't draw marker on the top of the top axis
if _ax == top_axis and h == 1:
continue

_ax.scatter(
_centers[0],
h,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=_ax.get_xaxis_transform(),
)
if overflow > 0.0 or overflow_xticklabel in xticklabels:
# Replace any existing xticks in overflow region with overflow bin center
_mask = xticks < flow_bins[-2]
xticks = np.insert(xticks[_mask], sum(_mask), _centers[-1])
xticklabels = [xlab for i, xlab in enumerate(xticklabels) if _mask[i]] + [
overflow_xticklabel
]

# Don't draw markers on the top of the top axis
top_axis = max(shared_axes, key=lambda a: a.get_position().y0)

# Draw on all shared axes
for _ax in shared_axes:
for h in [0, 1]:
_ax.set_xticks(xticks)
_ax.set_xticklabels(xticklabels)

_ax.plot(
[_edges[-2], _edges[-1]],
[h, h],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=_ax.get_xaxis_transform(),
clip_on=False,
)

# Don't draw marker on the top of the top axis
if _ax == top_axis and h == 1:
continue

_ax.scatter(
_centers[-1],
h,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=_ax.get_xaxis_transform(),
)

return return_artists

Expand Down
Loading