Skip to content

Commit

Permalink
add support for tracking remote dataset (#10287)
Browse files Browse the repository at this point in the history
* add support for tracking remote dataset

This adds support for virtually tracking a
dataset such as dvcx dataset, dataset from remote dvc/git
registries, and cloud-versioned remotes.

This PR introduces two different commands under `ds` namespace:
`add` and `update`.

`dvc ds add` command adds the dataset to `dvc.yaml` and tracks
the sources in `dvc.lock` file. Similarly, `dvc ds update`
updates the sources in `dvc.lock` file.

Example Usage:
```console
dvc ds add --name dogs --url dvcx://dogs # tracking dvcx dataset
# tracking dvc/git registries
dvc ds add --name example --url dvc://[email protected]/iterative/example.git
# cloud-versioned-remotes
dvc ds add --name versioning --url s3://cloud-versioning-demo
```

To update, specify the name.
```console
dvc ds update <name>
```

`dvc ds add` freezes (in the traditional sense of `dvc import/import-url`).
It keeps a "specification" of the dataset in `dvc.yaml` and also freezes
information about the dataset in the `dvc.lock` file. They are kept
inside `datasets` section in both `dvc.yaml` and `dvc.lock` files.

This metadata is used in the pipelines. You can add a dependency to your
stage using `ds://` scheme, followed by the name of the dataset in
`dvc.yaml` file. As it is used in pipelines, the `name` of the dataset
has to be unique in the repository. Different dataset of same names
are not allowed. On `dvc repro`, `dvc` copies the frozen information
about the particular dataset into the `deps` field for the stage in
`dvc.lock`. On successive invocation, `dvc` will compare the information
from the deps field of the lock with the frozen information in the
`datasets` section and decides whether to rerun or not.

As the dataset is frozen, `dvc repro` won't rerun until the dataset is
updated via `dvc update`.

Here are some examples for how `dvc.yaml` looks like:
```yaml
datasets:
- name: dogs
  url: dvcx://dogs
  type: dvcx
- name: example-get-started
  url: [email protected]:iterative/example-get-started.git
  type: dvc
  path: path
- name: dogs2
  url: dvcx://dogs@v2
  type: dvcx
- name: cloud-versioning-demo
  url: s3://cloud-versioning-demo
  type: url
```

```yaml
schema: '2.0'
datasets:
- name: dogs
  url: dvcx://dogs
  type: dvcx
  version: 3
  created_at: '2023-12-11T10:32:05.942708+00:00'
- name: example-get-started
  url: [email protected]:iterative/example-get-started.git
  type: dvc
  path: path
  rev_lock: df75c16ef61f0772d6e4bb27ba4617b06b4b5398
- name: cloud-versioning-demo
  url: s3://cloud-versioning-demo
  type: url
  meta:
    isdir: true
    size: 323919
    nfiles: 33
  files:
  - relpath: myproject/model.pt
    meta:
      size: 106433
      version_id: 5qrtnhnQ4fBzV73kqqK6pMGhTOzd_IPr
      etag: 3bc0028677ce6fb65bec8090c248b002
# truncated
```

The pipeline stages keep them in `dataset` section of individual deps.

```console
dvc stage add -n train -d ds://dogs python train.py
cat dvc.yaml
```

```yaml
# truncated
stages:
  train:
    cmd: python train.py
    deps:
    - ds://dogs
```

When it is reproduced, the `dvc.lock` will look something like follows:
```yaml
# truncated
stages:
  train:
    cmd: python train.py
    deps:
    - path: ds://dogs
      dataset:
        name: dogs
        url: dvcx://dogs
        type: dvcx
        version: 3
        created_at: '2023-12-11T10:32:05.942708+00:00'
```

* remodeling

* remove non-cloud-versioned remotes

* improve deserializing; handle invalidation
  • Loading branch information
skshetry authored Feb 26, 2024
1 parent 89537a7 commit 34d80fb
Show file tree
Hide file tree
Showing 12 changed files with 667 additions and 11 deletions.
2 changes: 2 additions & 0 deletions dvc/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dag,
data,
data_sync,
dataset,
destroy,
diff,
du,
Expand Down Expand Up @@ -67,6 +68,7 @@
dag,
data,
data_sync,
dataset,
destroy,
diff,
du,
Expand Down
205 changes: 205 additions & 0 deletions dvc/commands/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import TYPE_CHECKING, Optional

from dvc.cli import formatter
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import DvcException
from dvc.log import logger

if TYPE_CHECKING:
from rich.text import Text

from dvc.repo.datasets import Dataset, FileInfo

logger = logger.getChild(__name__)


def diff_files(old: list["FileInfo"], new: list["FileInfo"]) -> dict[str, list[str]]:
old_files = {d.relpath: d for d in old}
new_files = {d.relpath: d for d in new}
rest = old_files.keys() & new_files.keys()
return {
"added": list(new_files.keys() - old_files.keys()),
"deleted": list(old_files.keys() - new_files.keys()),
"modified": [p for p in rest if new_files[p] != old_files[p]],
}


class CmdDatasetAdd(CmdBase):
@classmethod
def display(cls, name: str, dataset: "Dataset", action: str = "Adding"):
from dvc.ui import ui

assert dataset.lock

url = dataset.spec.url
ver: str = ""
if dataset.type == "dvcx":
ver = f"v{dataset.lock.version}"
if dataset.type == "dvc":
if dataset.lock.path:
url = f"{url}:/{dataset.lock.path.lstrip('/')}"
if rev := dataset.lock.rev:
ver = rev

ver_part: Optional["Text"] = None
if ver:
ver_part = ui.rich_text.assemble(" @ ", (ver, "repr.number"))
text = ui.rich_text.assemble("(", (url, "repr.url"), ver_part or "", ")")
ui.write(action, ui.rich_text(name, "cyan"), text, styled=True)

def run(self):
from urllib.parse import urlsplit

d = vars(self.args)
url_obj = urlsplit(self.args.url)
if url_obj.scheme == "dvcx":
d["type"] = "dvcx"
elif url_obj.scheme.startswith("dvc"):
d["type"] = "dvc"
protos = tuple(url_obj.scheme.split("+"))
if not protos or protos == ("dvc",) or protos == ("dvc", "ssh"):
d["url"] = url_obj.netloc + url_obj.path
else:
d["url"] = url_obj._replace(scheme=protos[1]).geturl()
else:
d["type"] = "url"

existing = self.repo.datasets.get(self.args.name)
with self.repo.scm_context:
if not self.args.force and existing:
path = self.repo.fs.relpath(existing.manifest_path)
raise DvcException(
f"{self.args.name} already exists in {path}, "
"use the --force to overwrite"
)
dataset = self.repo.datasets.add(**d)
self.display(self.args.name, dataset)


class CmdDatasetUpdate(CmdBase):
def display(self, name: str, dataset: "Dataset", new: "Dataset"):
from dvc.commands.checkout import log_changes
from dvc.ui import ui

if not dataset.lock:
return CmdDatasetAdd.display(name, new, "Updating")
if dataset == new:
ui.write("[yellow]Nothing to update[/]", styled=True)
return

assert new.lock

v: Optional[tuple[str, str]] = None
if dataset.type == "dvcx":
assert new.type == "dvcx"
v = (f"v{dataset.lock.version}", f"v{new.lock.version}")
if dataset.type == "dvc":
assert new.type == "dvc"
v = (f"{dataset.lock.rev_lock[:9]}", f"{new.lock.rev_lock[:9]}")

if v:
part = ui.rich_text.assemble(
(v[0], "repr.number"),
" -> ",
(v[1], "repr.number"),
)
else:
part = ui.rich_text(dataset.spec.url, "repr.url")
changes = ui.rich_text.assemble("(", part, ")")
ui.write("Updating", ui.rich_text(name, "cyan"), changes, styled=True)
if dataset.type == "url":
assert new.type == "url"
stats = diff_files(dataset.lock.files, new.lock.files)
log_changes(stats)

def run(self):
from difflib import get_close_matches

from dvc.repo.datasets import DatasetNotFoundError
from dvc.ui import ui

with self.repo.scm_context:
try:
dataset, new = self.repo.datasets.update(**vars(self.args))
except DatasetNotFoundError:
logger.exception("")
if matches := get_close_matches(self.args.name, self.repo.datasets):
ui.write(
"did you mean?",
ui.rich_text(matches[0], "cyan"),
stderr=True,
styled=True,
)
return 1
self.display(self.args.name, dataset, new)


def add_parser(subparsers, parent_parser):
ds_parser = subparsers.add_parser(
"dataset",
aliases=["ds"],
parents=[parent_parser],
formatter_class=formatter.RawDescriptionHelpFormatter,
)
ds_subparsers = ds_parser.add_subparsers(
dest="cmd",
help="Use `dvc dataset CMD --help` to display command-specific help.",
required=True,
)

dataset_add_help = "Add a dataset."
ds_add_parser = ds_subparsers.add_parser(
"add",
parents=[parent_parser],
description=append_doc_link(dataset_add_help, "dataset/add"),
formatter_class=formatter.RawTextHelpFormatter,
help=dataset_add_help,
)
ds_add_parser.add_argument(
"--url",
required=True,
help="""\
Location of the data to download. Supported URLs:
s3://bucket/key/path
gs://bucket/path/to/file/or/dir
azure://mycontainer/path
remote://remote_name/path/to/file/or/dir (see `dvc remote`)
dvcx://dataset_name
To import data from dvc/git repositories, \
add dvc:// schema to the repo url, e.g:
dvc://[email protected]/iterative/example-get-started.git
dvc+https://github.com/iterative/example-get-started.git""",
)
ds_add_parser.add_argument(
"--name", help="Name of the dataset to add", required=True
)
ds_add_parser.add_argument(
"--rev",
help="Git revision, e.g. SHA, branch, tag "
"(only applicable for dvc/git repository)",
)
ds_add_parser.add_argument(
"--path", help="Path to a file or directory within the git repository"
)
ds_add_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite existing dataset",
)
ds_add_parser.set_defaults(func=CmdDatasetAdd)

dataset_update_help = "Update a dataset."
ds_update_parser = ds_subparsers.add_parser(
"update",
parents=[parent_parser],
description=append_doc_link(dataset_update_help, "dataset/add"),
formatter_class=formatter.RawDescriptionHelpFormatter,
help=dataset_update_help,
)
ds_update_parser.add_argument("name", help="Name of the dataset to update")
ds_update_parser.set_defaults(func=CmdDatasetUpdate)
8 changes: 4 additions & 4 deletions dvc/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from sqlalchemy import create_engine
from sqlalchemy.engine import make_url as _make_url
from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy import create_engine # type: ignore[import]
from sqlalchemy.engine import make_url as _make_url # type: ignore[import]
from sqlalchemy.exc import NoSuchModuleError # type: ignore[import]

from dvc import env
from dvc.exceptions import DvcException
Expand All @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from sqlalchemy.engine import URL, Connectable, Engine
from sqlalchemy.sql.expression import Selectable
from sqlalchemy.sql.expression import Selectable # type: ignore[import]


logger = logger.getChild(__name__)
Expand Down
21 changes: 18 additions & 3 deletions dvc/dependency/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from funcy import compact, merge

from dvc.exceptions import DvcException
from dvc_data.hashfile.hash_info import HashInfo

from .db import AbstractDependency
Expand Down Expand Up @@ -47,17 +48,31 @@ def fill_values(self, values=None):
)

def workspace_status(self):
registered = self.repo.index.datasets.get(self.name, {})
ds = self.repo.datasets[self.name]
if not ds.lock:
return {str(self): "not in sync"}

info: dict[str, Any] = self.hash_info.value if self.hash_info else {} # type: ignore[assignment]
if info != registered:
lock = self.repo.datasets._lock_from_info(info)
if not lock:
return {str(self): "new"}
if lock != ds.lock:
return {str(self): "modified"}
return {}

def status(self):
return self.workspace_status()

def get_hash(self):
return HashInfo(self.PARAM_DATASET, self.repo.index.datasets.get(self.name, {}))
ds = self.repo.datasets[self.name]
if not ds.lock:
if ds._invalidated:
raise DvcException(
"Dataset information is not in sync. "
f"Run 'dvc ds update {self.name}' to sync."
)
raise DvcException("Dataset information missing from dvc.lock file")
return HashInfo(self.PARAM_DATASET, ds.lock.to_dict()) # type: ignore[arg-type]

def save(self):
self.hash_info = self.get_hash()
Expand Down
3 changes: 2 additions & 1 deletion dvc/dependency/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from dvc.output import Output
from dvc.repo import Repo
from dvc.stage import Stage

logger = logger.getChild(__name__)
Expand All @@ -33,7 +34,7 @@ class AbstractDependency(Dependency):
"""Dependency without workspace/fs/fs_path"""

def __init__(self, stage: "Stage", info: dict[str, Any], *args, **kwargs):
self.repo = stage.repo
self.repo: "Repo" = stage.repo
self.stage = stage
self.fs = None
self.fs_path = None
Expand Down
37 changes: 37 additions & 0 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class SingleStageFile(FileMixin):
from dvc.stage.loader import SingleStageLoader as LOADER # noqa: N814

datasets: ClassVar[list[dict[str, Any]]] = []
datasets_lock: ClassVar[list[dict[str, Any]]] = []
metrics: ClassVar[list[str]] = []
plots: ClassVar[Any] = {}
params: ClassVar[list[str]] = []
Expand Down Expand Up @@ -240,6 +241,20 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs):
if update_lock:
self._dump_lockfile(stage, **kwargs)

def dump_dataset(self, dataset):
with modify_yaml(self.path, fs=self.repo.fs) as data:
datasets: list[dict] = data.setdefault("datasets", [])
loc = next(
(i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]),
None,
)
if loc is not None:
apply_diff(dataset, datasets[loc])
datasets[loc] = dataset
else:
datasets.append(dataset)
self.repo.scm_context.track_file(self.relpath)

def _dump_lockfile(self, stage, **kwargs):
self._lockfile.dump(stage, **kwargs)

Expand Down Expand Up @@ -308,6 +323,10 @@ def params(self) -> list[str]:
def datasets(self) -> list[dict[str, Any]]:
return self.contents.get("datasets", [])

@property
def datasets_lock(self) -> list[dict[str, Any]]:
return self.lockfile_contents.get("datasets", [])

@property
def artifacts(self) -> dict[str, Optional[dict[str, Any]]]:
return self.contents.get("artifacts", {})
Expand Down Expand Up @@ -357,6 +376,24 @@ def _load(self, **kwargs: Any):
self._check_gitignored()
return {}, ""

def dump_dataset(self, dataset: dict):
with modify_yaml(self.path, fs=self.repo.fs) as data:
data.update({"schema": "2.0"})
if not data:
logger.info("Generating lock file '%s'", self.relpath)

datasets: list[dict] = data.setdefault("datasets", [])
loc = next(
(i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]),
None,
)
if loc is not None:
datasets[loc] = dataset
else:
datasets.append(dataset)
data.setdefault("stages", {})
self.repo.scm_context.track_file(self.relpath)

def dump(self, stage, **kwargs):
stage_data = serialize.to_lockfile(stage, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def dumpd(self, **kwargs): # noqa: C901, PLR0912

ret: dict[str, Any] = {}
with_files = (
(not self.IS_DEPENDENCY or self.stage.is_import)
(not self.IS_DEPENDENCY or kwargs.get("datasets") or self.stage.is_import)
and self.hash_info.isdir
and (kwargs.get("with_files") or self.files is not None)
)
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__( # noqa: PLR0915, PLR0913
from dvc.fs import GitFileSystem, LocalFileSystem, localfs
from dvc.lock import LockNoop, make_lock
from dvc.repo.artifacts import Artifacts
from dvc.repo.datasets import Datasets
from dvc.repo.metrics import Metrics
from dvc.repo.params import Params
from dvc.repo.plots import Plots
Expand Down Expand Up @@ -220,6 +221,7 @@ def __init__( # noqa: PLR0915, PLR0913
self.plots: Plots = Plots(self)
self.params: Params = Params(self)
self.artifacts: Artifacts = Artifacts(self)
self.datasets: Datasets = Datasets(self)

self.stage_collection_error_handler: Optional[
Callable[[str, Exception], None]
Expand Down
Loading

0 comments on commit 34d80fb

Please sign in to comment.