Skip to content

Commit

Permalink
Fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Mar 26, 2024
1 parent 1dfec32 commit 70a18dc
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 100 deletions.
13 changes: 5 additions & 8 deletions src/fondant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from types import ModuleType

from fondant.core.schema import CloudCredentialsMount
from fondant.dataset import Dataset, Workspace
from fondant.dataset import Dataset

if t.TYPE_CHECKING:
from fondant.component import Component
Expand Down Expand Up @@ -610,12 +610,9 @@ def run_local(args):
# use workspace from cli command
# if args.workspace exists

workspace = getattr(args, "workspace", None)
if workspace is None:
workspace = Workspace(
name="dummy_workspace",
base_path=".artifacts",
) # TODO: handle in #887 -> retrieve global workspace or init default one
working_directory = getattr(args, "working_directory", None)
if working_directory is None:
working_directory = "./.fondant"

if args.extra_volumes:
extra_volumes.extend(args.extra_volumes)
Expand All @@ -628,7 +625,7 @@ def run_local(args):
runner = DockerRunner()
runner.run(
dataset=dataset,
workspace=workspace,
working_directory=working_directory,
extra_volumes=extra_volumes,
build_args=args.build_arg,
auth_provider=args.auth_provider,
Expand Down
52 changes: 29 additions & 23 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def retrieve_from_filesystem(uri: str) -> Resource:
def create(
cls,
*,
dataset_name: t.Optional[str] = None,
dataset_name: t.Optional[str] = "",
run_id: str,
component_id: t.Optional[str] = None,
cache_key: t.Optional[str] = None,
Expand Down Expand Up @@ -272,33 +272,36 @@ def evolve( # : PLR0912 (too many branches)
evolved_manifest.update_metadata(key="component_id", value=component_id)
evolved_manifest.update_metadata(key="run_id", value=run_id)

if working_directory:
evolved_manifest = self.evolve_manifest_index_and_field_locations(
component_id,
evolved_manifest,
operation_spec,
run_id,
working_directory,
)
evolved_manifest = self.evolve_manifest_index_and_field_locations(
component_id,
evolved_manifest,
operation_spec,
run_id,
working_directory,
)

return evolved_manifest

def evolve_manifest_index_and_field_locations( # noqa PLR0913
self,
component_id,
evolved_manifest,
operation_spec,
run_id,
working_dir,
component_id: str,
evolved_manifest: "Manifest",
operation_spec: OperationSpec,
run_id: str,
working_dir: t.Optional[str] = None,
):
# TODO: check when we should change the index?
"""Evolve the manifest index and field locations based on the component spec."""
# Update index location as this is always rewritten
evolved_manifest.add_or_update_field(
Field(
if working_dir:
field = Field.create(
name="index",
location=f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}",
),
)
working_dir=working_dir,
run_id=run_id,
component_id=component_id,
dataset_name=self.dataset_name,
)
evolved_manifest.add_or_update_field(field, overwrite=False)

# Remove all previous fields if the component changes the index
if operation_spec.previous_index:
for field_name in evolved_manifest.fields:
Expand All @@ -308,10 +311,13 @@ def evolve_manifest_index_and_field_locations( # noqa PLR0913
# If field was not part of the input manifest, add field to output manifest.
# If field was part of the input manifest and got produced by the component, update
# the manifest field.
field.location = (
f"{working_dir}/{self.dataset_name}/{run_id}/{component_id}"
evolved_field = field.update_location(
working_dir=working_dir,
run_id=run_id,
component_id=component_id,
dataset_name=self.dataset_name,
)
evolved_manifest.add_or_update_field(field, overwrite=True)
evolved_manifest.add_or_update_field(evolved_field, overwrite=True)

return evolved_manifest

Expand Down
29 changes: 28 additions & 1 deletion src/fondant/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
self,
name: str,
type: Type = Type("null"),
location: str = "",
location: t.Optional[str] = None,
) -> None:
self.name = name
self.type = type
Expand All @@ -274,6 +274,33 @@ def __repr__(self):
def __eq__(self, other):
return vars(self) == vars(other)

@classmethod
def create( # noqa: PLR0913
cls,
name: str,
run_id: str,
component_id: str,
dataset_name: str,
working_dir: t.Optional[str] = None,
):
"""Create a Field instance with the correct location based on the provided parameters."""
if working_dir:
location = f"{working_dir}/{dataset_name}/{run_id}/{component_id}"
return Field(name=name, location=location)
return Field(name=name)

def update_location(
self,
run_id: str,
component_id: str,
dataset_name: str,
working_dir: t.Optional[str] = None,
):
"""Update the location of the field based on the provided parameters."""
if working_dir:
self.location = f"{working_dir}/{dataset_name}/{run_id}/{component_id}"
return self


def validate_partition_size(arg_value):
if arg_value in ["disable", None, "None"]:
Expand Down
15 changes: 5 additions & 10 deletions src/fondant/core/schemas/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
"type": "object",
"properties": {
"dataset_name": {
"type": "string"
"type": ["string", "null"]
},
"manifest_location": {
"type": "string"
"type": ["string", "null"]
},
"run_id": {
"type": "string"
"type": ["string", "null"]
},
"component_id": {
"type": ["string", "null"]
Expand All @@ -29,31 +29,26 @@
"location": {
"type": "string"
}
},
"required": [
"location"
]
}
},
"fields": {
"$ref": "#/definitions/fields"
}
},
"required": [
"metadata",
"index",
"fields"
],
"definitions": {
"field": {
"type": "object",
"properties": {
"location": {
"type": "string",
"type": ["string", "null"],
"pattern": "/.*"
}
},
"required": [
"location",
"type"
]
},
Expand Down
4 changes: 2 additions & 2 deletions src/fondant/dataset/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def compile(
self,
dataset: Dataset,
*,
working_directory: t.Optional[str],
working_directory: t.Optional[str] = None,
output_path: str = "docker-compose.yml",
extra_volumes: t.Union[t.Optional[list], t.Optional[str]] = None,
build_args: t.Optional[t.List[str]] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def _generate_spec(
command.extend(
[
"--input_manifest_path",
f"{dataset.manifest.manifest_location()}",
f"{dataset.manifest.manifest_location}",
],
)

Expand Down
13 changes: 6 additions & 7 deletions src/fondant/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,8 @@ class Dataset:
def __init__(
self,
manifest: Manifest,
name: t.Optional[str] = None,
description: t.Optional[str] = None,
):
if name is not None:
self.name = self._validate_dataset_name(name)

self.description = description
self._graph: t.OrderedDict[str, t.Any] = OrderedDict()
self.task_without_dependencies_added = False
Expand All @@ -482,7 +478,7 @@ def __init__(
def _validate_dataset_name(name: str) -> str:
pattern = r"^[a-z0-9][a-z0-9_-]*$"
if not re.match(pattern, name):
msg = f"The workspace name violates the pattern {pattern}"
msg = f"The dataset name violates the pattern {pattern}"
raise InvalidWorkspaceDefinition(msg)
return name

Expand All @@ -492,6 +488,11 @@ def get_run_id(name) -> str:
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
return f"{name}-{timestamp}"

@property
def name(self) -> str:
"""The name of the dataset."""
return self.manifest.dataset_name

def register_operation(
self,
operation: ComponentOp,
Expand Down Expand Up @@ -820,8 +821,6 @@ def apply(
Returns:
An intermediate dataset.
"""
# TODO: add method call to retrieve workspace context, and make passing workspace optional

operation = ComponentOp.from_ref(
ref,
fields=self.fields,
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/example_modules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@


def create_dataset_with_args(name):
return Dataset(name)
return Dataset.create("load_from_parquet", dataset_name=name)


def create_dataset():
return Dataset("test_dataset")
return Dataset.create("load_from_parquet", dataset_name="test_dataset")


def not_implemented():
Expand Down
Loading

0 comments on commit 70a18dc

Please sign in to comment.