Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor repro: split logic for planning stages to run for repro #9690

Merged
merged 1 commit into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 85 additions & 77 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import TYPE_CHECKING, List, Optional, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast

from funcy import ldistinct

from dvc.exceptions import ReproductionError
from dvc.repo.scm_context import scm_context
from dvc.stage.cache import RunCacheNotSupported
from dvc.utils.collections import ensure_list

from . import locked

Expand All @@ -32,75 +33,33 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]:
return ret


@locked
@scm_context
def reproduce( # noqa: C901, PLR0912
self: "Repo",
targets=None,
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs,
):
from .graph import get_pipeline, get_pipelines

glob = kwargs.pop("glob", False)

if isinstance(targets, str):
targets = [targets]

if not all_pipelines and not targets:
from dvc.dvcfile import PROJECT_FILE

targets = [PROJECT_FILE]
def collect_stages(
repo: "Repo",
targets: Iterable[str],
recursive: bool = False,
glob: bool = False,
) -> List["Stage"]:
stages: List["Stage"] = []
for target in targets:
stages.extend(repo.stage.collect(target, recursive=recursive, glob=glob))
return ldistinct(stages)

targets = targets or []
interactive = kwargs.get("interactive", False)
if not interactive:
kwargs["interactive"] = self.config["core"].get("interactive", False)

stages = []
if pipeline or all_pipelines:
pipelines = get_pipelines(self.index.graph)
if all_pipelines:
used_pipelines = pipelines
else:
used_pipelines = []
for target in targets:
stage = self.stage.get_target(target)
used_pipelines.append(get_pipeline(pipelines, stage))

for pline in used_pipelines:
for stage in pline:
if pline.in_degree(stage) == 0:
stages.append(stage)
else:
for target in targets:
stages.extend(
self.stage.collect(
target,
recursive=recursive,
glob=glob,
)
)

if kwargs.get("pull", False) and kwargs.get("run_cache", True):
logger.debug("Pulling run cache")
try:
self.stage_cache.pull(None)
except RunCacheNotSupported as e:
logger.warning("Failed to pull run cache: %s", e)

return _reproduce_stages(self.index.graph, ldistinct(stages), **kwargs)
def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph":
g = cast("DiGraph", graph.copy())
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
g.remove_edges_from(graph.out_edges(stage))
return g


def _reproduce_stages( # noqa: C901
def plan_repro(
graph: "DiGraph",
stages: List["Stage"],
force_downstream: bool = False,
stages: Optional[List["Stage"]] = None,
pipeline: bool = False,
downstream: bool = False,
single_item: bool = False,
**kwargs,
all_pipelines: bool = False,
) -> List["Stage"]:
r"""Derive the evaluation of the given node for the given graph.
Expand Down Expand Up @@ -137,12 +96,70 @@ def _reproduce_stages( # noqa: C901
The derived evaluation of _downstream_ B would be: [B, D, E]
"""
from .graph import get_steps
from .graph import get_pipeline, get_pipelines, get_steps

if pipeline or all_pipelines:
pipelines = get_pipelines(graph)
if stages and pipeline:
pipelines = [get_pipeline(pipelines, stage) for stage in stages]

leaves: List["Stage"] = []
for pline in pipelines:
leaves.extend(node for node in pline if pline.in_degree(node) == 0)
stages = ldistinct(leaves)

active = _remove_frozen_stages(graph)
return get_steps(active, stages, downstream=downstream)

if not single_item:
active = _remove_frozen_stages(graph)
stages = get_steps(active, stages, downstream=downstream)

@locked
@scm_context
def reproduce( # noqa: C901
self: "Repo",
targets: Union[Iterable[str], str, None] = None,
recursive: bool = False,
pipeline: bool = False,
all_pipelines: bool = False,
downstream: bool = False,
single_item: bool = False,
glob: bool = False,
**kwargs,
):
from dvc.dvcfile import PROJECT_FILE

if not kwargs.get("interactive", False):
kwargs["interactive"] = self.config["core"].get("interactive", False)

stages: List["Stage"] = []
if not all_pipelines:
targets_list = ensure_list(targets or PROJECT_FILE)
stages = collect_stages(self, targets_list, recursive=recursive, glob=glob)

if kwargs.get("pull", False) and kwargs.get("run_cache", True):
logger.debug("Pulling run cache")
try:
self.stage_cache.pull(None)
except RunCacheNotSupported as e:
logger.warning("Failed to pull run cache: %s", e)

steps = stages
if pipeline or all_pipelines or not single_item:
graph = self.index.graph
steps = plan_repro(
graph,
stages,
pipeline=pipeline,
downstream=downstream,
all_pipelines=all_pipelines,
)
return _reproduce_stages(steps, **kwargs)


def _reproduce_stages(
stages: List["Stage"],
force_downstream: bool = False,
**kwargs,
) -> List["Stage"]:
result: List["Stage"] = []
for i, stage in enumerate(stages):
try:
Expand All @@ -164,12 +181,3 @@ def _reproduce_stages( # noqa: C901
if i < len(stages) - 1:
logger.info("") # add a newline
return result


def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph":
g = cast("DiGraph", graph.copy())
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
g.remove_edges_from(graph.out_edges(stage))
return g
54 changes: 54 additions & 0 deletions tests/unit/repo/test_reproduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from itertools import chain

from networkx import DiGraph

from dvc.repo.reproduce import plan_repro


def test_number_reproduces(tmp_dir, dvc, mocker):
reproduce_stage_mock = mocker.patch(
"dvc.repo.reproduce._reproduce_stage", returns=[]
Expand All @@ -14,3 +21,50 @@ def test_number_reproduces(tmp_dir, dvc, mocker):
dvc.reproduce(all_pipelines=True)

assert reproduce_stage_mock.call_count == 5


def test_repro_plan(mocker):
r"""
n1
/ \
n2 m3 n8
/ \ / \ |
n4 n5 n6 n7 n9
"""

# note: downstream steps may not be stable, use AnyOf in such cases
class AnyOf:
def __init__(self, *items):
self.items = items

def __eq__(self, other: object) -> bool:
return any(item == other for item in self.items)

n = mocker.sentinel
n1, n2, n3, n4, n5, n6, n7, n8, n9 = (getattr(n, f"n{i}") for i in range(1, 10))
edges = {n1: [n2, n3], n2: [n4, n5], n3: [n6, n7], n8: [n9]}
for node in chain.from_iterable([n, *v] for n, v in edges.items()):
node.frozen = False

g = DiGraph(edges)
assert plan_repro(g) == [n4, n5, n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n1]) == [n4, n5, n2, n6, n7, n3, n1]
assert plan_repro(g, [n4], downstream=True) == [n4, n2, n1]
assert plan_repro(g, [n8], True) == plan_repro(g, [n9], True) == [n9, n8]
assert plan_repro(g, [n2, n8], True) == [n4, n5, n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n2, n3], downstream=True) == [
AnyOf(n2, n3),
AnyOf(n2, n3),
n1,
]

n2.frozen = True
assert plan_repro(g) == [n2, n6, n7, n3, n1, n9, n8, n4, n5]
assert plan_repro(g, [n1]) == [n2, n6, n7, n3, n1]
assert plan_repro(g, [n4], downstream=True) == [n4]
assert plan_repro(g, [n2, n8], pipeline=True) == [n2, n6, n7, n3, n1, n9, n8]
assert plan_repro(g, [n2, n3], downstream=True) == [
AnyOf(n2, n3),
AnyOf(n2, n3),
n1,
]