Skip to content

Commit

Permalink
Don't validate returned Pandas dataframe strictly (#226)
Browse files Browse the repository at this point in the history
This PR filters out data from the pandas dataframe returned by the user
that is not defined in the component spec. Previously, returning
additional columns would raise an error. (see
#223 (comment))
  • Loading branch information
RobbeSneyders authored Jun 23, 2023
1 parent a711541 commit d072434
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion components/download_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
result_type="expand",
)

return dataframe[[("images", "data"), ("images", "width"), ("images", "height")]]
return dataframe


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
)
mask = mask.to_numpy()

dataframe = dataframe[mask]

dataframe = dataframe.drop(("text", "data"), axis=1)

return dataframe
return dataframe[mask]


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions fondant/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def optional_fondant_arguments() -> t.List[str]:
return ["input_manifest_path"]

def _load_or_create_manifest(self) -> Manifest:
# create initial manifest
# TODO ideally get rid of args.metadata by including them in the storage args

component_id = self.spec.name.lower().replace(" ", "_")
manifest = Manifest.create(
base_path=self.metadata["base_path"],
Expand Down Expand Up @@ -277,6 +274,15 @@ def wrapped_transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
tuple(column.split("_")) for column in dataframe.columns
)
dataframe = self.transform(dataframe)
# Drop columns not in the produces section of the component spec
dataframe.drop(
columns=[
(subset, field)
for (subset, field) in dataframe.columns
if subset not in self.spec.produces
or field not in self.spec.produces[subset].fields
]
)
dataframe.columns = [
"_".join(column) for column in dataframe.columns.to_flat_index()
]
Expand All @@ -300,9 +306,7 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame:
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
print(field.type.value)
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
# dtype=f"{field.type.value}[pyarrow]"
dtype=pd.ArrowDtype(field.type.value)
)
meta_df = pd.DataFrame(meta_dict).set_index("id")
Expand Down

0 comments on commit d072434

Please sign in to comment.