diff --git a/src/mplhep/plot.py b/src/mplhep/plot.py index 16dcb7d1..6992e47d 100644 --- a/src/mplhep/plot.py +++ b/src/mplhep/plot.py @@ -550,8 +550,32 @@ def iterable_not_string(arg): transform=ax.get_xaxis_transform(), ) - elif flow == "show" and (underflow > 0.0 or overflow > 0.0): - xticks = [label.get_text() for label in ax.get_xticklabels()] + elif flow == "show": + underflow_xticklabel = f"<{flow_bins[1]:.2g}" + overflow_xticklabel = f">{flow_bins[-2]:.2g}" + + # Loop over shared x axes to get xticks and xticklabels + xticks, xticklabels = np.array([]), [] + shared_axes = ax.get_shared_x_axes().get_siblings(ax) + shared_axes = [ + _ax for _ax in shared_axes if _ax.get_position().x0 == ax.get_position().x0 + ] + for _ax in shared_axes: + _xticks = _ax.get_xticks() + _xticklabels = [label.get_text() for label in _ax.get_xticklabels()] + + # Check if underflow/overflow xtick already exists + if ( + underflow_xticklabel in _xticklabels + or overflow_xticklabel in _xticklabels + ): + xticks = _xticks + xticklabels = _xticklabels + break + elif len(_xticklabels) > 0: + xticks = _xticks + xticklabels = _xticklabels + lw = ax.spines["bottom"].get_linewidth() _edges = plottables[0].edges _centers = plottables[0].centers @@ -559,59 +583,92 @@ def iterable_not_string(arg): 20 * ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width ) - if underflow > 0.0: - xticks[0] = "" - xticks[1] = f"<{flow_bins[2]}" - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels(xticks) - ax.plot( - [_edges[0], _edges[1]], - [0, 0], - color="white", - zorder=5, - ls="--", - lw=lw, - transform=ax.get_xaxis_transform(), - clip_on=False, - ) - ax.scatter( - _centers[0], - 0, - _marker_size, - marker=align_marker("d", valign="center"), - edgecolor="black", - zorder=5, - clip_on=False, - facecolor="white", - transform=ax.get_xaxis_transform(), - ) - if overflow > 0.0: - xticks[-1] = "" - xticks[-2] = f">{flow_bins[-3]}" - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels(xticks) - ax.plot( - [_edges[-2], _edges[-1]], - [0, 0], - color="white", - zorder=5, - ls="--", - lw=lw, - transform=ax.get_xaxis_transform(), - clip_on=False, - ) - ax.scatter( - _centers[-1], - 0, - _marker_size, - marker=align_marker("d", valign="center"), - edgecolor="black", - zorder=5, - clip_on=False, - facecolor="white", - transform=ax.get_xaxis_transform(), - ) + if underflow > 0.0 or underflow_xticklabel in xticklabels: + # Replace any existing xticks in underflow region with underflow bin center + _mask = xticks > flow_bins[1] + xticks = np.insert(xticks[_mask], 0, _centers[0]) + xticklabels = [underflow_xticklabel] + [ + xlab for i, xlab in enumerate(xticklabels) if _mask[i] + ] + + # Don't draw markers on the top of the top axis + top_axis = max(shared_axes, key=lambda a: a.get_position().y0) + + # Draw on all shared axes + for _ax in shared_axes: + _ax.set_xticks(xticks) + _ax.set_xticklabels(xticklabels) + for h in [0, 1]: + # Don't draw marker on the top of the top axis + if _ax == top_axis and h == 1: + continue + + _ax.plot( + [_edges[0], _edges[1]], + [h, h], + color="white", + zorder=5, + ls="--", + lw=lw, + transform=_ax.get_xaxis_transform(), + clip_on=False, + ) + + _ax.scatter( + _centers[0], + h, + _marker_size, + marker=align_marker("d", valign="center"), + edgecolor="black", + zorder=5, + clip_on=False, + facecolor="white", + transform=_ax.get_xaxis_transform(), + ) + if overflow > 0.0 or overflow_xticklabel in xticklabels: + # Replace any existing xticks in overflow region with overflow bin center + _mask = xticks < flow_bins[-2] + xticks = np.insert(xticks[_mask], sum(_mask), _centers[-1]) + xticklabels = [xlab for i, xlab in enumerate(xticklabels) if _mask[i]] + [ + overflow_xticklabel + ] + + # Don't draw markers on the top of the top axis + top_axis = max(shared_axes, key=lambda a: a.get_position().y0) + + # Draw on all shared axes + for _ax in shared_axes: + _ax.set_xticks(xticks) + _ax.set_xticklabels(xticklabels) + + for h in [0, 1]: + # Don't draw marker on the top of the top axis + if _ax == top_axis and h == 1: + continue + + _ax.plot( + [_edges[-2], _edges[-1]], + [h, h], + color="white", + zorder=5, + ls="--", + lw=lw, + transform=_ax.get_xaxis_transform(), + clip_on=False, + ) + + _ax.scatter( + _centers[-1], + h, + _marker_size, + marker=align_marker("d", valign="center"), + edgecolor="black", + zorder=5, + clip_on=False, + facecolor="white", + transform=_ax.get_xaxis_transform(), + ) return return_artists diff --git a/tests/baseline/test_histplot_flow.png b/tests/baseline/test_histplot_flow.png index a4682a51..b4023b0a 100644 Binary files a/tests/baseline/test_histplot_flow.png and b/tests/baseline/test_histplot_flow.png differ diff --git a/tests/baseline/test_histplot_hist_flow_no_variances.png b/tests/baseline/test_histplot_hist_flow_no_variances.png index dc23efc0..cbb4ba7b 100644 Binary files a/tests/baseline/test_histplot_hist_flow_no_variances.png and b/tests/baseline/test_histplot_hist_flow_no_variances.png differ diff --git a/tests/baseline/test_histplot_hist_flow_variances.png b/tests/baseline/test_histplot_hist_flow_variances.png index 33204359..7b08b5fd 100644 Binary files a/tests/baseline/test_histplot_hist_flow_variances.png and b/tests/baseline/test_histplot_hist_flow_variances.png differ diff --git a/tests/baseline/test_histplot_uproot_flow.png b/tests/baseline/test_histplot_uproot_flow.png index dd0b90c6..9f04c1e0 100644 Binary files a/tests/baseline/test_histplot_uproot_flow.png and b/tests/baseline/test_histplot_uproot_flow.png differ