Skip to content

Commit

Permalink
limit reimport behavior to checkout
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum authored and dberenbaum committed Jul 24, 2024
1 parent 8a91c2b commit 248ba3b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 65 deletions.
76 changes: 11 additions & 65 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import errno
import os
from collections import defaultdict
from copy import copy, deepcopy
from copy import deepcopy
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

import voluptuous as vol

from dvc.exceptions import DvcException
from dvc.prompt import confirm
from dvc.utils import as_posix

Expand All @@ -16,9 +12,6 @@
from dvc.fs import DVCFileSystem
from dvc.output import Output
from dvc.stage import Stage
from dvc_data.hashfile.hash_info import HashInfo
from dvc_data.hashfile.obj import HashFile
from dvc_objects.db import ObjectDB


class RepoDependency(Dependency):
Expand Down Expand Up @@ -101,13 +94,18 @@ def dumpd(self, **kwargs) -> dict[str, Union[str, dict[str, str]]]:
}

def download(self, to: "Output", jobs: Optional[int] = None):
from dvc_data.hashfile.checkout import checkout
from dvc_data.hashfile.build import build
from dvc_data.hashfile.checkout import CheckoutError, checkout

try:
used, obj = self._get_used_and_obj()
for odb, objs in used.items():
self.repo.cloud.pull(objs, jobs=jobs, odb=odb)
repo = self._make_fs(locked=True).repo

_, _, obj = build(
repo.cache.local,
self.fs_path,
repo.dvcfs,
repo.cache.local.fs.PARAM_CHECKSUM,
)
checkout(
to.fs_path,
to.fs,
Expand All @@ -117,7 +115,7 @@ def download(self, to: "Output", jobs: Optional[int] = None):
state=self.repo.state,
prompt=confirm,
)
except DvcException:
except CheckoutError:
super().download(to=to, jobs=jobs)

def update(self, rev: Optional[str] = None):
Expand All @@ -132,58 +130,6 @@ def changed_checksum(self) -> bool:
# immutable, hence its impossible for checksum to change.
return False

def _get_used_and_obj(
self, **kwargs
) -> tuple[dict[Optional["ObjectDB"], set["HashInfo"]], "HashFile"]:
from dvc.config import NoRemoteError
from dvc.exceptions import NoOutputOrStageError
from dvc.utils import as_posix
from dvc_data.hashfile.build import build
from dvc_data.hashfile.tree import Tree, TreeError

local_odb = self.repo.cache.local
locked = kwargs.pop("locked", True)
repo = self._make_fs(locked=locked).repo
used_obj_ids = defaultdict(set)
rev = repo.get_rev()
if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev

try:
for odb, obj_ids in repo.used_objs(
[os.path.join(repo.root_dir, self.def_path)],
force=True,
jobs=kwargs.get("jobs"),
recursive=True,
).items():
if odb is None:
odb = repo.cloud.get_remote_odb()
odb.read_only = True
used_obj_ids[odb].update(obj_ids)
except (NoRemoteError, NoOutputOrStageError):
pass

try:
object_store, _, obj = build(
local_odb,
as_posix(self.def_path),
repo.dvcfs,
local_odb.fs.PARAM_CHECKSUM,
)
except (FileNotFoundError, TreeError) as exc:
raise FileNotFoundError(
errno.ENOENT,
os.strerror(errno.ENOENT) + f" in {self.def_repo[self.PARAM_URL]}",
self.def_path,
) from exc
object_store = copy(object_store)
object_store.read_only = True

used_obj_ids[object_store].add(obj.hash_info)
if isinstance(obj, Tree):
used_obj_ids[object_store].update(oid for _, _, oid in obj)
return used_obj_ids, obj

def _make_fs(
self, rev: Optional[str] = None, locked: bool = True
) -> "DVCFileSystem":
Expand Down
14 changes: 14 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dvc.cachemgr import CacheManager
from dvc.config import NoRemoteError
from dvc.dependency import base
from dvc.dvcfile import load_file
from dvc.fs import system
from dvc.scm import Git
Expand Down Expand Up @@ -722,3 +723,16 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
remote="myremote",
remote_config={"key": "value"},
)


def test_reimport(tmp_dir, scm, dvc, erepo_dir, mocker):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")

spy = mocker.spy(base, "fs_download")
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.called

spy.reset_mock()
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", force=True)
assert not spy.called

0 comments on commit 248ba3b

Please sign in to comment.