Skip to content

Commit

Permalink
Apply pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Mar 22, 2024
1 parent cb4672e commit 1dfec32
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 85 deletions.
5 changes: 1 addition & 4 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):

def _write_dataframe(self, dataframe: dd.DataFrame) -> None:
"""Create dataframe writing task."""
location = (
f"{self.manifest.base_path}/{self.manifest.pipeline_name}/"
f"{self.manifest.run_id}/{self.operation_spec.component_name}"
)
location = self.manifest.dataset_location

# Create directory the dataframe will be written to, since this is not handled by Pandas
# `to_parquet` method.
Expand Down
9 changes: 3 additions & 6 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,6 @@ def execute(self, component_cls: t.Type[Component]) -> None:
component_cls: The class of the component to execute
"""
input_manifest = self._load_or_create_manifest()
base_path = input_manifest.base_path
pipeline_name = input_manifest.pipeline_name

if self.cache and self._is_previous_cached(input_manifest):
cache_reference_content = self._get_cache_reference_content()
Expand Down Expand Up @@ -342,8 +340,8 @@ def execute(self, component_cls: t.Type[Component]) -> None:
self.upload_manifest(output_manifest, save_path=self.output_manifest_path)

self._upload_cache_reference_content(
base_path=base_path,
pipeline_name=pipeline_name,
base_path="todo",
pipeline_name="todo",
)

def _upload_cache_reference_content(
Expand Down Expand Up @@ -408,8 +406,7 @@ def optional_fondant_arguments() -> t.List[str]:

def _load_or_create_manifest(self) -> Manifest:
return Manifest.create(
pipeline_name=self.metadata.pipeline_name,
base_path=self.metadata.base_path,
dataset_name=self.metadata.dataset_name,
run_id=self.metadata.run_id,
component_id=self.metadata.component_id,
cache_key=self.metadata.cache_key,
Expand Down
56 changes: 47 additions & 9 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ class Metadata:
Args:
dataset_name: the name of the pipeline
manifest_location: path to the manifest file itself
run_id: the run id of the pipeline
component_id: the name of the component
cache_key: the cache key of the component.
manifest_location: path to the manifest file itself
dataset_location: path to the stored parquet files
"""

dataset_name: t.Optional[str]
manifest_location: t.Optional[str]
run_id: str
component_id: t.Optional[str]
cache_key: t.Optional[str]
manifest_location: t.Optional[str]
dataset_location: t.Optional[str]

def to_dict(self):
return asdict(self)
Expand Down Expand Up @@ -97,6 +99,8 @@ def create(
run_id: str,
component_id: t.Optional[str] = None,
cache_key: t.Optional[str] = None,
manifest_location: t.Optional[str] = None,
dataset_location: t.Optional[str] = None,
) -> "Manifest":
"""Create an empty manifest.
Expand All @@ -105,12 +109,16 @@ def create(
run_id: The id of the current pipeline run
component_id: The id of the current component being executed
cache_key: The component cache key
manifest_location: location of the manifest.json file itself
dataset_location: location of the dataset parquet files
"""
metadata = Metadata(
dataset_name=dataset_name,
run_id=run_id,
component_id=component_id,
cache_key=cache_key,
manifest_location=manifest_location,
dataset_location=dataset_location,
)

specification = {
Expand All @@ -133,9 +141,14 @@ def to_file(self, path: t.Union[str, Path]) -> None:
with fs_open(path, "w", encoding="utf-8", auto_mkdir=True) as file_:
json.dump(self._specification, file_)

def get_location(self):
@property
def manifest_location(self):
return self._specification["metadata"]["manifest_location"]

@property
def dataset_location(self):
return self._specification["metadata"]["dataset_location"]

def copy(self) -> "Manifest":
"""Return a deep copy of itself."""
return self.__class__(copy.deepcopy(self._specification))
Expand Down Expand Up @@ -240,7 +253,7 @@ def evolve( # : PLR0912 (too many branches)
operation_spec: OperationSpec,
*,
run_id: str,
working_dir: str,
working_directory: t.Optional[str] = None,
) -> "Manifest":
"""Evolve the manifest based on the component spec. The resulting
manifest is the expected result if the current manifest is provided
Expand All @@ -250,6 +263,7 @@ def evolve( # : PLR0912 (too many branches)
operation_spec: the operation spec
run_id: the run id to include in the evolved manifest. If no run id is provided,
the run id from the original manifest is propagated.
working_directory: path of the working directory
"""
evolved_manifest = self.copy()

Expand All @@ -258,30 +272,54 @@ def evolve( # : PLR0912 (too many branches)
evolved_manifest.update_metadata(key="component_id", value=component_id)
evolved_manifest.update_metadata(key="run_id", value=run_id)

if working_directory:
evolved_manifest = self.evolve_manifest_index_and_field_locations(
component_id,
evolved_manifest,
operation_spec,
run_id,
working_directory,
)

return evolved_manifest

def evolve_manifest_index_and_field_locations( # noqa PLR0913
self,
component_id,
evolved_manifest,
operation_spec,
run_id,
working_dir,
):
# TODO: check when we should change the index?
# Update index location as this is always rewritten
evolved_manifest.add_or_update_field(
Field(name="index", location=f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}"),
Field(
name="index",
location=f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}",
),
)

# Remove all previous fields if the component changes the index
if operation_spec.previous_index:
for field_name in evolved_manifest.fields:
evolved_manifest.remove_field(field_name)

# Add or update all produced fields defined in the component spec
for name, field in operation_spec.produces_to_dataset.items():
# If field was not part of the input manifest, add field to output manifest.
# If field was part of the input manifest and got produced by the component, update
# the manifest field.
field.location = f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}"
field.location = (
f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}"
)
evolved_manifest.add_or_update_field(field, overwrite=True)

return evolved_manifest

def contains_data(self) -> bool:
"""Check if the manifest contains data. Checks if any dataset fields exists.
Is false in case the dataset manifest was initialised but no data added yet. In this case
the manifest only contains metadata like dataset name and run id."""
the manifest only contains metadata like dataset name and run id.
"""
return bool(self._specification["fields"])

def __repr__(self) -> str:
Expand Down
70 changes: 40 additions & 30 deletions src/fondant/dataset/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
VALID_VERTEX_ACCELERATOR_TYPES,
Dataset,
Image,
Workspace,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -216,11 +215,22 @@ def _generate_spec(
previous_component_cache=component_cache_key,
)

# Generate default values for manifest and dataset location based on working_dir
manifest_location = (
f"{working_directory}/{dataset.name}/{run_id}/{component_id}"
f"/manifest.json"
)
dataset_location = (
f"{working_directory}/{dataset.name}/{run_id}/{component_id}/data"
)

metadata = Metadata(
dataset_name=dataset.name,
run_id=run_id,
component_id=component_id,
cache_key=component_cache_key,
manifest_location=manifest_location,
dataset_location=dataset_location,
)

logger.info(f"Compiling service for {component_id}")
Expand Down Expand Up @@ -266,7 +276,7 @@ def _generate_spec(
command.extend(
[
"--input_manifest_path",
f"{dataset.manifest.get_location()}",
f"{dataset.manifest.manifest_location()}",
],
)

Expand Down Expand Up @@ -488,21 +498,21 @@ def _resolve_imports(self):
def compile(
self,
dataset: Dataset,
workspace: Workspace,
working_directory: str,
output_path: str,
) -> None:
"""Compile a pipeline to Kubeflow pipeline spec and save it to a specified output path.
Args:
dataset: the dataset to compile
workspace: workspace to operate in
working_directory: path of the working directory
output_path: the path where to save the Kubeflow pipeline spec
"""
# TODO: add method call to retrieve workspace context, and make passing workspace optional

run_id = workspace.get_run_id()
dataset.validate(run_id=run_id, workspace=workspace)
logger.info(f"Compiling {workspace.name} to {output_path}")
run_id = dataset.manifest.run_id
dataset.validate(run_id=run_id)
logger.info(f"Compiling {dataset.name} to {output_path}")

def set_component_exec_args(
*,
Expand All @@ -528,7 +538,7 @@ def set_component_exec_args(

return component_op

@self.kfp.dsl.pipeline(name=workspace.name, description=workspace.description)
@self.kfp.dsl.pipeline(name=dataset.name, description=dataset.description)
def kfp_pipeline():
previous_component_task = None
component_cache_key = None
Expand Down Expand Up @@ -559,23 +569,21 @@ def kfp_pipeline():
previous_component_cache=component_cache_key,
)
metadata = Metadata(
pipeline_name=workspace.name,
run_id=run_id,
base_path=workspace.base_path,
component_id=component_name,
cache_key=component_cache_key,
)

output_manifest_path = (
f"{workspace.base_path}/{workspace.name}/"
f"{run_id}/{component_name}/manifest.json"
f"{working_directory}/{metadata.dataset_name}/{metadata.run_id}/"
f"{metadata.component_id}/manifest.json",
)
# Set the execution order of the component task to be after the previous
# component task.
if component["dependencies"]:
for dependency in component["dependencies"]:
input_manifest_path = (
f"{workspace.base_path}/{workspace.name}/"
f"{working_directory}/{metadata.dataset_name}/{metadata.run_id}"
f"{run_id}/{dependency}/manifest.json"
)
kubeflow_component_op = set_component_exec_args(
Expand Down Expand Up @@ -614,7 +622,7 @@ def kfp_pipeline():

previous_component_task = component_task

logger.info(f"Compiling {workspace.name} to {output_path}")
logger.info(f"Compiling {dataset.name} to {output_path}")

self.kfp.compiler.Compiler().compile(kfp_pipeline, output_path) # type: ignore
logger.info("Pipeline compiled successfully")
Expand Down Expand Up @@ -742,6 +750,7 @@ def _resolve_imports(self):
def _build_command(
self,
metadata: Metadata,
working_directory: str,
arguments: t.Dict[str, t.Any],
dependencies: t.List[str] = [],
) -> t.List[str]:
Expand All @@ -752,7 +761,7 @@ def _build_command(
command.extend(
[
"--output_manifest_path",
f"{metadata.base_path}/{metadata.pipeline_name}/{metadata.run_id}/"
f"{working_directory}/{metadata.dataset_name}/{metadata.run_id}/"
f"{metadata.component_id}/manifest.json",
],
)
Expand All @@ -771,7 +780,7 @@ def _build_command(
command.extend(
[
"--input_manifest_path",
f"{metadata.base_path}/{metadata.pipeline_name}/{metadata.run_id}/"
f"{working_directory}/{metadata.dataset_name}/{metadata.run_id}/"
f"{dependency}/manifest.json",
],
)
Expand Down Expand Up @@ -843,7 +852,7 @@ def validate_base_path(self, base_path: str) -> None:
def compile(
self,
dataset: Dataset,
workspace: Workspace,
working_directory: str,
output_path: str,
*,
role_arn: t.Optional[str] = None,
Expand All @@ -853,20 +862,17 @@ def compile(
Args:
dataset: the dataset to compile
workspace: workspace to operate in
working_directory: path of the working directory
output_path: the path where to save the sagemaker pipeline spec.
role_arn: the Amazon Resource Name role to use for the processing steps,
if none provided the `sagemaker.get_execution_role()` role will be used.
"""
# TODO: add method call to retrieve workspace context, and make passing workspace optional

self.ecr_client = self.boto3.client("ecr")
self.validate_base_path(workspace.base_path)
self.validate_base_path(working_directory)
self._check_ecr_pull_through_rule()

run_id = workspace.get_run_id()
path = workspace.base_path
dataset.validate(run_id=run_id, workspace=workspace)
run_id = dataset.manifest.run_id
dataset.validate(run_id=run_id)

component_cache_key = None

Expand All @@ -880,19 +886,23 @@ def compile(
)

metadata = Metadata(
pipeline_name=workspace.name,
run_id=run_id,
base_path=path,
component_id=component_name,
cache_key=component_cache_key,
dataset_name=dataset.name,
manifest_location=f"{working_directory}/{dataset.name}/{run_id}/"
f"{component_name}/manifest.json",
dataset_location=f"{working_directory}/{dataset.name}/{run_id}/"
f"{component_name}/data",
)

logger.info(f"Compiling service for {component_name}")

command = self._build_command(
metadata,
component_op.arguments,
component["dependencies"],
metadata=metadata,
working_directory=working_directory,
arguments=component_op.arguments,
dependencies=component["dependencies"],
)
depends_on = [steps[-1]] if component["dependencies"] else []

Expand Down Expand Up @@ -932,7 +942,7 @@ def compile(
steps.append(step)

sagemaker_pipeline = self.sagemaker.workflow.pipeline.Pipeline(
name=workspace.name,
name=dataset.name,
steps=steps,
)
with open(output_path, "w") as outfile:
Expand Down
Loading

0 comments on commit 1dfec32

Please sign in to comment.