diff --git a/dvc/repo/add.py b/dvc/repo/add.py index b7f09658bd..f27978f09e 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -98,8 +98,11 @@ def add( # noqa: C901 **kwargs, ) + # remove existing stages that are to-be replaced with these + # new stages for the graph checks. + old_stages = set(repo.stages) - set(stages) try: - repo.check_modified_graph(stages) + repo.check_modified_graph(stages, list(old_stages)) except OverlappingOutputPathsError as exc: msg = ( "Cannot add '{out}', because it is overlapping with other " @@ -233,7 +236,6 @@ def _create_stages( transfer=False, **kwargs, ): - from dvc.dvcfile import Dvcfile from dvc.stage import Stage, create_stage, restore_meta expanded_targets = glob_targets(targets, glob=glob) @@ -259,12 +261,9 @@ def _create_stages( external=external, ) restore_meta(stage) - Dvcfile(repo, stage.path).remove() if desc: stage.outs[0].desc = desc - repo._reset() # pylint: disable=protected-access - if not stage: if pbar is not None: pbar.total -= 1 diff --git a/tests/func/test_add.py b/tests/func/test_add.py index b5559ddc3f..071a2f7531 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -24,7 +24,11 @@ from dvc.hash_info import HashInfo from dvc.main import main from dvc.objects.db import ODBManager -from dvc.output import OutputAlreadyTrackedError, OutputIsStageFileError +from dvc.output import ( + OutputAlreadyTrackedError, + OutputDoesNotExistError, + OutputIsStageFileError, +) from dvc.stage import Stage from dvc.stage.exceptions import ( StageExternalOutputsError, @@ -1190,3 +1194,42 @@ def test_add_ignored(tmp_dir, scm, dvc): assert str(exc.value) == ("bad DVC file name '{}' is git-ignored.").format( os.path.join("dir", "subdir.dvc") ) + + +def test_add_on_not_existing_file_should_not_remove_stage_file(tmp_dir, dvc): + (stage,) = tmp_dir.dvc_gen("foo", "foo") + (tmp_dir / "foo").unlink() + dvcfile_contents = (tmp_dir / stage.path).read_text() + + with pytest.raises(OutputDoesNotExistError): + dvc.add("foo") + assert (tmp_dir / "foo.dvc").exists() + assert (tmp_dir / stage.path).read_text() == dvcfile_contents + + +@pytest.mark.parametrize( + "target", + [ + "dvc.repo.Repo.check_modified_graph", + "dvc.stage.Stage.save", + "dvc.stage.Stage.commit", + ], +) +def test_add_does_not_remove_stage_file_on_failure( + tmp_dir, dvc, mocker, target +): + (stage,) = tmp_dir.dvc_gen("foo", "foo") + tmp_dir.gen("foo", "foobar") # update file + dvcfile_contents = (tmp_dir / stage.path).read_text() + + exc_msg = f"raising error from mocked '{target}'" + mocker.patch( + target, + side_effect=DvcException(exc_msg), + ) + + with pytest.raises(DvcException) as exc_info: + dvc.add("foo") + assert str(exc_info.value) == exc_msg + assert (tmp_dir / "foo.dvc").exists() + assert (tmp_dir / stage.path).read_text() == dvcfile_contents