Skip to content

Commit

Permalink
Make it possible to restart a pipeline with only unsuccessful jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
sapetnioc committed Dec 8, 2023
1 parent d084072 commit 3bbc920
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 23 deletions.
2 changes: 1 addition & 1 deletion capsul/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def job_parameters_from_values(self, job_dict, parameters_values):
result[k] = parameters_values[i]
return result

def successful_node_paths(self, engine_id, execution_id):
def failed_node_paths(self, engine_id, execution_id):
raise NotImplementedError

def print_execution_report(self, report, file=sys.stdout):
Expand Down
4 changes: 2 additions & 2 deletions capsul/database/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,9 @@ def execution_report_json(self, engine_id, execution_id):

return result

def successful_node_paths(self, engine_id, execution_id):
def failed_node_paths(self, engine_id, execution_id):
execution_key = f"capsul:{engine_id}:{execution_id}"
failed = json.loads(self.redis.hget(execution_key, "done"))
failed = json.loads(self.redis.hget(execution_key, "failed"))
for job_uuid in failed:
job = json.loads(
self.redis.hget(f"capsul:{engine_id}:{execution_id}", f"job:{job_uuid}")
Expand Down
28 changes: 14 additions & 14 deletions capsul/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..execution_context import CapsulWorkflow, ExecutionContext
from ..config.configuration import ModuleConfiguration
from ..database import engine_database
from ..api import Pipeline


def execution_context(engine_label, engine_config, executable):
Expand Down Expand Up @@ -413,21 +414,20 @@ def run(self, executable, timeout=None, print_report=False, debug=False, **kwarg
def prepare_pipeline_for_retry(self, pipeline, execution_id):
"""Modify a pipeline given a previous execution to select only the nodes that
weren't successful. Running the pipeline after this step will retry the
execution of failed jobs. This method adds (or modifies if it exists) an
unselected pipeline step called "successfully_executed" containing all nodes
that were successfully executed.
execution of failed jobs. This method sets a `self._enabled_nodes` attribute
containing the list of active jobs. I such an attribute exists and is not
empty, not job is created for any node outside this list.
"""
successful_nodes = []
for path in self.database.successful_node_paths(self.engine_id, execution_id):
successful_nodes.append(pipeline.node_from_path(path).name)
step_field = None
if pipeline.field("pipeline_steps"):
step_field = pipeline.pipeline_steps.fields("successfully_executed")
if step_field is None:
pipeline.add_pipeline_step("successfully_executed", successful_nodes, False)
else:
step_field.nodes = successful_nodes
setattr(pipeline.pipeline_steps, "successfully_executed", False)
# Parse successful nodes in previous execution and set the corresponding
# "successfully_executed" steps.
enabled_nodes = set()
for path in self.database.failed_node_paths(self.engine_id, execution_id):
node = pipeline
for i in path[:-1]:
node = node.nodes[i]
failed_node = node.nodes[path[-1]]
enabled_nodes.add(failed_node)
pipeline._enabled_nodes = enabled_nodes or None


class Workers(Controller):
Expand Down
17 changes: 12 additions & 5 deletions capsul/execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def __init__(self, executable, create_output_dirs=True, priority=None, debug=Fal
jobs_per_process = {}
process_chronology = {}
processes_proxies = {}
enabled_nodes = None
if isinstance(executable, (Pipeline, ProcessIteration)):
pipeline_tools.propagate_meta(executable)
if isinstance(executable, Pipeline):
enabled_nodes = executable.enabled_pipeline_nodes()
job_parameters = self._create_jobs(
top_parameters=top_parameters,
executable=executable,
Expand All @@ -128,6 +131,7 @@ def __init__(self, executable, create_output_dirs=True, priority=None, debug=Fal
parameters_location=[],
process_iterations={},
disabled=False,
enabled_nodes=enabled_nodes,
priority=priority,
debug=debug,
)
Expand All @@ -145,7 +149,7 @@ def __init__(self, executable, create_output_dirs=True, priority=None, debug=Fal
aj["wait_for"].add(before_job)
bj = self.jobs[before_job]
if bj["disabled"]:
bj["waited_by"].add(after_job)
bj.setdefault("waited_by", set()).add(after_job)

# Resolve disabled jobs
disabled_jobs = [
Expand All @@ -161,14 +165,14 @@ def __init__(self, executable, create_output_dirs=True, priority=None, debug=Fal
else:
wait_for.add(job)
waited_by = set()
stack = list(disabled_job[1]["waited_by"])
stack = list(disabled_job[1].get("waited_by", ()))
while stack:
job = stack.pop(0)
if self.jobs[job]["disabled"]:
stack.extend(self.jobs[job]["waited_by"])
stack.extend(self.jobs[job].get("waited_by", ()))
else:
waited_by.add(job)
for job in disabled_job[1]["waited_by"]:
for job in disabled_job[1].get("waited_by", ()):
self.jobs[job]["wait_for"].remove(disabled_job[0])
del self.jobs[disabled_job[0]]

Expand Down Expand Up @@ -276,6 +280,7 @@ def _create_jobs(
parameters_location,
process_iterations,
disabled,
enabled_nodes,
priority=None,
debug=None,
):
Expand Down Expand Up @@ -311,6 +316,7 @@ def _create_jobs(
parameters_location=parameters_location + ["nodes", node_name],
process_iterations=process_iterations,
disabled=disabled or node in disabled_nodes,
enabled_nodes=enabled_nodes,
priority=priority,
debug=debug,
)
Expand Down Expand Up @@ -460,6 +466,7 @@ def _create_jobs(
+ ["_iterations", str(iteration_index)],
process_iterations=process_iterations,
disabled=disabled,
enabled_nodes=enabled_nodes,
priority=new_priority,
debug=debug,
)
Expand Down Expand Up @@ -504,7 +511,7 @@ def _create_jobs(
else:
job = {
"uuid": job_uuid,
"disabled": False,
"disabled": enabled_nodes and process not in enabled_nodes,
"wait_for": set(),
"process": process.json(include_parameters=False),
"parameters_location": parameters_location,
Expand Down
6 changes: 5 additions & 1 deletion capsul/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,6 @@ def add_link(self, link, weak_link=False, allow_export=False):
source_node is self.pipeline_node
and source_plug_name not in source_node.plugs
):
print(dest_node_name, dest_node, dest_node.plugs.keys())
self.export_parameter(dest_node_name, dest_plug_name, source_plug_name)
return
elif (
Expand Down Expand Up @@ -2164,6 +2163,11 @@ def disabled_pipeline_steps_nodes(self):
disabled_nodes.update(self.nodes[node] for node in nodes)
return disabled_nodes

def enabled_pipeline_nodes(self):
"""Restrict the nodes to execute to the returned list. This method returns
either None or a non empty set of nodes."""
return getattr(self, "_enabled_nodes", None)

def get_pipeline_step_nodes(self, step_name):
"""Get the nodes in the given pipeline step"""
return self.pipeline_steps.field(step_name).nodes
Expand Down
155 changes: 155 additions & 0 deletions capsul/test/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import tempfile
from capsul.api import Capsul, Process, Pipeline
from soma.controller import field, File, undefined


class ControlledFailure(Process):
file: field(type_=File, read=True, write=True, optional=True)
fail_count: int = 0
value: field(type_=str, optional=True)
input: field(type_=File, optional=True)
output: field(type_=File, write=True)

def execute(self, context):
if self.file is not undefined and self.fail_count:
if os.path.exists(self.file):
with open(self.file) as f:
count = f.read()
count = int(count) if count else 0
else:
count = 0
count += 1
with open(self.file, "w") as f:
f.write(str(count))
if self.fail_count >= count:
raise Exception(
f"Process run count = {count} but failure count = {self.fail_count}"
)
result = []
if self.input is not undefined:
with open(self.input) as input:
result.append(input.read())
if self.value is not undefined:
result.append(self.value)
with open(self.output, "w") as output:
output.write("\n".join(result))


class PipelineToRestart(Pipeline):
def pipeline_definition(self):
self.add_process(
"initial_value",
ControlledFailure,
do_not_export=["file", "fail_count", "value"],
)
self.add_process(
"successful",
ControlledFailure,
do_not_export=["file", "fail_count", "value"],
)
self["successful.value"] = "successful"
self.add_process(
"must_restart", ControlledFailure, do_not_export=["fail_count", "value"]
)
self["must_restart.value"] = "must_restart"
self["must_restart.fail_count"] = 1
self.add_process(
"final_value",
ControlledFailure,
do_not_export=["file", "fail_count", "value"],
)
self["final_value.value"] = "final_value"

self.export_parameter("initial_value", "value", "initial_value")
self.add_link("initial_value.output->successful.input")
self.add_link("successful.output->must_restart.input")
self.add_link("must_restart.output->final_value.input")
self.export_parameter("initial_value", "output", allow_existing_plug=True)
self.export_parameter("successful", "output", allow_existing_plug=True)
self.export_parameter("must_restart", "output", allow_existing_plug=True)
self.export_parameter("final_value", "output", allow_existing_plug=True)


class SubPipelineToRestart(Pipeline):
def pipeline_definition(self):
self.add_process("sub1", PipelineToRestart, do_not_export=["initial_value"])
self["sub1.initial_value"] = "initial_value_1"
self.add_process("sub2", PipelineToRestart, do_not_export=["initial_value"])
self["sub2.initial_value"] = "initial_value_2"
self.add_link("sub1.output->sub2.input")
self.export_parameter("sub1", "file", "file1")
self.export_parameter("sub2", "file", "file2")
self.export_parameter("sub1", "output", allow_existing_plug=True)
self.export_parameter("sub2", "output", allow_existing_plug=True)


def test_retry_pipeline():
executable = Capsul.executable(PipelineToRestart)
tmp_failure = tempfile.NamedTemporaryFile()
tmp_result = tempfile.NamedTemporaryFile()
executable.initial_value = "initial_value"
executable.file = tmp_failure.name
executable.output = tmp_result.name

with Capsul().engine() as engine:
engine.assess_ready_to_start(executable)
execution_id = engine.start(executable)
engine.wait(execution_id, timeout=30)
error = engine.database.error(engine.engine_id, execution_id)
with open(executable.output) as f:
result = f.read()
assert error == "Some jobs failed"
assert result == "initial_value\nsuccessful"
engine.prepare_pipeline_for_retry(executable, execution_id)
execution_id = engine.start(executable)
engine.wait(execution_id, timeout=30)
error = engine.database.error(engine.engine_id, execution_id)
with open(executable.output) as f:
result = f.read()
assert error == None
assert result == "initial_value\nsuccessful\nmust_restart\nfinal_value"
engine.raise_for_status(execution_id)


def test_retry_sub_pipeline():
executable = Capsul.executable(SubPipelineToRestart)
tmp_failure1 = tempfile.NamedTemporaryFile()
tmp_failure2 = tempfile.NamedTemporaryFile()
tmp_result = tempfile.NamedTemporaryFile()
executable.file1 = tmp_failure1.name
executable.file2 = tmp_failure2.name
executable.output = tmp_result.name

with Capsul().engine() as engine:
engine.assess_ready_to_start(executable)
execution_id = engine.start(executable)
engine.wait(execution_id, timeout=30)
error = engine.database.error(engine.engine_id, execution_id)
with open(executable.output) as f:
result = f.read()
assert error == "Some jobs failed"
assert result == "initial_value_1\nsuccessful"
engine.prepare_pipeline_for_retry(executable, execution_id)
execution_id = engine.start(executable)
engine.wait(execution_id, timeout=30)
error = engine.database.error(engine.engine_id, execution_id)
with open(executable.output) as f:
result = f.read()
assert error == "Some jobs failed"
assert (
result
== "initial_value_1\nsuccessful\nmust_restart\nfinal_value\ninitial_value_2\nsuccessful"
)
engine.prepare_pipeline_for_retry(executable, execution_id)
execution_id = engine.start(executable)
engine.wait(execution_id, timeout=30)
error = engine.database.error(engine.engine_id, execution_id)
with open(executable.output) as f:
result = f.read()
assert error == None
assert (
result
== "initial_value_1\nsuccessful\nmust_restart\nfinal_value\ninitial_value_2\nsuccessful\nmust_restart\nfinal_value"
)
engine.raise_for_status(execution_id)

0 comments on commit 3bbc920

Please sign in to comment.