Skip to content

Commit

Permalink
save import hash info to state (#10531)
Browse files Browse the repository at this point in the history
* save import hash info to state

* clean up path handling

* avoid walking

* tests: do not use as_posix to check for output

---------

Co-authored-by: Saugat Pachhai (सौगात) <[email protected]>
  • Loading branch information
dberenbaum and skshetry authored Aug 20, 2024
1 parent 98dde6c commit 81db9b4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 38 deletions.
2 changes: 1 addition & 1 deletion dvc/dependency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def update(self, rev=None):
self.fs_path = self.fs.version_path(self.fs_path, self.meta.version_id)

def download(self, to, jobs=None):
fs_download(self.fs, self.fs_path, to.fs_path, jobs=jobs)
return fs_download(self.fs, self.fs_path, to.fs_path, jobs=jobs)

def save(self):
super().save()
Expand Down
44 changes: 20 additions & 24 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import voluptuous as vol

from dvc.prompt import confirm
from dvc.utils import as_posix

from .base import Dependency
Expand All @@ -12,6 +11,7 @@
from dvc.fs import DVCFileSystem
from dvc.output import Output
from dvc.stage import Stage
from dvc_data.hashfile.hash_info import HashInfo


class RepoDependency(Dependency):
Expand Down Expand Up @@ -94,29 +94,25 @@ def dumpd(self, **kwargs) -> dict[str, Union[str, dict[str, str]]]:
}

def download(self, to: "Output", jobs: Optional[int] = None):
from dvc_data.hashfile.build import build
from dvc_data.hashfile.checkout import CheckoutError, checkout

try:
repo = self._make_fs(locked=True).repo

_, _, obj = build(
repo.cache.local,
self.fs_path,
repo.dvcfs,
repo.cache.local.fs.PARAM_CHECKSUM,
)
checkout(
to.fs_path,
to.fs,
obj,
self.repo.cache.local,
ignore=None,
state=self.repo.state,
prompt=confirm,
)
except (CheckoutError, FileNotFoundError):
super().download(to=to, jobs=jobs)
from dvc.fs import LocalFileSystem

files = super().download(to=to, jobs=jobs)
if not isinstance(to.fs, LocalFileSystem):
return files

hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
for src_path, dest_path in files:
try:
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
dest_info = to.fs.info(dest_path)
except (OSError, KeyError, AttributeError):
# If no hash info found, just keep going and output will be hashed later
continue
if hash_info:
hashes.append((dest_path, hash_info, dest_info))
cache = to.cache if to.use_cache else to.local_cache
cache.state.save_many(hashes, to.fs)
return files

def update(self, rev: Optional[str] = None):
if rev:
Expand Down
6 changes: 3 additions & 3 deletions dvc/fs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

def download(
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
) -> int:
) -> list[tuple[str, str]]:
from dvc.scm import lfs_prefetch

from .callbacks import TqdmCallback
Expand All @@ -61,7 +61,7 @@ def download(
]
if not from_infos:
localfs.makedirs(to, exist_ok=True)
return 0
return []
to_infos = [
localfs.join(to, *fs.relparts(info, fs_path)) for info in from_infos
]
Expand All @@ -81,7 +81,7 @@ def download(
cb.set_size(len(from_infos))
jobs = jobs or fs.jobs
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)
return len(to_infos)
return list(zip(from_infos, to_infos))


def parse_external_url(url, fs_config=None, config=None):
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def download(

out = resolve_output(path, out, force=force)
fs = self.repo.dvcfs
count = fs_download(fs, path, os.path.abspath(out), jobs=jobs)
count = len(fs_download(fs, path, os.path.abspath(out), jobs=jobs))
return count, out

@staticmethod
Expand Down
16 changes: 7 additions & 9 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from dvc.cachemgr import CacheManager
from dvc.config import NoRemoteError
from dvc.dependency import base
from dvc.dvcfile import load_file
from dvc.fs import system
from dvc.scm import Git
from dvc.stage.exceptions import StagePathNotFoundError
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils.fs import remove
from dvc_data.hashfile import hash
from dvc_data.index.index import DataIndexDirError


Expand Down Expand Up @@ -725,14 +725,12 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
)


def test_reimport(tmp_dir, scm, dvc, erepo_dir, mocker):
def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")

spy = mocker.spy(base, "fs_download")
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.called

spy.reset_mock()
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", force=True)
assert not spy.called
spy = mocker.spy(hash, "file_md5")
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.call_count == 1
for call in spy.call_args_list:
assert stage.outs[0].fs_path != call.args[0]

0 comments on commit 81db9b4

Please sign in to comment.