Skip to content

Commit

Permalink
Bugfix partitions Load from Hub (#380)
Browse files Browse the repository at this point in the history
A fix for loading the dataset when `n_rows_to_load` is specified. An
offset of 1 was missing which caused an issue when there was only 1
partition (npartitions from the enumerate started from 0 and an attempt
to call `df.head()` with 0 partitions was made)
  • Loading branch information
PhilippeMoussalli authored Aug 23, 2023
1 parent 828c8ab commit 50f3a97
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
3 changes: 2 additions & 1 deletion components/load_from_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def load(self) -> dd.DataFrame:
# 4) Optional: only return specific amount of rows
if self.n_rows_to_load is not None:
partitions_length = 0
for npartitions, partition in enumerate(dask_df.partitions):
npartitions = 1
for npartitions, partition in enumerate(dask_df.partitions, start=1):
if partitions_length >= self.n_rows_to_load:
logger.info(f"""Required number of partitions to load\n
{self.n_rows_to_load} is {npartitions}""")
Expand Down
25 changes: 15 additions & 10 deletions examples/pipelines/finetune_stable_diffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

logger = logging.getLogger(__name__)
# General configs
pipeline_name = "Test fondant pipeline"
pipeline_description = "A test pipeline"
pipeline_name = "stable_diffusion_pipeline"
pipeline_description = (
"Pipeline to prepare and collect data for finetuning stable diffusion"
)

load_component_column_mapping = {"image": "images_data", "text": "captions_data"}

Expand All @@ -25,7 +27,7 @@
"dataset_name": "logo-wizard/modern-logo-dataset",
"column_name_mapping": load_component_column_mapping,
"image_column_names": ["image"],
"nb_rows_to_load": None,
"n_rows_to_load": None,
},
)

Expand Down Expand Up @@ -81,12 +83,15 @@
number_of_gpus=1,
)

pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)
pipeline = Pipeline(
pipeline_name=pipeline_name,
base_path="/home/philippe/Scripts/express/local_artifact/new",
)

pipeline.add_op(load_from_hub_op)
pipeline.add_op(image_resolution_extraction_op, dependencies=load_from_hub_op)
pipeline.add_op(image_embedding_op, dependencies=image_resolution_extraction_op)
pipeline.add_op(laion_retrieval_op, dependencies=image_embedding_op)
pipeline.add_op(download_images_op, dependencies=laion_retrieval_op)
pipeline.add_op(caption_images_op, dependencies=download_images_op)
pipeline.add_op(write_to_hub, dependencies=caption_images_op)
# pipeline.add_op(image_resolution_extraction_op, dependencies=load_from_hub_op)
# pipeline.add_op(image_embedding_op, dependencies=image_resolution_extraction_op)
# pipeline.add_op(laion_retrieval_op, dependencies=image_embedding_op)
# pipeline.add_op(download_images_op, dependencies=laion_retrieval_op)
# pipeline.add_op(caption_images_op, dependencies=download_images_op)
# pipeline.add_op(write_to_hub, dependencies=caption_images_op)

0 comments on commit 50f3a97

Please sign in to comment.