diff --git a/dvc/cli/parser.py b/dvc/cli/parser.py index aaeca86bc7..95ca31ce63 100644 --- a/dvc/cli/parser.py +++ b/dvc/cli/parser.py @@ -17,6 +17,7 @@ dag, data, data_sync, + dataset, destroy, diff, du, @@ -67,6 +68,7 @@ dag, data, data_sync, + dataset, destroy, diff, du, diff --git a/dvc/commands/dataset.py b/dvc/commands/dataset.py new file mode 100644 index 0000000000..9cbc95c339 --- /dev/null +++ b/dvc/commands/dataset.py @@ -0,0 +1,205 @@ +from typing import TYPE_CHECKING, Optional + +from dvc.cli import formatter +from dvc.cli.command import CmdBase +from dvc.cli.utils import append_doc_link +from dvc.exceptions import DvcException +from dvc.log import logger + +if TYPE_CHECKING: + from rich.text import Text + + from dvc.repo.datasets import Dataset, FileInfo + +logger = logger.getChild(__name__) + + +def diff_files(old: list["FileInfo"], new: list["FileInfo"]) -> dict[str, list[str]]: + old_files = {d.relpath: d for d in old} + new_files = {d.relpath: d for d in new} + rest = old_files.keys() & new_files.keys() + return { + "added": list(new_files.keys() - old_files.keys()), + "deleted": list(old_files.keys() - new_files.keys()), + "modified": [p for p in rest if new_files[p] != old_files[p]], + } + + +class CmdDatasetAdd(CmdBase): + @classmethod + def display(cls, name: str, dataset: "Dataset", action: str = "Adding"): + from dvc.ui import ui + + assert dataset.lock + + url = dataset.spec.url + ver: str = "" + if dataset.type == "dvcx": + ver = f"v{dataset.lock.version}" + if dataset.type == "dvc": + if dataset.lock.path: + url = f"{url}:/{dataset.lock.path.lstrip('/')}" + if rev := dataset.lock.rev: + ver = rev + + ver_part: Optional["Text"] = None + if ver: + ver_part = ui.rich_text.assemble(" @ ", (ver, "repr.number")) + text = ui.rich_text.assemble("(", (url, "repr.url"), ver_part or "", ")") + ui.write(action, ui.rich_text(name, "cyan"), text, styled=True) + + def run(self): + from urllib.parse import urlsplit + + d = vars(self.args) + url_obj = urlsplit(self.args.url) + if url_obj.scheme == "dvcx": + d["type"] = "dvcx" + elif url_obj.scheme.startswith("dvc"): + d["type"] = "dvc" + protos = tuple(url_obj.scheme.split("+")) + if not protos or protos == ("dvc",) or protos == ("dvc", "ssh"): + d["url"] = url_obj.netloc + url_obj.path + else: + d["url"] = url_obj._replace(scheme=protos[1]).geturl() + else: + d["type"] = "url" + + existing = self.repo.datasets.get(self.args.name) + with self.repo.scm_context: + if not self.args.force and existing: + path = self.repo.fs.relpath(existing.manifest_path) + raise DvcException( + f"{self.args.name} already exists in {path}, " + "use the --force to overwrite" + ) + dataset = self.repo.datasets.add(**d) + self.display(self.args.name, dataset) + + +class CmdDatasetUpdate(CmdBase): + def display(self, name: str, dataset: "Dataset", new: "Dataset"): + from dvc.commands.checkout import log_changes + from dvc.ui import ui + + if not dataset.lock: + return CmdDatasetAdd.display(name, new, "Updating") + if dataset == new: + ui.write("[yellow]Nothing to update[/]", styled=True) + return + + assert new.lock + + v: Optional[tuple[str, str]] = None + if dataset.type == "dvcx": + assert new.type == "dvcx" + v = (f"v{dataset.lock.version}", f"v{new.lock.version}") + if dataset.type == "dvc": + assert new.type == "dvc" + v = (f"{dataset.lock.rev_lock[:9]}", f"{new.lock.rev_lock[:9]}") + + if v: + part = ui.rich_text.assemble( + (v[0], "repr.number"), + " -> ", + (v[1], "repr.number"), + ) + else: + part = ui.rich_text(dataset.spec.url, "repr.url") + changes = ui.rich_text.assemble("(", part, ")") + ui.write("Updating", ui.rich_text(name, "cyan"), changes, styled=True) + if dataset.type == "url": + assert new.type == "url" + stats = diff_files(dataset.lock.files, new.lock.files) + log_changes(stats) + + def run(self): + from difflib import get_close_matches + + from dvc.repo.datasets import DatasetNotFoundError + from dvc.ui import ui + + with self.repo.scm_context: + try: + dataset, new = self.repo.datasets.update(**vars(self.args)) + except DatasetNotFoundError: + logger.exception("") + if matches := get_close_matches(self.args.name, self.repo.datasets): + ui.write( + "did you mean?", + ui.rich_text(matches[0], "cyan"), + stderr=True, + styled=True, + ) + return 1 + self.display(self.args.name, dataset, new) + + +def add_parser(subparsers, parent_parser): + ds_parser = subparsers.add_parser( + "dataset", + aliases=["ds"], + parents=[parent_parser], + formatter_class=formatter.RawDescriptionHelpFormatter, + ) + ds_subparsers = ds_parser.add_subparsers( + dest="cmd", + help="Use `dvc dataset CMD --help` to display command-specific help.", + required=True, + ) + + dataset_add_help = "Add a dataset." + ds_add_parser = ds_subparsers.add_parser( + "add", + parents=[parent_parser], + description=append_doc_link(dataset_add_help, "dataset/add"), + formatter_class=formatter.RawTextHelpFormatter, + help=dataset_add_help, + ) + ds_add_parser.add_argument( + "--url", + required=True, + help="""\ +Location of the data to download. Supported URLs: + +s3://bucket/key/path +gs://bucket/path/to/file/or/dir +azure://mycontainer/path +remote://remote_name/path/to/file/or/dir (see `dvc remote`) +dvcx://dataset_name + +To import data from dvc/git repositories, \ +add dvc:// schema to the repo url, e.g: +dvc://git@github.com/iterative/example-get-started.git +dvc+https://github.com/iterative/example-get-started.git""", + ) + ds_add_parser.add_argument( + "--name", help="Name of the dataset to add", required=True + ) + ds_add_parser.add_argument( + "--rev", + help="Git revision, e.g. SHA, branch, tag " + "(only applicable for dvc/git repository)", + ) + ds_add_parser.add_argument( + "--path", help="Path to a file or directory within the git repository" + ) + ds_add_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Overwrite existing dataset", + ) + ds_add_parser.set_defaults(func=CmdDatasetAdd) + + dataset_update_help = "Update a dataset." + ds_update_parser = ds_subparsers.add_parser( + "update", + parents=[parent_parser], + description=append_doc_link(dataset_update_help, "dataset/add"), + formatter_class=formatter.RawDescriptionHelpFormatter, + help=dataset_update_help, + ) + ds_update_parser.add_argument("name", help="Name of the dataset to update") + ds_update_parser.set_defaults(func=CmdDatasetUpdate) diff --git a/dvc/database.py b/dvc/database.py index 4289779b3e..01e038e4e0 100644 --- a/dvc/database.py +++ b/dvc/database.py @@ -5,9 +5,9 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from sqlalchemy import create_engine -from sqlalchemy.engine import make_url as _make_url -from sqlalchemy.exc import NoSuchModuleError +from sqlalchemy import create_engine # type: ignore[import] +from sqlalchemy.engine import make_url as _make_url # type: ignore[import] +from sqlalchemy.exc import NoSuchModuleError # type: ignore[import] from dvc import env from dvc.exceptions import DvcException @@ -17,7 +17,7 @@ if TYPE_CHECKING: from sqlalchemy.engine import URL, Connectable, Engine - from sqlalchemy.sql.expression import Selectable + from sqlalchemy.sql.expression import Selectable # type: ignore[import] logger = logger.getChild(__name__) diff --git a/dvc/dependency/dataset.py b/dvc/dependency/dataset.py index 37dcea30ee..e9f3128b6e 100644 --- a/dvc/dependency/dataset.py +++ b/dvc/dependency/dataset.py @@ -3,6 +3,7 @@ from funcy import compact, merge +from dvc.exceptions import DvcException from dvc_data.hashfile.hash_info import HashInfo from .db import AbstractDependency @@ -47,9 +48,15 @@ def fill_values(self, values=None): ) def workspace_status(self): - registered = self.repo.index.datasets.get(self.name, {}) + ds = self.repo.datasets[self.name] + if not ds.lock: + return {str(self): "not in sync"} + info: dict[str, Any] = self.hash_info.value if self.hash_info else {} # type: ignore[assignment] - if info != registered: + lock = self.repo.datasets._lock_from_info(info) + if not lock: + return {str(self): "new"} + if lock != ds.lock: return {str(self): "modified"} return {} @@ -57,7 +64,15 @@ def status(self): return self.workspace_status() def get_hash(self): - return HashInfo(self.PARAM_DATASET, self.repo.index.datasets.get(self.name, {})) + ds = self.repo.datasets[self.name] + if not ds.lock: + if ds._invalidated: + raise DvcException( + "Dataset information is not in sync. " + f"Run 'dvc ds update {self.name}' to sync." + ) + raise DvcException("Dataset information missing from dvc.lock file") + return HashInfo(self.PARAM_DATASET, ds.lock.to_dict()) # type: ignore[arg-type] def save(self): self.hash_info = self.get_hash() diff --git a/dvc/dependency/db.py b/dvc/dependency/db.py index 2d663150c4..6cdbd169ec 100644 --- a/dvc/dependency/db.py +++ b/dvc/dependency/db.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from dvc.output import Output + from dvc.repo import Repo from dvc.stage import Stage logger = logger.getChild(__name__) @@ -33,7 +34,7 @@ class AbstractDependency(Dependency): """Dependency without workspace/fs/fs_path""" def __init__(self, stage: "Stage", info: dict[str, Any], *args, **kwargs): - self.repo = stage.repo + self.repo: "Repo" = stage.repo self.stage = stage self.fs = None self.fs_path = None diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 9837549ce4..134378f3be 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -172,6 +172,7 @@ class SingleStageFile(FileMixin): from dvc.stage.loader import SingleStageLoader as LOADER # noqa: N814 datasets: ClassVar[list[dict[str, Any]]] = [] + datasets_lock: ClassVar[list[dict[str, Any]]] = [] metrics: ClassVar[list[str]] = [] plots: ClassVar[Any] = {} params: ClassVar[list[str]] = [] @@ -240,6 +241,20 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs): if update_lock: self._dump_lockfile(stage, **kwargs) + def dump_dataset(self, dataset): + with modify_yaml(self.path, fs=self.repo.fs) as data: + datasets: list[dict] = data.setdefault("datasets", []) + loc = next( + (i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]), + None, + ) + if loc is not None: + apply_diff(dataset, datasets[loc]) + datasets[loc] = dataset + else: + datasets.append(dataset) + self.repo.scm_context.track_file(self.relpath) + def _dump_lockfile(self, stage, **kwargs): self._lockfile.dump(stage, **kwargs) @@ -308,6 +323,10 @@ def params(self) -> list[str]: def datasets(self) -> list[dict[str, Any]]: return self.contents.get("datasets", []) + @property + def datasets_lock(self) -> list[dict[str, Any]]: + return self.lockfile_contents.get("datasets", []) + @property def artifacts(self) -> dict[str, Optional[dict[str, Any]]]: return self.contents.get("artifacts", {}) @@ -357,6 +376,24 @@ def _load(self, **kwargs: Any): self._check_gitignored() return {}, "" + def dump_dataset(self, dataset: dict): + with modify_yaml(self.path, fs=self.repo.fs) as data: + data.update({"schema": "2.0"}) + if not data: + logger.info("Generating lock file '%s'", self.relpath) + + datasets: list[dict] = data.setdefault("datasets", []) + loc = next( + (i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]), + None, + ) + if loc is not None: + datasets[loc] = dataset + else: + datasets.append(dataset) + data.setdefault("stages", {}) + self.repo.scm_context.track_file(self.relpath) + def dump(self, stage, **kwargs): stage_data = serialize.to_lockfile(stage, **kwargs) diff --git a/dvc/output.py b/dvc/output.py index e94dd00e83..92a982cc25 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -814,7 +814,7 @@ def dumpd(self, **kwargs): # noqa: C901, PLR0912 ret: dict[str, Any] = {} with_files = ( - (not self.IS_DEPENDENCY or self.stage.is_import) + (not self.IS_DEPENDENCY or kwargs.get("datasets") or self.stage.is_import) and self.hash_info.isdir and (kwargs.get("with_files") or self.files is not None) ) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index c10965f432..3e14069d43 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -150,6 +150,7 @@ def __init__( # noqa: PLR0915, PLR0913 from dvc.fs import GitFileSystem, LocalFileSystem, localfs from dvc.lock import LockNoop, make_lock from dvc.repo.artifacts import Artifacts + from dvc.repo.datasets import Datasets from dvc.repo.metrics import Metrics from dvc.repo.params import Params from dvc.repo.plots import Plots @@ -220,6 +221,7 @@ def __init__( # noqa: PLR0915, PLR0913 self.plots: Plots = Plots(self) self.params: Params = Params(self) self.artifacts: Artifacts = Artifacts(self) + self.datasets: Datasets = Datasets(self) self.stage_collection_error_handler: Optional[ Callable[[str, Exception], None] diff --git a/dvc/repo/datasets.py b/dvc/repo/datasets.py new file mode 100644 index 0000000000..4188d9f5af --- /dev/null +++ b/dvc/repo/datasets.py @@ -0,0 +1,385 @@ +from collections.abc import Iterator, Mapping +from datetime import datetime +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union, cast +from urllib.parse import urlparse + +from attrs import Attribute, AttrsInstance, asdict, evolve, field, fields, frozen +from attrs.converters import default_if_none + +from dvc.dvcfile import Lockfile, ProjectFile +from dvc.exceptions import DvcException +from dvc.log import logger +from dvc_data.hashfile.meta import Meta + +if TYPE_CHECKING: + from dql.dataset import DatasetRecord # type: ignore[import] + from typing_extensions import Self + + from dvc.repo import Repo + + +logger = logger.getChild(__name__) + + +def default_str(v) -> str: + return default_if_none("")(v) + + +def to_datetime(d: Union[str, datetime]) -> datetime: + return datetime.fromisoformat(d) if isinstance(d, str) else d + + +def ensure(cls): + def inner(v): + return cls.from_dict(v) if isinstance(v, dict) else v + + return inner + + +class SerDe: + def to_dict(self: AttrsInstance) -> dict[str, Any]: + def filter_defaults(attr: Attribute, v: Any): + if attr.metadata.get("exclude_falsy", False) and not v: + return False + return attr.default != v + + def value_serializer(_inst, _field, v): + return v.isoformat() if isinstance(v, datetime) else v + + return asdict(self, filter=filter_defaults, value_serializer=value_serializer) + + @classmethod + def from_dict(cls: type["Self"], d: dict[str, Any]) -> "Self": + _fields = fields(cast("type[AttrsInstance]", cls)) + kwargs = {f.name: d[f.name] for f in _fields if f.name in d} + return cls(**kwargs) + + +@frozen(kw_only=True) +class DatasetSpec(SerDe): + name: str + url: str + type: Literal["dvc", "dvcx", "url"] + + +@frozen(kw_only=True) +class DVCDatasetSpec(DatasetSpec): + type: Literal["dvc"] + path: str = field(default="", converter=default_str) + rev: Optional[str] = None + + +@frozen(kw_only=True, order=True) +class FileInfo(SerDe): + relpath: str + meta: Meta = field(order=False, converter=ensure(Meta)) # type: ignore[misc] + + +@frozen(kw_only=True) +class DVCDatasetLock(DVCDatasetSpec): + rev_lock: str + + +@frozen(kw_only=True) +class DVCXDatasetLock(DatasetSpec): + version: int + created_at: datetime = field(converter=to_datetime) + + +@frozen(kw_only=True) +class URLDatasetLock(DatasetSpec): + meta: Meta = field(converter=ensure(Meta)) # type: ignore[misc] + files: list[FileInfo] = field( + factory=list, + converter=lambda f: sorted(map(ensure(FileInfo), f)), + metadata={"exclude_falsy": True}, + ) + + +def to_spec(lock: "Lock") -> "Spec": + cls = DVCDatasetSpec if lock.type == "dvc" else DatasetSpec + return cls(**{f.name: getattr(lock, f.name) for f in fields(cls)}) + + +@frozen(kw_only=True) +class DVCDataset: + manifest_path: str + spec: DVCDatasetSpec + lock: Optional[DVCDatasetLock] = None + _invalidated: bool = field(default=False, eq=False, repr=False) + + type: ClassVar[Literal["dvc"]] = "dvc" + + def update(self, repo, rev: Optional[str] = None, **kwargs) -> "Self": + from dvc.dependency import RepoDependency + + spec = self.spec + if rev: + spec = evolve(self.spec, rev=rev) + + def_repo = { + RepoDependency.PARAM_REV: spec.rev, + RepoDependency.PARAM_URL: spec.url, + } + dep = RepoDependency(def_repo, None, spec.path, repo=repo) # type: ignore[arg-type] + dep.save() + d = dep.dumpd() + + repo_info = d[RepoDependency.PARAM_REPO] + assert isinstance(repo_info, dict) + rev_lock = repo_info[RepoDependency.PARAM_REV_LOCK] + lock = DVCDatasetLock(**spec.to_dict(), rev_lock=rev_lock) + return evolve(self, spec=spec, lock=lock) + + +@frozen(kw_only=True) +class DVCXDataset: + manifest_path: str + spec: "DatasetSpec" + lock: "Optional[DVCXDatasetLock]" = field(default=None) + _invalidated: bool = field(default=False, eq=False, repr=False) + + type: ClassVar[Literal["dvcx"]] = "dvcx" + + @property + def name_version(self) -> tuple[str, Optional[int]]: + url = urlparse(self.spec.url) + parts = url.netloc.split("@v") + assert parts + + name = parts[0] + version = int(parts[1]) if len(parts) > 1 else None + return name, version + + def update( + self, + repo, # noqa: ARG002 + record: Optional["DatasetRecord"] = None, + version: Optional[int] = None, + **kwargs, + ) -> "Self": + if not record: + try: + from dvcx.catalog import get_catalog # type: ignore[import] + + except ImportError as exc: + raise DvcException("dvcx is not installed") from exc + + name, _version = self.name_version + version = _version or version + catalog = get_catalog() + record = catalog.get_remote_dataset(name) + + assert record is not None + ver = version or record.latest_version + assert ver + version_info = record.get_version(ver) + lock = DVCXDatasetLock( + **self.spec.to_dict(), + version=version_info.version, + created_at=version_info.created_at, + ) + return evolve(self, lock=lock) + + +@frozen(kw_only=True) +class URLDataset: + manifest_path: str + spec: "DatasetSpec" + lock: "Optional[URLDatasetLock]" = None + _invalidated: bool = field(default=False, eq=False, repr=False) + + type: ClassVar[Literal["url"]] = "url" + + def update(self, repo, **kwargs): + from dvc.dependency import Dependency + + dep = Dependency( + None, self.spec.url, repo=repo, fs_config={"version_aware": True} + ) + dep.save() + d = dep.dumpd(datasets=True) + files = [ + FileInfo(relpath=info["relpath"], meta=Meta.from_dict(info)) + for info in d.get("files", []) + ] + lock = URLDatasetLock(**self.spec.to_dict(), meta=dep.meta, files=files) + return evolve(self, lock=lock) + + +Lock = Union[DVCDatasetLock, DVCXDatasetLock, URLDatasetLock] +Spec = Union[DatasetSpec, DVCDatasetSpec] +Dataset = Union[DVCDataset, DVCXDataset, URLDataset] + + +class DatasetNotFoundError(DvcException, KeyError): + def __init__(self, name, *args): + self.name = name + super().__init__("dataset not found", *args) + + def __str__(self) -> str: + return self.msg + + +class Datasets(Mapping[str, Dataset]): + def __init__(self, repo: "Repo") -> None: + self.repo: "Repo" = repo + + def __repr__(self): + return repr(dict(self)) + + def __rich_repr__(self): + yield dict(self) + + def __getitem__(self, name: str) -> Dataset: + try: + return self._datasets[name] + except KeyError as exc: + raise DatasetNotFoundError(name) from exc + + def __setitem__(self, name: str, dataset: Dataset) -> None: + self._datasets[name] = dataset + + def __contains__(self, name: object) -> bool: + return name in self._datasets + + def __iter__(self) -> Iterator[str]: + return iter(self._datasets) + + def __len__(self) -> int: + return len(self._datasets) + + @cached_property + def _spec(self) -> dict[str, tuple[str, dict[str, Any]]]: + return { + dataset["name"]: (path, dataset) + for path, datasets in self.repo.index._datasets.items() + for dataset in datasets + } + + @cached_property + def _lock(self) -> dict[str, Optional[dict[str, Any]]]: + datasets_lock = self.repo.index._datasets_lock + + def find(path, name) -> Optional[dict[str, Any]]: + lock = datasets_lock.get(path, []) + return next((dataset for dataset in lock if dataset["name"] == name), None) + + return {ds["name"]: find(path, name) for name, (path, ds) in self._spec.items()} + + @cached_property + def _datasets(self) -> dict[str, Dataset]: + return { + name: self._build_dataset(path, spec, self._lock[name]) + for name, (path, spec) in self._spec.items() + } + + def _reset(self) -> None: + self.__dict__.pop("_spec", None) + self.__dict__.pop("_lock", None) + self.__dict__.pop("_datasets", None) + + @staticmethod + def _spec_from_info(spec: dict[str, Any]) -> Spec: + typ = spec.get("type") + if not typ: + raise ValueError("type should be present in spec") + if typ == "dvc": + return DVCDatasetSpec.from_dict(spec) + if typ in {"dvcx", "url"}: + return DatasetSpec.from_dict(spec) + raise ValueError(f"unknown dataset type: {spec.get('type', '')}") + + @staticmethod + def _lock_from_info(lock: Optional[dict[str, Any]]) -> Optional[Lock]: + kl = {"dvc": DVCDatasetLock, "dvcx": DVCXDatasetLock, "url": URLDatasetLock} + if lock and (cls := kl.get(lock.get("type", ""))): # type: ignore[assignment] + return cls.from_dict(lock) # type: ignore[attr-defined] + return None + + @classmethod + def _build_dataset( + cls, + manifest_path: str, + spec_data: dict[str, Any], + lock_data: Optional[dict[str, Any]] = None, + ) -> Dataset: + _invalidated = False + spec = cls._spec_from_info(spec_data) + lock = cls._lock_from_info(lock_data) + # if dvc.lock and dvc.yaml file are not in sync, we invalidate the lock. + if lock is not None and to_spec(lock) != spec: + logger.debug( + "invalidated lock data for %s in %s", + spec.name, + manifest_path, + ) + _invalidated = True # signal is used during `dvc repro`/`dvc status`. + lock = None + + assert isinstance(spec, DatasetSpec) + if spec.type == "dvc": + assert lock is None or isinstance(lock, DVCDatasetLock) + assert isinstance(spec, DVCDatasetSpec) + return DVCDataset( + manifest_path=manifest_path, + spec=spec, + lock=lock, + invalidated=_invalidated, + ) + if spec.type == "url": + assert lock is None or isinstance(lock, URLDatasetLock) + return URLDataset( + manifest_path=manifest_path, + spec=spec, + lock=lock, + invalidated=_invalidated, + ) + if spec.type == "dvcx": + assert lock is None or isinstance(lock, DVCXDatasetLock) + return DVCXDataset( + manifest_path=manifest_path, + spec=spec, + lock=lock, + invalidated=_invalidated, + ) + raise ValueError(f"unknown dataset type: {spec.type!r}") + + def add( + self, url: str, name: str, manifest_path: str = "dvc.yaml", **kwargs: Any + ) -> Dataset: + kwargs.update({"url": url, "name": name}) + dataset = self._build_dataset(manifest_path, kwargs) + dataset = dataset.update(self.repo) + + self.dump(dataset) + self[name] = dataset + return dataset + + def update(self, name, **kwargs) -> tuple[Dataset, Dataset]: + dataset = self[name] + new = dataset.update(self.repo, **kwargs) + + self.dump(new, old=dataset) + self[name] = new + return dataset, new + + def _dump_spec(self, manifest_path: str, spec: Spec) -> None: + spec_data = spec.to_dict() + assert spec_data.keys() & {"type", "name", "url"} + project_file = ProjectFile(self.repo, manifest_path) + project_file.dump_dataset(spec_data) + + def _dump_lock(self, manifest_path: str, lock: Lock) -> None: + lock_data = lock.to_dict() + assert lock_data.keys() & {"type", "name", "url"} + lockfile = Lockfile(self.repo, Path(manifest_path).with_suffix(".lock")) + lockfile.dump_dataset(lock_data) + + def dump(self, dataset: Dataset, old: Optional[Dataset] = None) -> None: + if not old or old.spec != dataset.spec: + self._dump_spec(dataset.manifest_path, dataset.spec) + if dataset.lock and (not old or old.lock != dataset.lock): + self._dump_lock(dataset.manifest_path, dataset.lock) diff --git a/dvc/repo/index.py b/dvc/repo/index.py index ffa624df94..2f764a786d 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -290,6 +290,7 @@ def __init__( params: Optional[dict[str, Any]] = None, artifacts: Optional[dict[str, Any]] = None, datasets: Optional[dict[str, list[dict[str, Any]]]] = None, + datasets_lock: Optional[dict[str, list[dict[str, Any]]]] = None, ) -> None: self.repo = repo self.stages = stages or [] @@ -298,6 +299,7 @@ def __init__( self._params = params or {} self._artifacts = artifacts or {} self._datasets: dict[str, list[dict[str, Any]]] = datasets or {} + self._datasets_lock: dict[str, list[dict[str, Any]]] = datasets_lock or {} self._collected_targets: dict[int, list["StageInfo"]] = {} @cached_property @@ -322,6 +324,7 @@ def from_repo( params = {} artifacts = {} datasets = {} + datasets_lock = {} onerror = onerror or repo.stage_collection_error_handler for _, idx in collect_files(repo, onerror=onerror): @@ -331,6 +334,7 @@ def from_repo( params.update(idx._params) artifacts.update(idx._artifacts) datasets.update(idx._datasets) + datasets_lock.update(idx._datasets_lock) return cls( repo, stages=stages, @@ -339,6 +343,7 @@ def from_repo( params=params, artifacts=artifacts, datasets=datasets, + datasets_lock=datasets_lock, ) @classmethod @@ -354,6 +359,9 @@ def from_file(cls, repo: "Repo", path: str) -> "Index": params={path: dvcfile.params} if dvcfile.params else {}, artifacts={path: dvcfile.artifacts} if dvcfile.artifacts else {}, datasets={path: dvcfile.datasets} if dvcfile.datasets else {}, + datasets_lock={path: dvcfile.datasets_lock} + if dvcfile.datasets_lock + else {}, ) def update(self, stages: Iterable["Stage"]) -> "Index": diff --git a/dvc/schema.py b/dvc/schema.py index 37710ad3b0..85001016fc 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -47,6 +47,7 @@ LOCKFILE_STAGES_SCHEMA = {str: LOCK_FILE_STAGE_SCHEMA} LOCKFILE_SCHEMA = { vol.Required("schema"): vol.Equal("2.0", "invalid schema version"), + "datasets": object, STAGES: LOCKFILE_STAGES_SCHEMA, } @@ -128,7 +129,7 @@ def validator(data): {vol.Required("type"): str, vol.Required("name"): str}, extra=vol.ALLOW_EXTRA ) MULTI_STAGE_SCHEMA = { - "datasets": [DATASET_SCHEMA], + "datasets": object, PLOTS: [vol.Any(str, SINGLE_PLOT_SCHEMA)], STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, VARS_KWD: VARS_SCHEMA, diff --git a/tests/unit/command/test_help.py b/tests/unit/command/test_help.py index 0350c11e29..488cb63f4d 100644 --- a/tests/unit/command/test_help.py +++ b/tests/unit/command/test_help.py @@ -29,7 +29,7 @@ def recurse_parser(parser: ArgumentParser, parents: tuple[str, ...] = root) -> N # the no. of commands will usually go up, # but if we ever remove commands and drop below, adjust the magic number accordingly - assert len(commands) >= 112 + assert len(commands) >= 116 return sorted(commands)