diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index ab0f17e2b3..bbabd56a8f 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -44,7 +44,7 @@ from dvc.repo import Repo from dvc.repo.experiments.stash import ExpStashEntry from dvc.scm import Git - from dvc.stage import PipelineStage + from dvc.stage import PipelineStage, Stage logger = logging.getLogger(__name__) @@ -255,6 +255,15 @@ def _from_stash_entry( **kwargs, ) + @classmethod + def _get_stage_files(cls, stages: List["Stage"]) -> List[str]: + from dvc.stage.utils import _get_stage_files + + ret: List[str] = [] + for stage in stages: + ret.extend(_get_stage_files(stage)) + return ret + @classmethod def _get_top_level_paths(cls, repo: "Repo") -> List["str"]: return list( @@ -511,6 +520,9 @@ def reproduce( ) stages = dvc.reproduce(*args, **kwargs) + if paths := cls._get_stage_files(stages): + logger.debug("Staging stage-related files: %s", paths) + dvc.scm_context.add(paths) if paths := cls._get_top_level_paths(dvc): logger.debug("Staging top-level files: %s", paths) dvc.scm_context.add(paths) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 808fa93291..0a6265c2bd 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Iterator, List, Optional, cast +from typing import TYPE_CHECKING, List, Optional, cast from funcy import ldistinct @@ -29,33 +29,9 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]: ret = stage.reproduce(**kwargs) if ret and not kwargs.get("dry", False): stage.dump(update_pipeline=False) - _track_stage(stage) return ret -def _get_stage_files(stage: "Stage") -> Iterator[str]: - yield stage.dvcfile.relpath - for dep in stage.deps: - if ( - not dep.use_scm_ignore - and dep.is_in_repo - and not stage.repo.dvcfs.isdvc(stage.repo.dvcfs.from_os_path(str(dep))) - ): - yield dep.fs_path - for out in stage.outs: - if not out.use_scm_ignore and out.is_in_repo: - yield out.fs_path - - -def _track_stage(stage: "Stage") -> None: - from dvc.utils import relpath - - context = stage.repo.scm_context - for path in _get_stage_files(stage): - context.track_file(relpath(path)) - return context.track_changed_files() - - @locked @scm_context def reproduce( # noqa: C901, PLR0912 diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index bb3d2e575f..430b850797 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -280,3 +280,27 @@ def validate_kwargs( kwargs.pop("name", None) return kwargs + + +def _get_stage_files(stage: "Stage") -> List[str]: + from dvc.dvcfile import ProjectFile + from dvc.utils import relpath + + ret: List[str] = [] + file = stage.dvcfile + ret.append(file.relpath) + if isinstance(file, ProjectFile): + ret.append(file._lockfile.relpath) # pylint: disable=protected-access + + for dep in stage.deps: + if ( + not dep.use_scm_ignore + and dep.is_in_repo + and not stage.repo.dvcfs.isdvc(stage.repo.dvcfs.from_os_path(str(dep))) + ): + ret.append(relpath(dep.fs_path)) + + for out in stage.outs: + if not out.use_scm_ignore and out.is_in_repo: + ret.append(relpath(out.fs_path)) + return ret diff --git a/tests/unit/repo/test_reproduce.py b/tests/unit/repo/test_reproduce.py index 3206d9d2cc..f752bbf8ea 100644 --- a/tests/unit/repo/test_reproduce.py +++ b/tests/unit/repo/test_reproduce.py @@ -1,8 +1,3 @@ -import os - -from dvc.repo.reproduce import _get_stage_files - - def test_number_reproduces(tmp_dir, dvc, mocker): reproduce_stage_mock = mocker.patch( "dvc.repo.reproduce._reproduce_stage", returns=[] @@ -19,42 +14,3 @@ def test_number_reproduces(tmp_dir, dvc, mocker): dvc.reproduce(all_pipelines=True) assert reproduce_stage_mock.call_count == 5 - - -def test_get_stage_files(tmp_dir, dvc): - tmp_dir.dvc_gen("dvc-dep", "dvc-dep") - tmp_dir.gen("other-dep", "other-dep") - - stage = dvc.stage.add( - name="stage", - cmd="foo", - deps=["dvc-dep", "other-dep"], - outs=["dvc-out"], - outs_no_cache=["other-out"], - ) - result = set(_get_stage_files(stage)) - assert result == { - stage.dvcfile.relpath, - str(tmp_dir / "other-dep"), - str(tmp_dir / "other-out"), - } - - -def test_get_stage_files_wdir(tmp_dir, dvc): - tmp_dir.gen({"dir": {"dvc-dep": "dvc-dep", "other-dep": "other-dep"}}) - dvc.add(os.path.join("dir", "dvc-dep")) - - stage = dvc.stage.add( - name="stage", - cmd="foo", - wdir="dir", - deps=["dvc-dep", "other-dep"], - outs=["dvc-out"], - outs_no_cache=["other-out"], - ) - result = set(_get_stage_files(stage)) - assert result == { - stage.dvcfile.relpath, - str(tmp_dir / "dir" / "other-dep"), - str(tmp_dir / "dir" / "other-out"), - } diff --git a/tests/unit/stage/test_utils.py b/tests/unit/stage/test_utils.py index 363efffbe4..dffedc2990 100644 --- a/tests/unit/stage/test_utils.py +++ b/tests/unit/stage/test_utils.py @@ -1,7 +1,7 @@ import os from dvc.fs import localfs -from dvc.stage.utils import resolve_paths +from dvc.stage.utils import _get_stage_files, resolve_paths def test_resolve_paths(): @@ -19,3 +19,40 @@ def test_resolve_paths(): path, wdir = resolve_paths(fs=localfs, path=file_path, wdir="../../some-dir") assert path == os.path.abspath(file_path) assert wdir == os.path.abspath("some-dir") + + +def test_get_stage_files(tmp_dir, dvc): + tmp_dir.dvc_gen("dvc-dep", "dvc-dep") + tmp_dir.gen("other-dep", "other-dep") + stage = dvc.stage.create( + name="stage", + cmd="foo", + deps=["dvc-dep", "other-dep"], + outs=["dvc-out"], + outs_no_cache=["other-out"], + ) + assert _get_stage_files(stage) == [ + "dvc.yaml", + "dvc.lock", + "other-dep", + "other-out", + ] + + +def test_get_stage_files_wdir(tmp_dir, dvc): + tmp_dir.gen({"dir": {"dvc-dep": "dvc-dep", "other-dep": "other-dep"}}) + dvc.add(os.path.join("dir", "dvc-dep")) + stage = dvc.stage.create( + name="stage", + cmd="foo", + wdir="dir", + deps=["dvc-dep", "other-dep"], + outs=["dvc-out"], + outs_no_cache=["other-out"], + ) + assert _get_stage_files(stage) == [ + "dvc.yaml", + "dvc.lock", + os.path.join("dir", "other-dep"), + os.path.join("dir", "other-out"), + ]