Skip to content

Commit

Permalink
ds add: use --dvc/--dvcx/--url flags to accept url (#10326)
Browse files Browse the repository at this point in the history
* ds add: use --dvc/--dvcx/--url flags to accept url

.. instead of relying on `scheme://`

I think this is more simpler from user's perspective, easier to
teach and easier to maintain as we won't have to do url parsing.

I am taking inspiration from `cargo install` has `--git`/`--rev`, etc.
to install from alternate sources.

```console
Options:
      --version <VERSION>     Specify a version to install
      --index <INDEX>         Registry index to install from
      --registry <REGISTRY>   Registry to use
      --git <URL>             Git URL to install the specified crate from
      --branch <BRANCH>       Branch to use when installing from git
      --tag <TAG>             Tag to use when installing from git
      --rev <SHA>             Specific commit to use when installing from git
      --path <PATH>           Filesystem path to local crate to install
      --root <DIR>            Directory to install packages into```
```

The `name` is also changed from an option to an argument.

### Examples

```console
dvc ds add example-get-started --dvc [email protected]:iterative/example-get-started.git
dvc ds add dogs --dvcx dogs
dvc ds add dogs --url s3://bucket/key/path
```

* add --version alias
  • Loading branch information
skshetry authored Feb 27, 2024
1 parent 4cb658f commit 9bb4501
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 100 deletions.
48 changes: 32 additions & 16 deletions dvc/commands/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def display(cls, name: str, dataset: "Dataset", action: str = "Adding"):
ui.write(action, ui.rich_text(name, "cyan"), text, styled=True)

def run(self):
if not self.args.dvc and self.args.rev:
raise DvcException("--rev can't be used without --dvc")
if not self.args.dvc and self.args.path:
raise DvcException("--path can't be used without --dvc")

d = vars(self.args)
for key in ["dvc", "dvcx", "url"]:
if url := d.pop(key, None):
d.update({"type": key, "url": url})
break

existing = self.repo.datasets.get(self.args.name)
with self.repo.scm_context:
if not self.args.force and existing:
Expand All @@ -57,7 +68,7 @@ def run(self):
f"{self.args.name} already exists in {path}, "
"use the --force to overwrite"
)
dataset = self.repo.datasets.add(**vars(self.args))
dataset = self.repo.datasets.add(**d)
self.display(self.args.name, dataset)
return 0

Expand Down Expand Up @@ -154,33 +165,37 @@ def add_parser(subparsers, parent_parser):
formatter_class=formatter.RawTextHelpFormatter,
help=dataset_add_help,
)
ds_add_parser.add_argument(

url_exclusive_group = ds_add_parser.add_mutually_exclusive_group(required=True)
url_exclusive_group.add_argument(
"--dvcx", metavar="name", help="Name of the dvcx dataset to track"
)
url_exclusive_group.add_argument(
"--dvc",
help="Path or URL to a Git/DVC repository to track",
metavar="url",
)
url_exclusive_group.add_argument(
"--url",
required=True,
help="""\
Location of the data to download. Supported URLs:
URL of a cloud-versioned remote to track. Supported URLs:
s3://bucket/key/path
gs://bucket/path/to/file/or/dir
azure://mycontainer/path
remote://remote_name/path/to/file/or/dir (see `dvc remote`)
dvcx://dataset_name
To import data from dvc/git repositories, \
add dvc:// schema to the repo url, e.g:
dvc://[email protected]:iterative/example-get-started.git
dvc+https://github.com/iterative/example-get-started.git""",
)
ds_add_parser.add_argument(
"--name", help="Name of the dataset to add", required=True
""",
)
ds_add_parser.add_argument("name", help="Name of the dataset to add")
ds_add_parser.add_argument(
"--rev",
help="Git revision, e.g. SHA, branch, tag "
"(only applicable for dvc/git repository)",
help="Git revision, e.g. SHA, branch, tag (only applicable with --dvc)",
metavar="<commit>",
)
ds_add_parser.add_argument(
"--path", help="Path to a file or directory within the git repository"
"--path",
help="Path to a file or a directory within a git repository "
"(only applicable with --dvc)",
)
ds_add_parser.add_argument(
"-f",
Expand All @@ -202,6 +217,7 @@ def add_parser(subparsers, parent_parser):
ds_update_parser.add_argument("name", help="Name of the dataset to update")
ds_update_parser.add_argument(
"--rev",
"--version",
nargs="?",
help="DVCX dataset version or Git revision (e.g. SHA, branch, tag)",
metavar="<version>",
Expand Down
35 changes: 11 additions & 24 deletions dvc/repo/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,6 @@
logger = logger.getChild(__name__)


def parse_url_and_type(url: str):
from urllib.parse import urlsplit

if os.path.exists(url):
return {"type": "dvc", "url": url}

url_obj = urlsplit(url)
if url_obj.scheme == "dvcx":
return {"type": "dvcx", "url": url}
if url_obj.scheme and not url_obj.scheme.startswith("dvc"):
return {"type": "url", "url": url}

protos = tuple(url_obj.scheme.split("+"))
if not protos or protos == ("dvc",):
url = url_obj.netloc + url_obj.path
else:
url = url_obj._replace(scheme=protos[1]).geturl()
return {"type": "dvc", "url": url}


def _get_dataset_record(name: str) -> "DatasetRecord":
from dvc.exceptions import DvcException

Expand Down Expand Up @@ -189,10 +169,15 @@ class DVCXDataset:

type: ClassVar[Literal["dvcx"]] = "dvcx"

@property
def pinned(self) -> bool:
return self.name_version[1] is not None

@property
def name_version(self) -> tuple[str, Optional[int]]:
url = urlparse(self.spec.url)
parts = url.netloc.split("@v")
path = url.netloc + url.path
parts = path.split("@v")
assert parts

name = parts[0]
Expand Down Expand Up @@ -384,13 +369,15 @@ def _build_dataset(

def add(
self,
url: str,
name: str,
url: str,
type: str, # noqa: A002
manifest_path: StrPath = "dvc.yaml",
**kwargs: Any,
) -> Dataset:
spec = kwargs | parse_url_and_type(url) | {"name": name}
dataset = self._build_dataset(os.path.abspath(manifest_path), spec)
assert type in {"dvc", "dvcx", "url"}
kwargs.update({"name": name, "url": url, "type": type})
dataset = self._build_dataset(os.path.abspath(manifest_path), kwargs)
dataset = dataset.update(self.repo)

self.dump(dataset)
Expand Down
26 changes: 13 additions & 13 deletions tests/func/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_dvc(tmp_dir, scm, dvc: "Repo"):
datasets = dvc.datasets

tmp_dir.scm_gen("file", "file", commit="add file")
dataset = datasets.add(tmp_dir.fs_path, name="mydataset", path="file")
dataset = datasets.add("mydataset", tmp_dir.fs_path, "dvc", path="file")
expected = DVCDataset(
manifest_path=(tmp_dir / "dvc.yaml").fs_path,
spec=DVCDatasetSpec(
Expand Down Expand Up @@ -84,13 +84,13 @@ def test_dvcx(tmp_dir, dvc, mocker):
version_info.append(version_info[1])
mocker.patch("dvc.repo.datasets._get_dataset_info", side_effect=version_info)

dataset = datasets.add(url="dvcx://dataset", name="mydataset")
dataset = datasets.add("mydataset", "dataset", "dvcx")
expected = DVCXDataset(
manifest_path=(tmp_dir / "dvc.yaml").fs_path,
spec=DatasetSpec(name="mydataset", url="dvcx://dataset", type="dvcx"),
spec=DatasetSpec(name="mydataset", url="dataset", type="dvcx"),
lock=DVCXDatasetLock(
name="mydataset",
url="dvcx://dataset",
url="dataset",
type="dvcx",
version=1,
created_at=version_info[0].created_at,
Expand Down Expand Up @@ -125,7 +125,7 @@ def mocked_save(d):

mocker.patch.object(Dependency, "save", mocked_save)

dataset = datasets.add(url="s3://dataset", name="mydataset")
dataset = datasets.add("mydataset", "s3://dataset", "url")
expected = URLDataset(
manifest_path=(tmp_dir / "dvc.yaml").fs_path,
spec=DatasetSpec(name="mydataset", url="s3://dataset", type="url"),
Expand Down Expand Up @@ -191,14 +191,14 @@ def test_dvc_dump(tmp_dir, dvc):

def test_dvcx_dump(tmp_dir, dvc):
manifest_path = os.path.join(tmp_dir, "dvc.yaml")
spec = DatasetSpec(name="mydataset", url="dvcx://dataset", type="dvcx")
spec = DatasetSpec(name="mydataset", url="dataset", type="dvcx")
dt = datetime.now(tz=timezone.utc)
lock = DVCXDatasetLock(version=1, created_at=dt, **spec.to_dict())
dataset = DVCXDataset(manifest_path=manifest_path, spec=spec, lock=lock)

dvc.datasets.dump(dataset)

spec_d = {"name": "mydataset", "type": "dvcx", "url": "dvcx://dataset"}
spec_d = {"name": "mydataset", "type": "dvcx", "url": "dataset"}
assert (tmp_dir / "dvc.yaml").parse() == {"datasets": [spec_d]}
assert (tmp_dir / "dvc.lock").parse() == {
"schema": "2.0",
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_invalidation(tmp_dir, dvc):
spec = DatasetSpec(name="mydataset", url="url1", type="url")
lock = DVCXDatasetLock(
name="mydataset",
url="dvcx://dataset",
url="dataset",
type="dvcx",
version=1,
created_at=datetime.now(tz=timezone.utc),
Expand All @@ -262,7 +262,7 @@ def test_invalidation(tmp_dir, dvc):


def test_dvc_dataset_pipeline(tmp_dir, dvc, scm):
dvc.datasets.add(name="mydataset", url=tmp_dir.fs_path)
dvc.datasets.add("mydataset", tmp_dir.fs_path, "dvc")

stage = dvc.stage.add(cmd="echo", name="train", deps=["ds://mydataset"])
assert (tmp_dir / "dvc.yaml").parse() == {
Expand Down Expand Up @@ -293,11 +293,11 @@ def test_dvcx_dataset_pipeline(mocker, tmp_dir, dvc):
version_info = [MockedDVCXVersionInfo(1), MockedDVCXVersionInfo(2)]
mocker.patch("dvc.repo.datasets._get_dataset_info", side_effect=version_info)

dvc.datasets.add(name="mydataset", url="dvcx://mydataset")
dvc.datasets.add("mydataset", "dataset", "dvcx")

stage = dvc.stage.add(cmd="echo", name="train", deps=["ds://mydataset"])
assert (tmp_dir / "dvc.yaml").parse() == {
"datasets": [{"name": "mydataset", "url": "dvcx://mydataset", "type": "dvcx"}],
"datasets": [{"name": "mydataset", "url": "dataset", "type": "dvcx"}],
"stages": {"train": {"cmd": "echo", "deps": ["ds://mydataset"]}},
}

Expand Down Expand Up @@ -330,7 +330,7 @@ def mocked_save(d):

mocker.patch.object(Dependency, "save", mocked_save)

dvc.datasets.add(name="mydataset", url="s3://mydataset")
dvc.datasets.add("mydataset", "s3://mydataset", "url")

stage = dvc.stage.add(cmd="echo", name="train", deps=["ds://mydataset"])
assert (tmp_dir / "dvc.yaml").parse() == {
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_pipeline_when_not_in_sync(tmp_dir, dvc):
spec = DatasetSpec(name="mydataset", url="url1", type="url")
lock = DVCXDatasetLock(
name="mydataset",
url="dvcx://dataset",
url="dataset",
type="dvcx",
version=1,
created_at=datetime.now(tz=timezone.utc),
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/command/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_add(dvc, capsys, mocker, spec, lock, expected_output):

m = mocker.patch("dvc.repo.datasets.Datasets.add", return_value=dataset)

assert main(["dataset", "add", "--url", spec["url"], "--name", spec["name"]]) == 0
assert main(["dataset", "add", spec["name"], f"--{spec['type']}", spec["url"]]) == 0
out, err = capsys.readouterr()
assert out == expected_output
assert not err
Expand All @@ -45,7 +45,7 @@ def test_add_already_exists(dvc, caplog, mocker):
dataset = dvc.datasets._build_dataset("dvc.yaml", spec, None)
mocker.patch("dvc.repo.datasets.Datasets.get", return_value=dataset)

assert main(["dataset", "add", "--url", "url", "--name", "ds"]) == 255
assert main(["dataset", "add", "ds", "--dvcx", "dataset"]) == 255
assert "ds already exists in dvc.yaml, use the --force to overwrite" in caplog.text


Expand Down
45 changes: 0 additions & 45 deletions tests/unit/test_dataset.py

This file was deleted.

0 comments on commit 9bb4501

Please sign in to comment.