From 248ba3b6b64f9e49dedab6d2aedab92e977bd0d0 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 12 Apr 2024 11:34:17 -0400 Subject: [PATCH] limit reimport behavior to checkout --- dvc/dependency/repo.py | 76 ++++++--------------------------------- tests/func/test_import.py | 14 ++++++++ 2 files changed, 25 insertions(+), 65 deletions(-) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 17d2b3f371..811107b6fc 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -1,12 +1,8 @@ -import errno -import os -from collections import defaultdict -from copy import copy, deepcopy +from copy import deepcopy from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union import voluptuous as vol -from dvc.exceptions import DvcException from dvc.prompt import confirm from dvc.utils import as_posix @@ -16,9 +12,6 @@ from dvc.fs import DVCFileSystem from dvc.output import Output from dvc.stage import Stage - from dvc_data.hashfile.hash_info import HashInfo - from dvc_data.hashfile.obj import HashFile - from dvc_objects.db import ObjectDB class RepoDependency(Dependency): @@ -101,13 +94,18 @@ def dumpd(self, **kwargs) -> dict[str, Union[str, dict[str, str]]]: } def download(self, to: "Output", jobs: Optional[int] = None): - from dvc_data.hashfile.checkout import checkout + from dvc_data.hashfile.build import build + from dvc_data.hashfile.checkout import CheckoutError, checkout try: - used, obj = self._get_used_and_obj() - for odb, objs in used.items(): - self.repo.cloud.pull(objs, jobs=jobs, odb=odb) + repo = self._make_fs(locked=True).repo + _, _, obj = build( + repo.cache.local, + self.fs_path, + repo.dvcfs, + repo.cache.local.fs.PARAM_CHECKSUM, + ) checkout( to.fs_path, to.fs, @@ -117,7 +115,7 @@ def download(self, to: "Output", jobs: Optional[int] = None): state=self.repo.state, prompt=confirm, ) - except DvcException: + except CheckoutError: super().download(to=to, jobs=jobs) def update(self, rev: Optional[str] = None): @@ -132,58 +130,6 @@ def changed_checksum(self) -> bool: # immutable, hence its impossible for checksum to change. return False - def _get_used_and_obj( - self, **kwargs - ) -> tuple[dict[Optional["ObjectDB"], set["HashInfo"]], "HashFile"]: - from dvc.config import NoRemoteError - from dvc.exceptions import NoOutputOrStageError - from dvc.utils import as_posix - from dvc_data.hashfile.build import build - from dvc_data.hashfile.tree import Tree, TreeError - - local_odb = self.repo.cache.local - locked = kwargs.pop("locked", True) - repo = self._make_fs(locked=locked).repo - used_obj_ids = defaultdict(set) - rev = repo.get_rev() - if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None: - self.def_repo[self.PARAM_REV_LOCK] = rev - - try: - for odb, obj_ids in repo.used_objs( - [os.path.join(repo.root_dir, self.def_path)], - force=True, - jobs=kwargs.get("jobs"), - recursive=True, - ).items(): - if odb is None: - odb = repo.cloud.get_remote_odb() - odb.read_only = True - used_obj_ids[odb].update(obj_ids) - except (NoRemoteError, NoOutputOrStageError): - pass - - try: - object_store, _, obj = build( - local_odb, - as_posix(self.def_path), - repo.dvcfs, - local_odb.fs.PARAM_CHECKSUM, - ) - except (FileNotFoundError, TreeError) as exc: - raise FileNotFoundError( - errno.ENOENT, - os.strerror(errno.ENOENT) + f" in {self.def_repo[self.PARAM_URL]}", - self.def_path, - ) from exc - object_store = copy(object_store) - object_store.read_only = True - - used_obj_ids[object_store].add(obj.hash_info) - if isinstance(obj, Tree): - used_obj_ids[object_store].update(oid for _, _, oid in obj) - return used_obj_ids, obj - def _make_fs( self, rev: Optional[str] = None, locked: bool = True ) -> "DVCFileSystem": diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 125337811e..86e8b3b967 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -6,6 +6,7 @@ from dvc.cachemgr import CacheManager from dvc.config import NoRemoteError +from dvc.dependency import base from dvc.dvcfile import load_file from dvc.fs import system from dvc.scm import Git @@ -722,3 +723,16 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir): remote="myremote", remote_config={"key": "value"}, ) + + +def test_reimport(tmp_dir, scm, dvc, erepo_dir, mocker): + with erepo_dir.chdir(): + erepo_dir.dvc_gen("foo", "foo content", commit="create foo") + + spy = mocker.spy(base, "fs_download") + dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported") + assert spy.called + + spy.reset_mock() + dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", force=True) + assert not spy.called