Skip to content

Commit

Permalink
repro: make stages run in order of the graph.
Browse files Browse the repository at this point in the history
In case of --single-item, this means it'll run in the order of arguments passed.
For --downstream, this fixes running of stage out of order when multiple arguments are passed.
For normal case, this enforces the same logic, so that they are never run out-of-order.
  • Loading branch information
skshetry committed Jun 30, 2023
1 parent 36ec25f commit 9c6b5e7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 45 deletions.
24 changes: 23 additions & 1 deletion dvc/repo/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterator, List, Set
from typing import TYPE_CHECKING, Any, Iterator, List, Set, TypeVar

from dvc.fs import localfs
from dvc.utils.fs import path_isin
Expand All @@ -8,6 +8,8 @@

from dvc.stage import Stage

T = TypeVar("T")


def check_acyclic(graph: "DiGraph") -> None:
import networkx as nx
Expand Down Expand Up @@ -39,6 +41,26 @@ def get_pipelines(graph: "DiGraph"):
return [graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph)]


def get_subgraph_of_nodes(
graph: "DiGraph", sources: List[Any], downstream: bool = False
) -> "DiGraph":
from networkx import dfs_postorder_nodes, reverse_view

assert sources
g = reverse_view(graph) if downstream else graph
nodes = []
for source in sources:
nodes.extend(dfs_postorder_nodes(g, source))
return graph.subgraph(nodes)


def get_steps(graph: "DiGraph", sources: List[T], downstream: bool = False) -> List[T]:
from networkx import dfs_postorder_nodes

sub = get_subgraph_of_nodes(graph, sources, downstream=downstream)
return list(dfs_postorder_nodes(sub))


def collect_pipeline(stage: "Stage", graph: "DiGraph") -> Iterator["Stage"]:
import networkx as nx

Expand Down
67 changes: 23 additions & 44 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import TYPE_CHECKING, Iterator, List
from typing import TYPE_CHECKING, Iterator, List, cast

from funcy import ldistinct

from dvc.exceptions import ReproductionError
from dvc.repo.scm_context import scm_context
Expand All @@ -8,6 +10,8 @@
from . import locked

if TYPE_CHECKING:
from networkx import DiGraph

from dvc.stage import Stage

from . import Repo
Expand Down Expand Up @@ -78,11 +82,12 @@ def reproduce( # noqa: C901, PLR0912

targets = [PROJECT_FILE]

targets = targets or []
interactive = kwargs.get("interactive", False)
if not interactive:
kwargs["interactive"] = self.config["core"].get("interactive", False)

stages = set()
stages = []
if pipeline or all_pipelines:
pipelines = get_pipelines(self.index.graph)
if all_pipelines:
Expand All @@ -96,10 +101,10 @@ def reproduce( # noqa: C901, PLR0912
for pline in used_pipelines:
for stage in pline:
if pline.in_degree(stage) == 0:
stages.add(stage)
stages.append(stage)
else:
for target in targets:
stages.update(
stages.extend(
self.stage.collect(
target,
recursive=recursive,
Expand All @@ -114,7 +119,7 @@ def reproduce( # noqa: C901, PLR0912
except RunCacheNotSupported as e:
logger.warning("Failed to pull run cache: %s", e)

return _reproduce_stages(self.index.graph, list(stages), **kwargs)
return _reproduce_stages(self.index.graph, ldistinct(stages), **kwargs)


def _reproduce_stages( # noqa: C901
Expand Down Expand Up @@ -159,15 +164,19 @@ def _reproduce_stages( # noqa: C901
The derived evaluation of _downstream_ B would be: [B, D, E]
"""
steps = _get_steps(graph, stages, downstream, single_item)
from .graph import get_steps

if not single_item:
active = _remove_frozen_stages(graph)
stages = get_steps(active, stages, downstream=downstream)

force_downstream = kwargs.pop("force_downstream", False)
result: List["Stage"] = []
unchanged: List["Stage"] = []
# `ret` is used to add a cosmetic newline.
ret: List["Stage"] = []

for stage in steps:
for stage in stages:
if ret:
logger.info("")

Expand All @@ -191,40 +200,10 @@ def _reproduce_stages( # noqa: C901
return result


def _get_steps(graph, stages, downstream, single_item):
import networkx as nx

active = graph.copy()
if not single_item:
# NOTE: frozen stages don't matter for single_item
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
active.remove_edges_from(graph.out_edges(stage))

all_pipelines: List["Stage"] = []
for stage in stages:
if downstream:
# NOTE (py3 only):
# Python's `deepcopy` defaults to pickle/unpickle the object.
# Stages are complex objects (with references to `repo`,
# `outs`, and `deps`) that cause struggles when you try
# to serialize them. We need to create a copy of the graph
# itself, and then reverse it, instead of using
# graph.reverse() directly because it calls `deepcopy`
# underneath -- unless copy=False is specified.
nodes = nx.dfs_postorder_nodes(active.reverse(copy=False), stage)
all_pipelines += reversed(list(nodes))
else:
all_pipelines += nx.dfs_postorder_nodes(active, stage)

steps = []
for stage in all_pipelines:
if stage not in steps:
# NOTE: order of steps still matters for single_item
if single_item and stage not in stages:
continue

steps.append(stage)

return steps
def _remove_frozen_stages(graph: "DiGraph") -> "DiGraph":
g = cast("DiGraph", graph.copy())
for stage in graph:
if stage.frozen:
# NOTE: disconnect frozen stage from its dependencies
g.remove_edges_from(graph.out_edges(stage))
return g

0 comments on commit 9c6b5e7

Please sign in to comment.