Skip to content

Commit

Permalink
refactor repro: split logic for planning stages to run for repro (#9690)
Browse files Browse the repository at this point in the history
The planning logic has now been moved to `plan_repro` api, that allows us
to test without creating stages.
  • Loading branch information
skshetry authored Jul 3, 2023
1 parent b7b6642 commit 2becfd9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 77 deletions.
162 changes: 85 additions & 77 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import TYPE_CHECKING, List, Optional, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast

from funcy import ldistinct

from dvc.exceptions import ReproductionError
from dvc.repo.scm_context import scm_context
from dvc.stage.cache import RunCacheNotSupported
from dvc.utils.collections import ensure_list

from . import locked

Expand All @@ -32,75 +33,33 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]:
return ret


@locked
@scm_context
def reproduce( # noqa: C901, PLR0912
self: "Repo",
targets=None,
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs,
):
from .graph import get_pipeline, get_pipelines

glob = kwargs.pop("glob", False)

if isinstance(targets, str):
targets = [targets]

if not all_pipelines and not targets:
from dvc.dvcfile import PROJECT_FILE

targets = [PROJECT_FILE]
def collect_stages(
repo: "Repo",
targets: Iterable[str],
recursive: bool = False,
glob: bool = False,
) -> List["Stage"]:
stages: List["Stage"] = []
for target in targets:
stages.extend(repo.stage.collect(target, recursive=recursive, glob=glob))
return ldistinct(stages)

targets = targets or []
interactive = kwargs.get("interactive", False)
if not interactive:
kwargs["interactive"] = self.config["core"].get("interactive", False)

stages = []
if pipeline or all_pipelines:
pipelines = get_pipelines(self.index.graph)
if all_pipelines:
used_pipelines = pipelines
else:
used_pipelines = []
for target in targets:
stage = self.stage.get_target(target)
used_pipelines.append(get_pipeline(pipelines, stage))

for pline in used_pipelines:
for stage in pline:
if pline.in_degree(stage) == 0:
stages.append(stage)
else:
for target in targets:
stages.extend(
self.stage.collect(
target,
recursive=recursive,
glob=glob,
)
)

if kwargs.get("pull", False) and kwargs.get("run_cache", True):
logger.debug("Pulling run cache")
try:
self.stage_cache.pull(None)
except RunCacheNotSupported as e:
logger.warning("Failed to pull run cache: %s", e)

return _reproduce_stages(self.index.graph, ldistinct(stages), **kwargs)
def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph":
g = cast("DiGraph", graph.copy())
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
g.remove_edges_from(graph.out_edges(stage))
return g


def _reproduce_stages( # noqa: C901
def plan_repro(
graph: "DiGraph",
stages: List["Stage"],
force_downstream: bool = False,
stages: Optional[List["Stage"]] = None,
pipeline: bool = False,
downstream: bool = False,
single_item: bool = False,
**kwargs,
all_pipelines: bool = False,
) -> List["Stage"]:
r"""Derive the evaluation of the given node for the given graph.
Expand Down Expand Up @@ -137,12 +96,70 @@ def _reproduce_stages( # noqa: C901
The derived evaluation of _downstream_ B would be: [B, D, E]
"""
from .graph import get_steps
from .graph import get_pipeline, get_pipelines, get_steps

if pipeline or all_pipelines:
pipelines = get_pipelines(graph)
if stages and pipeline:
pipelines = [get_pipeline(pipelines, stage) for stage in stages]

leaves: List["Stage"] = []
for pline in pipelines:
leaves.extend(node for node in pline if pline.in_degree(node) == 0)
stages = ldistinct(leaves)

active = _remove_frozen_stages(graph)
return get_steps(active, stages, downstream=downstream)

if not single_item:
active = _remove_frozen_stages(graph)
stages = get_steps(active, stages, downstream=downstream)

@locked
@scm_context
def reproduce( # noqa: C901
self: "Repo",
targets: Union[Iterable[str], str, None] = None,
recursive: bool = False,
pipeline: bool = False,
all_pipelines: bool = False,
downstream: bool = False,
single_item: bool = False,
glob: bool = False,
**kwargs,
):
from dvc.dvcfile import PROJECT_FILE

if not kwargs.get("interactive", False):
kwargs["interactive"] = self.config["core"].get("interactive", False)

stages: List["Stage"] = []
if not all_pipelines:
targets_list = ensure_list(targets or PROJECT_FILE)
stages = collect_stages(self, targets_list, recursive=recursive, glob=glob)

if kwargs.get("pull", False) and kwargs.get("run_cache", True):
logger.debug("Pulling run cache")
try:
self.stage_cache.pull(None)
except RunCacheNotSupported as e:
logger.warning("Failed to pull run cache: %s", e)

steps = stages
if pipeline or all_pipelines or not single_item:
graph = self.index.graph
steps = plan_repro(
graph,
stages,
pipeline=pipeline,
downstream=downstream,
all_pipelines=all_pipelines,
)
return _reproduce_stages(steps, **kwargs)


def _reproduce_stages(
stages: List["Stage"],
force_downstream: bool = False,
**kwargs,
) -> List["Stage"]:
result: List["Stage"] = []
for i, stage in enumerate(stages):
try:
Expand All @@ -164,12 +181,3 @@ def _reproduce_stages( # noqa: C901
if i < len(stages) - 1:
logger.info("") # add a newline
return result


def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph":
g = cast("DiGraph", graph.copy())
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
g.remove_edges_from(graph.out_edges(stage))
return g
54 changes: 54 additions & 0 deletions tests/unit/repo/test_reproduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from itertools import chain

from networkx import DiGraph

from dvc.repo.reproduce import plan_repro


def test_number_reproduces(tmp_dir, dvc, mocker):
reproduce_stage_mock = mocker.patch(
"dvc.repo.reproduce._reproduce_stage", returns=[]
Expand All @@ -14,3 +21,50 @@ def test_number_reproduces(tmp_dir, dvc, mocker):
dvc.reproduce(all_pipelines=True)

assert reproduce_stage_mock.call_count == 5


def test_repro_plan(mocker):
r"""
n1
/ \
n2 m3 n8
/ \ / \ |
n4 n5 n6 n7 n9
"""

# note: downstream steps may not be stable, use AnyOf in such cases
class AnyOf:
def __init__(self, *items):
self.items = items

def __eq__(self, other: object) -> bool:
return any(item == other for item in self.items)

n = mocker.sentinel
n1, n2, n3, n4, n5, n6, n7, n8, n9 = (getattr(n, f"n{i}") for i in range(1, 10))
edges = {n1: [n2, n3], n2: [n4, n5], n3: [n6, n7], n8: [n9]}
for node in chain.from_iterable([n, *v] for n, v in edges.items()):
node.frozen = False

g = DiGraph(edges)
assert plan_repro(g) == [n4, n5, n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n1]) == [n4, n5, n2, n6, n7, n3, n1]
assert plan_repro(g, [n4], downstream=True) == [n4, n2, n1]
assert plan_repro(g, [n8], True) == plan_repro(g, [n9], True) == [n9, n8]
assert plan_repro(g, [n2, n8], True) == [n4, n5, n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n2, n3], downstream=True) == [
AnyOf(n2, n3),
AnyOf(n2, n3),
n1,
]

n2.frozen = True
assert plan_repro(g) == [n2, n6, n7, n3, n1, n9, n8, n4, n5]
assert plan_repro(g, [n1]) == [n2, n6, n7, n3, n1]
assert plan_repro(g, [n4], downstream=True) == [n4]
assert plan_repro(g, [n2, n8], pipeline=True) == [n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n2, n3], downstream=True) == [
AnyOf(n2, n3),
AnyOf(n2, n3),
n1,
]

0 comments on commit 2becfd9

Please sign in to comment.