Skip to content

Commit

Permalink
tests: add tests for datasets (#10321)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Feb 26, 2024
1 parent f38490c commit 6cb37ac
Show file tree
Hide file tree
Showing 7 changed files with 651 additions and 43 deletions.
20 changes: 3 additions & 17 deletions dvc/commands/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,6 @@ def display(cls, name: str, dataset: "Dataset", action: str = "Adding"):
ui.write(action, ui.rich_text(name, "cyan"), text, styled=True)

def run(self):
from urllib.parse import urlsplit

d = vars(self.args)
url_obj = urlsplit(self.args.url)
if url_obj.scheme == "dvcx":
d["type"] = "dvcx"
elif url_obj.scheme.startswith("dvc"):
d["type"] = "dvc"
protos = tuple(url_obj.scheme.split("+"))
if not protos or protos == ("dvc",) or protos == ("dvc", "ssh"):
d["url"] = url_obj.netloc + url_obj.path
else:
d["url"] = url_obj._replace(scheme=protos[1]).geturl()
else:
d["type"] = "url"

existing = self.repo.datasets.get(self.args.name)
with self.repo.scm_context:
if not self.args.force and existing:
Expand All @@ -73,8 +57,9 @@ def run(self):
f"{self.args.name} already exists in {path}, "
"use the --force to overwrite"
)
dataset = self.repo.datasets.add(**d)
dataset = self.repo.datasets.add(**vars(self.args))
self.display(self.args.name, dataset)
return 0


class CmdDatasetUpdate(CmdBase):
Expand Down Expand Up @@ -133,6 +118,7 @@ def run(self):
)
return 1
self.display(self.args.name, dataset, new)
return 0


def add_parser(subparsers, parent_parser):
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def close(self):

def _reset(self):
self.scm._reset()
self.datasets._reset()
self.state.close()
if "dvcfs" in self.__dict__:
self.dvcfs.close()
Expand Down
82 changes: 61 additions & 21 deletions dvc/repo/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections.abc import Iterator, Mapping
from datetime import datetime
from functools import cached_property
Expand All @@ -11,10 +12,12 @@
from dvc.dvcfile import Lockfile, ProjectFile
from dvc.exceptions import DvcException
from dvc.log import logger
from dvc.types import StrPath
from dvc_data.hashfile.meta import Meta

if TYPE_CHECKING:
from dql.dataset import DatasetRecord # type: ignore[import]
from dvcx.dataset import DatasetVersion # type: ignore[import]
from typing_extensions import Self

from dvc.repo import Repo
Expand All @@ -23,6 +26,50 @@
logger = logger.getChild(__name__)


def parse_url_and_type(url: str):
from urllib.parse import urlsplit

if os.path.exists(url):
return {"type": "dvc", "url": url}

url_obj = urlsplit(url)
if url_obj.scheme == "dvcx":
return {"type": "dvcx", "url": url}
if url_obj.scheme and not url_obj.scheme.startswith("dvc"):
return {"type": "url", "url": url}

protos = tuple(url_obj.scheme.split("+"))
if not protos or protos == ("dvc",):
url = url_obj.netloc + url_obj.path
else:
url = url_obj._replace(scheme=protos[1]).geturl()
return {"type": "dvc", "url": url}


def _get_dataset_record(name: str) -> "DatasetRecord":
from dvc.exceptions import DvcException

try:
from dvcx.catalog import get_catalog # type: ignore[import]

except ImportError as exc:
raise DvcException("dvcx is not installed") from exc

catalog = get_catalog()
return catalog.get_remote_dataset(name)


def _get_dataset_info(
name: str, record: Optional["DatasetRecord"] = None, version: Optional[int] = None
) -> "DatasetVersion":
record = record or _get_dataset_record(name)
assert record
v = version or record.latest_version
assert v
assert v >= 1
return record.get_version(v)


def default_str(v) -> str:
return default_if_none("")(v)

Expand Down Expand Up @@ -160,22 +207,9 @@ def update(
version: Optional[int] = None,
**kwargs,
) -> "Self":
if not record:
try:
from dvcx.catalog import get_catalog # type: ignore[import]

except ImportError as exc:
raise DvcException("dvcx is not installed") from exc

name, _version = self.name_version
version = _version or version
catalog = get_catalog()
record = catalog.get_remote_dataset(name)

assert record is not None
ver = version or record.latest_version
assert ver
version_info = record.get_version(ver)
name, _version = self.name_version
version = version or _version
version_info = _get_dataset_info(name, record=record, version=version)
lock = DVCXDatasetLock(
**self.spec.to_dict(),
version=version_info.version,
Expand Down Expand Up @@ -264,6 +298,8 @@ def _lock(self) -> dict[str, Optional[dict[str, Any]]]:
datasets_lock = self.repo.index._datasets_lock

def find(path, name) -> Optional[dict[str, Any]]:
# only look for `name` in the lock file next to the
# corresponding `dvc.yaml` file
lock = datasets_lock.get(path, [])
return next((dataset for dataset in lock if dataset["name"] == name), None)

Expand Down Expand Up @@ -348,10 +384,14 @@ def _build_dataset(
raise ValueError(f"unknown dataset type: {spec.type!r}")

def add(
self, url: str, name: str, manifest_path: str = "dvc.yaml", **kwargs: Any
self,
url: str,
name: str,
manifest_path: StrPath = "dvc.yaml",
**kwargs: Any,
) -> Dataset:
kwargs.update({"url": url, "name": name})
dataset = self._build_dataset(manifest_path, kwargs)
spec = kwargs | parse_url_and_type(url) | {"name": name}
dataset = self._build_dataset(os.path.abspath(manifest_path), spec)
dataset = dataset.update(self.repo)

self.dump(dataset)
Expand All @@ -366,13 +406,13 @@ def update(self, name, **kwargs) -> tuple[Dataset, Dataset]:
self[name] = new
return dataset, new

def _dump_spec(self, manifest_path: str, spec: Spec) -> None:
def _dump_spec(self, manifest_path: StrPath, spec: Spec) -> None:
spec_data = spec.to_dict()
assert spec_data.keys() & {"type", "name", "url"}
project_file = ProjectFile(self.repo, manifest_path)
project_file.dump_dataset(spec_data)

def _dump_lock(self, manifest_path: str, lock: Lock) -> None:
def _dump_lock(self, manifest_path: StrPath, lock: Lock) -> None:
lock_data = lock.to_dict()
assert lock_data.keys() & {"type", "name", "url"}
lockfile = Lockfile(self.repo, Path(manifest_path).with_suffix(".lock"))
Expand Down
5 changes: 0 additions & 5 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,6 @@ def outs(self) -> Iterator["Output"]:
for stage in self.stages:
yield from stage.outs

@cached_property
def datasets(self) -> dict[str, dict[str, Any]]:
datasets = chain.from_iterable(self._datasets.values())
return {dataset["name"]: dataset for dataset in datasets}

@cached_property
def out_data_keys(self) -> dict[str, set["DataIndexKey"]]:
by_workspace: dict[str, set["DataIndexKey"]] = defaultdict(set)
Expand Down
Loading

0 comments on commit 6cb37ac

Please sign in to comment.