Skip to content

Commit

Permalink
Merge branch 'main' into use-correct-docstrfmt-version
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Oct 19, 2024
2 parents b6f4a2a + cf03e25 commit 9c95537
Show file tree
Hide file tree
Showing 21 changed files with 563 additions and 422 deletions.
2 changes: 1 addition & 1 deletion datasets/e2e/pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Flower Datasets with PyTorch"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
flwr-datasets = { path = "./../../", extras = ["vision"] }
torch = "^1.12.0"
torchvision = "^0.14.1"
Expand Down
2 changes: 1 addition & 1 deletion datasets/e2e/scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Flower Datasets with scikit-learn"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
flwr-datasets = { path = "./../../", extras = ["vision"] }
scikit-learn = "^1.2.0"
parameterized = "==0.9.0"
2 changes: 1 addition & 1 deletion datasets/e2e/tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Flower Datasets with TensorFlow"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.9,<3.11"
flwr-datasets = { path = "./../../", extras = ["vision"] }
tensorflow-cpu = "^2.9.1, !=2.11.1"
tensorflow-io-gcs-filesystem = "<0.35.0"
Expand Down
6 changes: 5 additions & 1 deletion datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ def _load_mocked_dataset_by_partial_download(
The dataset with the requested samples.
"""
dataset = datasets.load_dataset(
dataset_name, name=subset_name, split=split_name, streaming=True
dataset_name,
name=subset_name,
split=split_name,
streaming=True,
trust_remote_code=True,
)
dataset_list = []
# It's a list of dict such that each dict represent a single sample of the dataset
Expand Down
14 changes: 9 additions & 5 deletions datasets/flwr_datasets/visualization/bar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from matplotlib import colors as mcolors
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure


# pylint: disable=too-many-arguments,too-many-locals,too-many-branches
Expand Down Expand Up @@ -82,10 +83,11 @@ def _plot_bar(
if "stacked" not in plot_kwargs:
plot_kwargs["stacked"] = True

axis = dataframe.plot(
axis_df: Axes = dataframe.plot(
ax=axis,
**plot_kwargs,
)
assert axis_df is not None, "axis is None after plotting using DataFrame.plot()"

if legend:
if legend_kwargs is None:
Expand All @@ -104,20 +106,22 @@ def _plot_bar(
shift = min(0.05 + max_len_label_str / 100, 0.15)
legend_kwargs["bbox_to_anchor"] = (1.0 + shift, 0.5)

handles, legend_labels = axis.get_legend_handles_labels()
_ = axis.figure.legend(
handles, legend_labels = axis_df.get_legend_handles_labels()
figure = axis_df.figure
assert isinstance(figure, Figure), "figure extraction from axes is not a Figure"
_ = figure.legend(
handles=handles[::-1], labels=legend_labels[::-1], **legend_kwargs
)

# Heuristic to make the partition id on xticks non-overlapping
if partition_id_axis == "x":
xticklabels = axis.get_xticklabels()
xticklabels = axis_df.get_xticklabels()
if len(xticklabels) > 20:
# Make every other xtick label not visible
for i, label in enumerate(xticklabels):
if i % 2 == 1:
label.set_visible(False)
return axis
return axis_df


def _initialize_figsize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


# pylint: disable=too-many-arguments,too-many-locals
# mypy: disable-error-code="call-overload"
def plot_comparison_label_distribution(
partitioner_list: list[Partitioner],
label_name: Union[str, list[str]],
Expand Down Expand Up @@ -153,7 +154,11 @@ def plot_comparison_label_distribution(
figsize = _initialize_comparison_figsize(figsize, num_partitioners)
axes_sharing = _initialize_axis_sharing(size_unit, plot_type, partition_id_axis)
fig, axes = plt.subplots(
1, num_partitioners, layout="constrained", figsize=figsize, **axes_sharing
nrows=1,
ncols=num_partitioners,
figsize=figsize,
layout="constrained",
**axes_sharing,
)

if titles is None:
Expand Down
6 changes: 4 additions & 2 deletions datasets/flwr_datasets/visualization/label_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,7 @@ def plot_label_distributions(
plot_kwargs,
legend_kwargs,
)
assert axis is not None
return axis.figure, axis, dataframe
assert axis is not None, "axis is None after plotting"
figure = axis.figure
assert isinstance(figure, Figure), "figure extraction from axes is not a Figure"
return figure, axis, dataframe
4 changes: 2 additions & 2 deletions datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ exclude = [
]

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
numpy = "^1.21.0"
datasets = ">=2.14.6 <2.20.0"
datasets = ">=2.14.6 <=3.1.0"
pillow = { version = ">=6.2.1", optional = true }
soundfile = { version = ">=0.12.1", optional = true }
librosa = { version = ">=0.10.0.post2", optional = true }
Expand Down
Loading

0 comments on commit 9c95537

Please sign in to comment.