Skip to content

Commit

Permalink
Implement physical "Operation" abstraction (#108)
Browse files Browse the repository at this point in the history
This is part of #99. Also, once used, this will address #105.
  • Loading branch information
geoffxy authored Nov 2, 2024
1 parent e7b30a9 commit cf5096d
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 0 deletions.
Empty file.
43 changes: 43 additions & 0 deletions src/conductor/execution/ops/combine_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pathlib
from typing import Optional, Sequence

from conductor.context import Context
from conductor.execution.ops.operation import Operation
from conductor.execution.task_state import TaskState
from conductor.task_identifier import TaskIdentifier
from conductor.task_types.base import TaskExecutionHandle


class CombineOutputs(Operation):
def __init__(
self,
*,
initial_state: TaskState,
identifier: TaskIdentifier,
output_path: pathlib.Path,
deps_output_paths: Sequence[pathlib.Path],
) -> None:
super().__init__(initial_state)
self._identifier = identifier
self._output_path = output_path
self._deps_output_paths = deps_output_paths

def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle:
self._output_path.mkdir(parents=True, exist_ok=True)

for dep_dir in self._deps_output_paths:
if (
not dep_dir.is_dir()
# Checks if the directory is empty
or not any(True for _ in dep_dir.iterdir())
):
continue
copy_into = self._output_path / dep_dir.name
# The base data may be large, so we use symlinks to avoid copying.
copy_into.symlink_to(dep_dir, target_is_directory=True)

return TaskExecutionHandle.from_sync_execution()

def finish_execution(self, handle: TaskExecutionHandle, ctx: Context) -> None:
# Nothing special needs to be done here.
pass
108 changes: 108 additions & 0 deletions src/conductor/execution/ops/operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List, Optional

from conductor.context import Context
from conductor.errors.base import ConductorError
from conductor.execution.task_state import TaskState
from conductor.task_types.base import TaskType, TaskExecutionHandle


class Operation:
"""
Represents a physical unit of work to run. Conductor task types are
converted to operations for execution.
"""

# NOTE: TaskState will be renamed to OperationState after the refactor.

def __init__(self, initial_state: TaskState) -> None:
self._state = initial_state
self._stored_error: Optional[ConductorError] = None

# A list of this operation's dependencies (i.e., a list of tasks that must
# execute successfully before this task can execute).
self._exe_deps: List["Operation"] = []
# The number of operations that still need to complete before this
# operation can execute. This value is always less than or equal to
# `len(self._exe_deps)`.
self._waiting_on = 0

# A list of operations that have this operation as a dependency (i.e., a
# list of tasks that cannot execute until this operation executes
# successfully).
self._deps_of: List["Operation"] = []

@property
def associated_task(self) -> Optional[TaskType]:
"""
The task that is responsible for creating this operation, if any.
"""
return None

# Execution-related methods.

def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle:
raise NotImplementedError

def finish_execution(self, handle: TaskExecutionHandle, ctx: Context) -> None:
raise NotImplementedError

# Execution state methods.

@property
def state(self) -> TaskState:
return self._state

@property
def exe_deps(self) -> List["Operation"]:
return self._exe_deps

@property
def deps_of(self) -> List["Operation"]:
return self._deps_of

@property
def stored_error(self) -> Optional[ConductorError]:
return self._stored_error

@property
def waiting_on(self) -> int:
return self._waiting_on

def set_state(self, state: TaskState) -> None:
self._state = state

def add_exe_dep(self, exe_dep: "Operation") -> None:
self._exe_deps.append(exe_dep)

def add_dep_of(self, task: "Operation") -> None:
self._deps_of.append(task)

def store_error(self, error: ConductorError) -> None:
self._stored_error = error

def reset_waiting_on(self) -> None:
self._waiting_on = len(self._exe_deps)

def decrement_deps_of_waiting_on(self) -> None:
for dep_of in self.deps_of:
# pylint: disable=protected-access
dep_of._decrement_waiting_on()

def succeeded(self) -> bool:
return (
self.state == TaskState.SUCCEEDED
or self.state == TaskState.SUCCEEDED_CACHED
)

def not_yet_executed(self) -> bool:
return self.state == TaskState.QUEUED

def exe_deps_succeeded(self) -> bool:
"""
Returns true iff all dependent operations have executed successfully.
"""
return all(map(lambda task: task.succeeded(), self.exe_deps))

def _decrement_waiting_on(self) -> None:
assert self._waiting_on > 0
self._waiting_on -= 1
156 changes: 156 additions & 0 deletions src/conductor/execution/ops/run_task_executable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import pathlib
import signal
import subprocess
import sys
from typing import Optional, Sequence

from conductor.config import (
OUTPUT_ENV_VARIABLE_NAME,
DEPS_ENV_VARIABLE_NAME,
DEPS_ENV_PATH_SEPARATOR,
TASK_NAME_ENV_VARIABLE_NAME,
STDOUT_LOG_FILE,
STDERR_LOG_FILE,
EXP_ARGS_JSON_FILE_NAME,
EXP_OPTION_JSON_FILE_NAME,
SLOT_ENV_VARIABLE_NAME,
)
from conductor.context import Context
from conductor.errors import (
TaskFailed,
TaskNonZeroExit,
ConductorAbort,
)
from conductor.execution.ops.operation import Operation
from conductor.execution.task_state import TaskState
from conductor.execution.version_index import Version
from conductor.task_types.base import TaskExecutionHandle
from conductor.task_identifier import TaskIdentifier
from conductor.utils.output_handler import RecordType, OutputHandler
from conductor.utils.run_arguments import RunArguments
from conductor.utils.run_options import RunOptions


class RunTaskExecutable(Operation):
def __init__(
self,
*,
initial_state: TaskState,
identifier: TaskIdentifier,
run: str,
args: RunArguments,
options: RunOptions,
working_path: pathlib.Path,
output_path: pathlib.Path,
deps_output_paths: Sequence[pathlib.Path],
record_output: bool,
version_to_record: Optional[Version],
serialize_args_options: bool,
parallelizable: bool,
) -> None:
super().__init__(initial_state)
self._identifier = identifier
self._args = args
self._options = options
self._run = " ".join(
[run, self._args.serialize_cmdline(), self._options.serialize_cmdline()]
)
self._working_path = working_path
self._output_path = output_path
self._deps_output_paths = deps_output_paths
self._record_output = record_output
self._version_to_record = version_to_record
self._serialize_args_options = serialize_args_options
self._parallelizable = parallelizable

def start_execution(self, ctx: Context, slot: Optional[int]) -> TaskExecutionHandle:
try:
self._output_path.mkdir(parents=True, exist_ok=True)

env_vars = {
**os.environ,
OUTPUT_ENV_VARIABLE_NAME: str(self._output_path),
DEPS_ENV_VARIABLE_NAME: DEPS_ENV_PATH_SEPARATOR.join(
map(str, self._deps_output_paths)
),
TASK_NAME_ENV_VARIABLE_NAME: self._identifier.name,
}
if slot is not None:
env_vars[SLOT_ENV_VARIABLE_NAME] = str(slot)

if self._record_output:
if slot is None:
record_type = RecordType.Teed
else:
record_type = RecordType.OnlyLogged
else:
record_type = RecordType.NotRecorded

stdout_output = OutputHandler(
self._output_path / STDOUT_LOG_FILE, record_type
)
stderr_output = OutputHandler(
self._output_path / STDERR_LOG_FILE, record_type
)

process = subprocess.Popen(
[self._run],
shell=True,
cwd=self._working_path,
executable="/bin/bash",
stdout=stdout_output.popen_arg(),
stderr=stderr_output.popen_arg(),
env=env_vars,
start_new_session=True,
)

stdout_output.maybe_tee(process.stdout, sys.stdout, ctx)
stderr_output.maybe_tee(process.stderr, sys.stderr, ctx)

handle = TaskExecutionHandle.from_async_process(pid=process.pid)
handle.stdout = stdout_output
handle.stderr = stderr_output
return handle

except ConductorAbort:
# Send SIGTERM to the entire process group (i.e., the subprocess
# and its child processes).
if process is not None:
group_id = os.getpgid(process.pid)
if group_id >= 0:
os.killpg(group_id, signal.SIGTERM)
if self._record_output:
ctx.tee_processor.shutdown()
raise

except OSError as ex:
raise TaskFailed(task_identifier=self._identifier).add_extra_context(
str(ex)
)

def finish_execution(self, handle: "TaskExecutionHandle", ctx: Context) -> None:
assert handle.stdout is not None
assert handle.stderr is not None
handle.stdout.finish()
handle.stderr.finish()

assert handle.returncode is not None
if handle.returncode != 0:
raise TaskNonZeroExit(
task_identifier=self._identifier, code=handle.returncode
)

if self._serialize_args_options:
if not self._args.empty():
self._args.serialize_json(self._output_path / EXP_ARGS_JSON_FILE_NAME)
if not self._options.empty():
self._options.serialize_json(
self._output_path / EXP_OPTION_JSON_FILE_NAME
)

if self._version_to_record is not None:
ctx.version_index.insert_output_version(
self._identifier, self._version_to_record
)
ctx.version_index.commit_changes()

0 comments on commit cf5096d

Please sign in to comment.