Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialise dataset from manifest #911

10 changes: 2 additions & 8 deletions examples/sample_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@
import pyarrow as pa

from fondant.component import PandasTransformComponent
from fondant.dataset import Workspace, lightweight_component, Dataset

BASE_PATH = Path("./.artifacts").resolve()

# Define pipeline
workspace = Workspace(name="dummy-pipeline", base_path=str(BASE_PATH))
from fondant.dataset import lightweight_component, Dataset

# Load from hub component
load_component_column_mapping = {
"text": "text_data",
}

dataset = Dataset.read(
dataset = Dataset.create(
"load_from_parquet",
arguments={
"dataset_uri": "/data/sample.parquet",
"column_name_mapping": load_component_column_mapping,
},
produces={"text_data": pa.string()},
workspace=workspace,
)

dataset = dataset.apply("./components/dummy_component")
Expand Down
41 changes: 29 additions & 12 deletions src/fondant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from types import ModuleType

from fondant.core.schema import CloudCredentialsMount
from fondant.dataset import Dataset, Workspace
from fondant.dataset import Dataset

if t.TYPE_CHECKING:
from fondant.component import Component
Expand Down Expand Up @@ -607,15 +607,10 @@ def run_local(args):
from fondant.dataset.runner import DockerRunner

extra_volumes = []
# use workspace from cli command
# if args.workspace exists

workspace = getattr(args, "workspace", None)
if workspace is None:
workspace = Workspace(
name="dummy_workspace",
base_path=".artifacts",
) # TODO: handle in #887 -> retrieve global workspace or init default one
working_directory = getattr(args, "working_directory", None)
if working_directory is None:
working_directory = "./.artifacts/dataset"

if args.extra_volumes:
extra_volumes.extend(args.extra_volumes)
Expand All @@ -628,7 +623,7 @@ def run_local(args):
runner = DockerRunner()
runner.run(
dataset=dataset,
workspace=workspace,
working_directory=working_directory,
extra_volumes=extra_volumes,
build_args=args.build_arg,
auth_provider=args.auth_provider,
Expand All @@ -647,7 +642,12 @@ def run_kfp(args):
ref = args.ref

runner = KubeflowRunner(host=args.host)
runner.run(dataset=ref)

working_directory = getattr(args, "working_directory", None)
if working_directory is None:
working_directory = "./.artifacts/dataset"

runner.run(dataset=ref, working_directory=working_directory)


def run_vertex(args):
Expand All @@ -664,7 +664,12 @@ def run_vertex(args):
service_account=args.service_account,
network=args.network,
)
runner.run(input=ref)

working_directory = getattr(args, "working_directory", None)
if working_directory is None:
working_directory = "./.artifacts/dataset"

runner.run(input=ref, working_directory=working_directory)


def run_sagemaker(args):
Expand All @@ -676,10 +681,16 @@ def run_sagemaker(args):
ref = args.ref

runner = SagemakerRunner()

working_directory = getattr(args, "working_directory", None)
if working_directory is None:
working_directory = "./.artifacts/dataset"

runner.run(
input=ref,
pipeline_name=args.pipeline_name,
role_arn=args.role_arn,
working_directory=working_directory,
)


Expand Down Expand Up @@ -709,6 +720,12 @@ def register_execute(parent_parser):
action="store",
)

parser.add_argument(
"working_directory",
help="""Reference to the module containing the component to run""",
action="store",
)

parser.set_defaults(func=execute)


Expand Down
19 changes: 13 additions & 6 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,23 @@ def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):

def _write_dataframe(self, dataframe: dd.DataFrame) -> None:
"""Create dataframe writing task."""
location = (
f"{self.manifest.base_path}/{self.manifest.pipeline_name}/"
f"{self.manifest.run_id}/{self.operation_spec.component_name}"
location = self.manifest.get_dataset_columns_locations(
columns=dataframe.columns,
)

if len(set(location)) > 1:
msg = "Writing to multiple locations is currently not supported."
raise ValueError(
msg,
)

output_location_path = location[0]

# Create directory the dataframe will be written to, since this is not handled by Pandas
# `to_parquet` method.
protocol = fsspec.utils.get_protocol(location)
protocol = fsspec.utils.get_protocol(output_location_path)
fs = fsspec.get_filesystem_class(protocol)
fs().makedirs(location)
fs().makedirs(output_location_path)

schema = {
field.name: field.type.value
Expand All @@ -226,7 +233,7 @@ def _write_dataframe(self, dataframe: dd.DataFrame) -> None:
# https://dask.discourse.group/t/improving-pipeline-resilience-when-using-to-parquet-and-preemptible-workers/2141
to_parquet_tasks = [
d.to_parquet(
os.path.join(location, f"part.{i}.parquet"),
os.path.join(output_location_path, f"part.{i}.parquet"),
schema=pa.schema(list(schema.items())),
index=True,
)
Expand Down
43 changes: 30 additions & 13 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import json
import logging
import os
import typing as t
from abc import abstractmethod
from distutils.util import strtobool
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
user_arguments: t.Dict[str, t.Any],
input_partition_rows: int,
previous_index: t.Optional[str] = None,
working_directory: str,
) -> None:
self.operation_spec = operation_spec
self.cache = cache
Expand All @@ -70,6 +72,7 @@ def __init__(
self.user_arguments = user_arguments
self.input_partition_rows = input_partition_rows
self.previous_index = previous_index
self.working_directory = working_directory

@classmethod
def from_args(cls) -> "Executor":
Expand All @@ -78,6 +81,12 @@ def from_args(cls) -> "Executor":
parser.add_argument("--operation_spec", type=json.loads)
parser.add_argument("--cache", type=lambda x: bool(strtobool(x)))
parser.add_argument("--input_partition_rows", type=int)
parser.add_argument(
"--working_directory",
type=str,
default="./artifacts/dataset",
)

args, _ = parser.parse_known_args()

if "operation_spec" not in args:
Expand All @@ -86,10 +95,17 @@ def from_args(cls) -> "Executor":

operation_spec = OperationSpec.from_dict(args.operation_spec)

working_directory = (
args.working_directory
if hasattr(args, "working_directory")
else os.path.normpath("./artifacts/dataset")
)

return cls.from_spec(
operation_spec,
cache=args.cache,
input_partition_rows=args.input_partition_rows,
working_directory=working_directory,
)

@classmethod
Expand All @@ -99,6 +115,7 @@ def from_spec(
*,
cache: bool,
input_partition_rows: int,
working_directory: str,
) -> "Executor":
"""Create an executor from a component spec."""
args_dict = vars(cls._add_and_parse_args(operation_spec))
Expand Down Expand Up @@ -126,6 +143,7 @@ def from_spec(
user_arguments=args_dict,
input_partition_rows=input_partition_rows,
previous_index=operation_spec.previous_index,
working_directory=working_directory,
)

@classmethod
Expand Down Expand Up @@ -227,7 +245,7 @@ def _get_cache_reference_content(self) -> t.Union[str, None]:
The content of the cache reference file.
"""
manifest_reference_path = (
f"{self.metadata.base_path}/{self.metadata.pipeline_name}/cache/"
f"{self.working_directory}/{self.metadata.dataset_name}/cache/"
f"{self.metadata.cache_key}.txt"
)

Expand Down Expand Up @@ -297,6 +315,7 @@ def _run_execution(
output_manifest = input_manifest.evolve(
operation_spec=self.operation_spec,
run_id=self.metadata.run_id,
working_directory=self.working_directory,
)
self._write_data(dataframe=output_df, manifest=output_manifest)

Expand All @@ -310,12 +329,12 @@ def execute(self, component_cls: t.Type[Component]) -> None:

Args:
component_cls: The class of the component to execute
working_directory: The working directory where the dataset artifacts will be stored
"""
input_manifest = self._load_or_create_manifest()
base_path = input_manifest.base_path
pipeline_name = input_manifest.pipeline_name

if self.cache and self._is_previous_cached(input_manifest):
logger.info("Caching is currently temporarily disabled.")
cache_reference_content = self._get_cache_reference_content()

if cache_reference_content is not None:
Expand All @@ -333,7 +352,6 @@ def execute(self, component_cls: t.Type[Component]) -> None:
output_manifest = None
else:
output_manifest = self._run_execution(component_cls, input_manifest)

else:
logger.info("Caching disabled for the component")
output_manifest = self._run_execution(component_cls, input_manifest)
Expand All @@ -342,14 +360,14 @@ def execute(self, component_cls: t.Type[Component]) -> None:
self.upload_manifest(output_manifest, save_path=self.output_manifest_path)

self._upload_cache_reference_content(
base_path=base_path,
pipeline_name=pipeline_name,
working_directory=self.working_directory,
dataset_name=self.metadata.dataset_name,
)

def _upload_cache_reference_content(
self,
base_path: str,
pipeline_name: str,
working_directory: str,
dataset_name: str,
):
"""
Write the cache key containing the reference to the location of the written manifest.
Expand All @@ -359,11 +377,11 @@ def _upload_cache_reference_content(
cached component executions.

Args:
base_path: The base path of the pipeline.
pipeline_name: The name of the pipeline.
working_directory: Working directory where the dataset artifacts are stored.
dataset_name: The name of the dataset.
"""
cache_reference_path = (
f"{base_path}/{pipeline_name}/cache/{self.metadata.cache_key}.txt"
f"{working_directory}/{dataset_name}/cache/{self.metadata.cache_key}.txt"
)

logger.info(
Expand Down Expand Up @@ -408,8 +426,7 @@ def optional_fondant_arguments() -> t.List[str]:

def _load_or_create_manifest(self) -> Manifest:
return Manifest.create(
pipeline_name=self.metadata.pipeline_name,
base_path=self.metadata.base_path,
dataset_name=self.metadata.dataset_name,
run_id=self.metadata.run_id,
component_id=self.metadata.component_id,
cache_key=self.metadata.cache_key,
Expand Down
Loading
Loading