-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement physical "Operation" abstraction (#108)
- Loading branch information
Showing
4 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pathlib | ||
from typing import Optional, Sequence | ||
|
||
from conductor.context import Context | ||
from conductor.execution.ops.operation import Operation | ||
from conductor.execution.task_state import TaskState | ||
from conductor.task_identifier import TaskIdentifier | ||
from conductor.task_types.base import TaskExecutionHandle | ||
|
||
|
||
class CombineOutputs(Operation): | ||
def __init__( | ||
self, | ||
*, | ||
initial_state: TaskState, | ||
identifier: TaskIdentifier, | ||
output_path: pathlib.Path, | ||
deps_output_paths: Sequence[pathlib.Path], | ||
) -> None: | ||
super().__init__(initial_state) | ||
self._identifier = identifier | ||
self._output_path = output_path | ||
self._deps_output_paths = deps_output_paths | ||
|
||
def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle: | ||
self._output_path.mkdir(parents=True, exist_ok=True) | ||
|
||
for dep_dir in self._deps_output_paths: | ||
if ( | ||
not dep_dir.is_dir() | ||
# Checks if the directory is empty | ||
or not any(True for _ in dep_dir.iterdir()) | ||
): | ||
continue | ||
copy_into = self._output_path / dep_dir.name | ||
# The base data may be large, so we use symlinks to avoid copying. | ||
copy_into.symlink_to(dep_dir, target_is_directory=True) | ||
|
||
return TaskExecutionHandle.from_sync_execution() | ||
|
||
def finish_execution(self, handle: TaskExecutionHandle, ctx: Context) -> None: | ||
# Nothing special needs to be done here. | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import List, Optional | ||
|
||
from conductor.context import Context | ||
from conductor.errors.base import ConductorError | ||
from conductor.execution.task_state import TaskState | ||
from conductor.task_types.base import TaskType, TaskExecutionHandle | ||
|
||
|
||
class Operation: | ||
""" | ||
Represents a physical unit of work to run. Conductor task types are | ||
converted to operations for execution. | ||
""" | ||
|
||
# NOTE: TaskState will be renamed to OperationState after the refactor. | ||
|
||
def __init__(self, initial_state: TaskState) -> None: | ||
self._state = initial_state | ||
self._stored_error: Optional[ConductorError] = None | ||
|
||
# A list of this operation's dependencies (i.e., a list of tasks that must | ||
# execute successfully before this task can execute). | ||
self._exe_deps: List["Operation"] = [] | ||
# The number of operations that still need to complete before this | ||
# operation can execute. This value is always less than or equal to | ||
# `len(self._exe_deps)`. | ||
self._waiting_on = 0 | ||
|
||
# A list of operations that have this operation as a dependency (i.e., a | ||
# list of tasks that cannot execute until this operation executes | ||
# successfully). | ||
self._deps_of: List["Operation"] = [] | ||
|
||
@property | ||
def associated_task(self) -> Optional[TaskType]: | ||
""" | ||
The task that is responsible for creating this operation, if any. | ||
""" | ||
return None | ||
|
||
# Execution-related methods. | ||
|
||
def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle: | ||
raise NotImplementedError | ||
|
||
def finish_execution(self, handle: TaskExecutionHandle, ctx: Context) -> None: | ||
raise NotImplementedError | ||
|
||
# Execution state methods. | ||
|
||
@property | ||
def state(self) -> TaskState: | ||
return self._state | ||
|
||
@property | ||
def exe_deps(self) -> List["Operation"]: | ||
return self._exe_deps | ||
|
||
@property | ||
def deps_of(self) -> List["Operation"]: | ||
return self._deps_of | ||
|
||
@property | ||
def stored_error(self) -> Optional[ConductorError]: | ||
return self._stored_error | ||
|
||
@property | ||
def waiting_on(self) -> int: | ||
return self._waiting_on | ||
|
||
def set_state(self, state: TaskState) -> None: | ||
self._state = state | ||
|
||
def add_exe_dep(self, exe_dep: "Operation") -> None: | ||
self._exe_deps.append(exe_dep) | ||
|
||
def add_dep_of(self, task: "Operation") -> None: | ||
self._deps_of.append(task) | ||
|
||
def store_error(self, error: ConductorError) -> None: | ||
self._stored_error = error | ||
|
||
def reset_waiting_on(self) -> None: | ||
self._waiting_on = len(self._exe_deps) | ||
|
||
def decrement_deps_of_waiting_on(self) -> None: | ||
for dep_of in self.deps_of: | ||
# pylint: disable=protected-access | ||
dep_of._decrement_waiting_on() | ||
|
||
def succeeded(self) -> bool: | ||
return ( | ||
self.state == TaskState.SUCCEEDED | ||
or self.state == TaskState.SUCCEEDED_CACHED | ||
) | ||
|
||
def not_yet_executed(self) -> bool: | ||
return self.state == TaskState.QUEUED | ||
|
||
def exe_deps_succeeded(self) -> bool: | ||
""" | ||
Returns true iff all dependent operations have executed successfully. | ||
""" | ||
return all(map(lambda task: task.succeeded(), self.exe_deps)) | ||
|
||
def _decrement_waiting_on(self) -> None: | ||
assert self._waiting_on > 0 | ||
self._waiting_on -= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import os | ||
import pathlib | ||
import signal | ||
import subprocess | ||
import sys | ||
from typing import Optional, Sequence | ||
|
||
from conductor.config import ( | ||
OUTPUT_ENV_VARIABLE_NAME, | ||
DEPS_ENV_VARIABLE_NAME, | ||
DEPS_ENV_PATH_SEPARATOR, | ||
TASK_NAME_ENV_VARIABLE_NAME, | ||
STDOUT_LOG_FILE, | ||
STDERR_LOG_FILE, | ||
EXP_ARGS_JSON_FILE_NAME, | ||
EXP_OPTION_JSON_FILE_NAME, | ||
SLOT_ENV_VARIABLE_NAME, | ||
) | ||
from conductor.context import Context | ||
from conductor.errors import ( | ||
TaskFailed, | ||
TaskNonZeroExit, | ||
ConductorAbort, | ||
) | ||
from conductor.execution.ops.operation import Operation | ||
from conductor.execution.task_state import TaskState | ||
from conductor.execution.version_index import Version | ||
from conductor.task_types.base import TaskExecutionHandle | ||
from conductor.task_identifier import TaskIdentifier | ||
from conductor.utils.output_handler import RecordType, OutputHandler | ||
from conductor.utils.run_arguments import RunArguments | ||
from conductor.utils.run_options import RunOptions | ||
|
||
|
||
class RunTaskExecutable(Operation): | ||
def __init__( | ||
self, | ||
*, | ||
initial_state: TaskState, | ||
identifier: TaskIdentifier, | ||
run: str, | ||
args: RunArguments, | ||
options: RunOptions, | ||
working_path: pathlib.Path, | ||
output_path: pathlib.Path, | ||
deps_output_paths: Sequence[pathlib.Path], | ||
record_output: bool, | ||
version_to_record: Optional[Version], | ||
serialize_args_options: bool, | ||
parallelizable: bool, | ||
) -> None: | ||
super().__init__(initial_state) | ||
self._identifier = identifier | ||
self._args = args | ||
self._options = options | ||
self._run = " ".join( | ||
[run, self._args.serialize_cmdline(), self._options.serialize_cmdline()] | ||
) | ||
self._working_path = working_path | ||
self._output_path = output_path | ||
self._deps_output_paths = deps_output_paths | ||
self._record_output = record_output | ||
self._version_to_record = version_to_record | ||
self._serialize_args_options = serialize_args_options | ||
self._parallelizable = parallelizable | ||
|
||
def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle: | ||
try: | ||
self._output_path.mkdir(parents=True, exist_ok=True) | ||
|
||
env_vars = { | ||
**os.environ, | ||
OUTPUT_ENV_VARIABLE_NAME: str(self._output_path), | ||
DEPS_ENV_VARIABLE_NAME: DEPS_ENV_PATH_SEPARATOR.join( | ||
map(str, self._deps_output_paths) | ||
), | ||
TASK_NAME_ENV_VARIABLE_NAME: self._identifier.name, | ||
} | ||
if slot is not None: | ||
env_vars[SLOT_ENV_VARIABLE_NAME] = str(slot) | ||
|
||
if self._record_output: | ||
if slot is None: | ||
record_type = RecordType.Teed | ||
else: | ||
record_type = RecordType.OnlyLogged | ||
else: | ||
record_type = RecordType.NotRecorded | ||
|
||
stdout_output = OutputHandler( | ||
self._output_path / STDOUT_LOG_FILE, record_type | ||
) | ||
stderr_output = OutputHandler( | ||
self._output_path / STDERR_LOG_FILE, record_type | ||
) | ||
|
||
process = subprocess.Popen( | ||
[self._run], | ||
shell=True, | ||
cwd=self._working_path, | ||
executable="/bin/bash", | ||
stdout=stdout_output.popen_arg(), | ||
stderr=stderr_output.popen_arg(), | ||
env=env_vars, | ||
start_new_session=True, | ||
) | ||
|
||
stdout_output.maybe_tee(process.stdout, sys.stdout, ctx) | ||
stderr_output.maybe_tee(process.stderr, sys.stderr, ctx) | ||
|
||
handle = TaskExecutionHandle.from_async_process(pid=process.pid) | ||
handle.stdout = stdout_output | ||
handle.stderr = stderr_output | ||
return handle | ||
|
||
except ConductorAbort: | ||
# Send SIGTERM to the entire process group (i.e., the subprocess | ||
# and its child processes). | ||
if process is not None: | ||
group_id = os.getpgid(process.pid) | ||
if group_id >= 0: | ||
os.killpg(group_id, signal.SIGTERM) | ||
if self._record_output: | ||
ctx.tee_processor.shutdown() | ||
raise | ||
|
||
except OSError as ex: | ||
raise TaskFailed(task_identifier=self._identifier).add_extra_context( | ||
str(ex) | ||
) | ||
|
||
def finish_execution(self, handle: "TaskExecutionHandle", ctx: Context) -> None: | ||
assert handle.stdout is not None | ||
assert handle.stderr is not None | ||
handle.stdout.finish() | ||
handle.stderr.finish() | ||
|
||
assert handle.returncode is not None | ||
if handle.returncode != 0: | ||
raise TaskNonZeroExit( | ||
task_identifier=self._identifier, code=handle.returncode | ||
) | ||
|
||
if self._serialize_args_options: | ||
if not self._args.empty(): | ||
self._args.serialize_json(self._output_path / EXP_ARGS_JSON_FILE_NAME) | ||
if not self._options.empty(): | ||
self._options.serialize_json( | ||
self._output_path / EXP_OPTION_JSON_FILE_NAME | ||
) | ||
|
||
if self._version_to_record is not None: | ||
ctx.version_index.insert_output_version( | ||
self._identifier, self._version_to_record | ||
) | ||
ctx.version_index.commit_changes() |