Skip to content

Commit

Permalink
fix(datasets) Fix the scale of value axis when plotting in absolute s…
Browse files Browse the repository at this point in the history
…izes (#4255)
  • Loading branch information
adam-narozniak authored Oct 3, 2024
1 parent 1465048 commit 52686d9
Showing 1 changed file with 51 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Comparison of label distribution plotting."""


from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
Expand All @@ -33,10 +33,10 @@
def plot_comparison_label_distribution(
partitioner_list: list[Partitioner],
label_name: Union[str, list[str]],
plot_type: str = "bar",
size_unit: str = "percent",
max_num_partitions: Optional[Union[int]] = 30,
partition_id_axis: str = "y",
plot_type: Literal["bar", "heatmap"] = "bar",
size_unit: Literal["percent", "absolute"] = "percent",
max_num_partitions: Optional[int] = 30,
partition_id_axis: Literal["x", "y"] = "y",
figsize: Optional[tuple[float, float]] = None,
subtitle: str = "Comparison of Per Partition Label Distribution",
titles: Optional[list[str]] = None,
Expand All @@ -55,14 +55,14 @@ def plot_comparison_label_distribution(
List of partitioners to be compared.
label_name : Union[str, List[str]]
Column name or list of column names identifying labels for each partitioner.
plot_type : str
plot_type : Literal["bar", "heatmap"]
Type of plot, either "bar" or "heatmap".
size_unit : str
size_unit : Literal["percent", "absolute"]
"absolute" for raw counts, or "percent" to normalize values to 100%.
max_num_partitions : Optional[int]
Maximum number of partitions to include in the plot. If None, all partitions
are included.
partition_id_axis : str
partition_id_axis : Literal["x", "y"]
Axis on which the partition IDs will be marked, either "x" or "y".
figsize : Optional[Tuple[float, float]]
Size of the figure. If None, a default size is calculated.
Expand Down Expand Up @@ -151,7 +151,10 @@ def plot_comparison_label_distribution(
f"{type(label_name)}"
)
figsize = _initialize_comparison_figsize(figsize, num_partitioners)
fig, axes = plt.subplots(1, num_partitioners, layout="constrained", figsize=figsize)
axes_sharing = _initialize_axis_sharing(size_unit, plot_type, partition_id_axis)
fig, axes = plt.subplots(
1, num_partitioners, layout="constrained", figsize=figsize, **axes_sharing
)

if titles is None:
titles = ["" for _ in range(num_partitioners)]
Expand Down Expand Up @@ -201,11 +204,12 @@ def plot_comparison_label_distribution(
axis.set_xlabel("")
axis.set_ylabel("")
axis.set_title(titles[idx])
for axis in axes[1:]:
axis.set_yticks([])
_set_tick_on_value_axes(axes, partition_id_axis, size_unit)

# Set up figure xlabel and ylabel
xlabel, ylabel = _initialize_comparison_xy_labels(plot_type, partition_id_axis)
xlabel, ylabel = _initialize_comparison_xy_labels(
plot_type, size_unit, partition_id_axis
)
fig.supxlabel(xlabel)
fig.supylabel(ylabel)
fig.suptitle(subtitle)
Expand All @@ -226,11 +230,13 @@ def _initialize_comparison_figsize(


def _initialize_comparison_xy_labels(
plot_type: str, partition_id_axis: str
plot_type: Literal["bar", "heatmap"],
size_unit: Literal["percent", "absolute"],
partition_id_axis: Literal["x", "y"],
) -> tuple[str, str]:
if plot_type == "bar":
xlabel = "Partition ID"
ylabel = "Class distribution"
ylabel = "Class distribution" if size_unit == "percent" else "Class Count"
elif plot_type == "heatmap":
xlabel = "Partition ID"
ylabel = "Label"
Expand All @@ -243,3 +249,34 @@ def _initialize_comparison_xy_labels(
xlabel, ylabel = ylabel, xlabel

return xlabel, ylabel


def _initialize_axis_sharing(
size_unit: Literal["percent", "absolute"],
plot_type: Literal["bar", "heatmap"],
partition_id_axis: Literal["x", "y"],
) -> dict[str, bool]:
# Do not intervene when the size_unit is percent and plot_type is heatmap
if size_unit == "percent":
return {}
if plot_type == "heatmap":
return {}
if partition_id_axis == "x":
return {"sharey": True}
if partition_id_axis == "y":
return {"sharex": True}
return {"sharex": False, "sharey": False}


def _set_tick_on_value_axes(
axes: list[Axes],
partition_id_axis: Literal["x", "y"],
size_unit: Literal["percent", "absolute"],
) -> None:
if partition_id_axis == "x" and size_unit == "absolute":
# Exclude this case due to sharing of y-axis (and thus y-ticks)
# They must remain set and the number are displayed only on the first plot
pass
else:
for axis in axes[1:]:
axis.set_yticks([])

0 comments on commit 52686d9

Please sign in to comment.