From 0d67e5b14a35983a09a0954b0585957fcb1a19c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 3 Jul 2023 16:10:50 +0545 Subject: [PATCH] refactor repro: split logic for planning stages to run for repro The planning logic has now been moved to `plan_repro` api, that allows us to test without creating stages. --- dvc/repo/reproduce.py | 162 ++++++++++++++++-------------- tests/unit/repo/test_reproduce.py | 54 ++++++++++ 2 files changed, 139 insertions(+), 77 deletions(-) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 0a6265c2bd..47c254836a 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -1,11 +1,12 @@ import logging -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast from funcy import ldistinct from dvc.exceptions import ReproductionError from dvc.repo.scm_context import scm_context from dvc.stage.cache import RunCacheNotSupported +from dvc.utils.collections import ensure_list from . import locked @@ -32,75 +33,33 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]: return ret -@locked -@scm_context -def reproduce( # noqa: C901, PLR0912 - self: "Repo", - targets=None, - recursive=False, - pipeline=False, - all_pipelines=False, - **kwargs, -): - from .graph import get_pipeline, get_pipelines - - glob = kwargs.pop("glob", False) - - if isinstance(targets, str): - targets = [targets] - - if not all_pipelines and not targets: - from dvc.dvcfile import PROJECT_FILE - - targets = [PROJECT_FILE] +def collect_stages( + repo: "Repo", + targets: Iterable[str], + recursive: bool = False, + glob: bool = False, +) -> List["Stage"]: + stages: List["Stage"] = [] + for target in targets: + stages.extend(repo.stage.collect(target, recursive=recursive, glob=glob)) + return ldistinct(stages) - targets = targets or [] - interactive = kwargs.get("interactive", False) - if not interactive: - kwargs["interactive"] = self.config["core"].get("interactive", False) - stages = [] - if pipeline or all_pipelines: - pipelines = get_pipelines(self.index.graph) - if all_pipelines: - used_pipelines = pipelines - else: - used_pipelines = [] - for target in targets: - stage = self.stage.get_target(target) - used_pipelines.append(get_pipeline(pipelines, stage)) - - for pline in used_pipelines: - for stage in pline: - if pline.in_degree(stage) == 0: - stages.append(stage) - else: - for target in targets: - stages.extend( - self.stage.collect( - target, - recursive=recursive, - glob=glob, - ) - ) - - if kwargs.get("pull", False) and kwargs.get("run_cache", True): - logger.debug("Pulling run cache") - try: - self.stage_cache.pull(None) - except RunCacheNotSupported as e: - logger.warning("Failed to pull run cache: %s", e) - - return _reproduce_stages(self.index.graph, ldistinct(stages), **kwargs) +def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph": + g = cast("DiGraph", graph.copy()) + for stage in graph: + if stage.frozen: + # NOTE: disconnect frozen stage from its dependencies + g.remove_edges_from(graph.out_edges(stage)) + return g -def _reproduce_stages( # noqa: C901 +def plan_repro( graph: "DiGraph", - stages: List["Stage"], - force_downstream: bool = False, + stages: Optional[List["Stage"]] = None, + pipeline: bool = False, downstream: bool = False, - single_item: bool = False, - **kwargs, + all_pipelines: bool = False, ) -> List["Stage"]: r"""Derive the evaluation of the given node for the given graph. @@ -137,12 +96,70 @@ def _reproduce_stages( # noqa: C901 The derived evaluation of _downstream_ B would be: [B, D, E] """ - from .graph import get_steps + from .graph import get_pipeline, get_pipelines, get_steps + + if pipeline or all_pipelines: + pipelines = get_pipelines(graph) + if stages and pipeline: + pipelines = [get_pipeline(pipelines, stage) for stage in stages] + + leaves: List["Stage"] = [] + for pline in pipelines: + leaves.extend(node for node in pline if pline.in_degree(node) == 0) + stages = ldistinct(leaves) + + active = _remove_frozen_stages(graph) + return get_steps(active, stages, downstream=downstream) - if not single_item: - active = _remove_frozen_stages(graph) - stages = get_steps(active, stages, downstream=downstream) +@locked +@scm_context +def reproduce( # noqa: C901 + self: "Repo", + targets: Union[Iterable[str], str, None] = None, + recursive: bool = False, + pipeline: bool = False, + all_pipelines: bool = False, + downstream: bool = False, + single_item: bool = False, + glob: bool = False, + **kwargs, +): + from dvc.dvcfile import PROJECT_FILE + + if not kwargs.get("interactive", False): + kwargs["interactive"] = self.config["core"].get("interactive", False) + + stages: List["Stage"] = [] + if not all_pipelines: + targets_list = ensure_list(targets or PROJECT_FILE) + stages = collect_stages(self, targets_list, recursive=recursive, glob=glob) + + if kwargs.get("pull", False) and kwargs.get("run_cache", True): + logger.debug("Pulling run cache") + try: + self.stage_cache.pull(None) + except RunCacheNotSupported as e: + logger.warning("Failed to pull run cache: %s", e) + + steps = stages + if pipeline or all_pipelines or not single_item: + graph = self.index.graph + steps = plan_repro( + graph, + stages, + pipeline=pipeline, + downstream=downstream, + all_pipelines=all_pipelines, + ) + return _reproduce_stages(steps, **kwargs) + + +def _reproduce_stages( + stages: List["Stage"], + force_downstream: bool = False, + **kwargs, +) -> List["Stage"]: result: List["Stage"] = [] for i, stage in enumerate(stages): try: @@ -164,12 +181,3 @@ def _reproduce_stages( # noqa: C901 if i < len(stages) - 1: logger.info("") # add a newline return result - - -def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph": - g = cast("DiGraph", graph.copy()) - for stage in graph: - if stage.frozen: - # NOTE: disconnect frozen stage from its dependencies - g.remove_edges_from(graph.out_edges(stage)) - return g diff --git a/tests/unit/repo/test_reproduce.py b/tests/unit/repo/test_reproduce.py index f752bbf8ea..a360284dcc 100644 --- a/tests/unit/repo/test_reproduce.py +++ b/tests/unit/repo/test_reproduce.py @@ -1,3 +1,10 @@ +from itertools import chain + +from networkx import DiGraph + +from dvc.repo.reproduce import plan_repro + + def test_number_reproduces(tmp_dir, dvc, mocker): reproduce_stage_mock = mocker.patch( "dvc.repo.reproduce._reproduce_stage", returns=[] @@ -14,3 +21,50 @@ def test_number_reproduces(tmp_dir, dvc, mocker): dvc.reproduce(all_pipelines=True) assert reproduce_stage_mock.call_count == 5 + + +def test_repro_plan(mocker): + r""" + n1 + / \ + n2 m3 n8 + / \ / \ | + n4 n5 n6 n7 n9 + """ + + # note: downstream steps may not be stable, use AnyOf in such cases + class AnyOf: + def __init__(self, *items): + self.items = items + + def __eq__(self, other: object) -> bool: + return any(item == other for item in self.items) + + n = mocker.sentinel + n1, n2, n3, n4, n5, n6, n7, n8, n9 = (getattr(n, f"n{i}") for i in range(1, 10)) + edges = {n1: [n2, n3], n2: [n4, n5], n3: [n6, n7], n8: [n9]} + for node in chain.from_iterable([n, *v] for n, v in edges.items()): + node.frozen = False + + g = DiGraph(edges) + assert plan_repro(g) == [n4, n5, n2, n6, n7, n3, n1, n9, n8] + assert plan_repro(g, [n1]) == [n4, n5, n2, n6, n7, n3, n1] + assert plan_repro(g, [n4], downstream=True) == [n4, n2, n1] + assert plan_repro(g, [n8], True) == plan_repro(g, [n9], True) == [n9, n8] + assert plan_repro(g, [n2, n8], True) == [n4, n5, n2, n6, n7, n3, n1, n9, n8] + assert plan_repro(g, [n2, n3], downstream=True) == [ + AnyOf(n2, n3), + AnyOf(n2, n3), + n1, + ] + + n2.frozen = True + assert plan_repro(g) == [n2, n6, n7, n3, n1, n9, n8, n4, n5] + assert plan_repro(g, [n1]) == [n2, n6, n7, n3, n1] + assert plan_repro(g, [n4], downstream=True) == [n4] + assert plan_repro(g, [n2, n8], pipeline=True) == [n2, n6, n7, n3, n1, n9, n8] + assert plan_repro(g, [n2, n3], downstream=True) == [ + AnyOf(n2, n3), + AnyOf(n2, n3), + n1, + ]