diff --git a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py index b89a5ffae..30af0d334 100644 --- a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py +++ b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py @@ -63,6 +63,7 @@ def test_slurm_workflow(): assert ".slurm_job" in content with open(os.path.join(path, "slurm_execute_experiment")) as f: content = f.read() + assert "scontrol show hostnames" in content assert "#SBATCH --gpus-per-task=1" in content with open(os.path.join(path, "query_job")) as f: content = f.read() diff --git a/lib/ramble/ramble/util/command_runner.py b/lib/ramble/ramble/util/command_runner.py index c290c6d03..f479e9499 100644 --- a/lib/ramble/ramble/util/command_runner.py +++ b/lib/ramble/ramble/util/command_runner.py @@ -20,14 +20,16 @@ class CommandRunner: Can be inherited to construct custom command runners. """ - def __init__(self, name=None, command=None, shell="bash", dry_run=False, path=None): + def __init__( + self, name=None, command=None, shell="bash", dry_run=False, path=None, required=True + ): """ Ensure required command is found in the path """ self.name = name self.dry_run = dry_run self.shell = shell - required = not self.dry_run + required = required and not self.dry_run try: if path is None: self.command = which(command, required=required) diff --git a/lib/ramble/ramble/wmkit.py b/lib/ramble/ramble/wmkit.py index 0d0ab3bc2..845442827 100644 --- a/lib/ramble/ramble/wmkit.py +++ b/lib/ramble/ramble/wmkit.py @@ -14,3 +14,8 @@ from ramble.language.shared_language import * from ramble.workflow_manager import WorkflowManagerBase + +from ramble.util.command_runner import ( + CommandRunner, + RunnerError, +) diff --git a/lib/ramble/ramble/workflow_manager.py b/lib/ramble/ramble/workflow_manager.py index ffbc1238f..b02dbe876 100644 --- a/lib/ramble/ramble/workflow_manager.py +++ b/lib/ramble/ramble/workflow_manager.py @@ -38,6 +38,7 @@ def __init__(self, file_path): ramble.util.directives.define_directive_methods(self) self.app_inst = None + self.runner = None def set_application(self, app_inst): """Set a reference to the associated app_inst""" diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index b8758eaa8..6be44ed4a 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -28,6 +28,11 @@ class Slurm(WorkflowManagerBase): tags("workflow", "slurm") + def __init__(self, file_path): + super().__init__(file_path) + + self.runner = SlurmRunner() + workflow_manager_variable( name="partition", default="", @@ -54,23 +59,23 @@ class Slurm(WorkflowManagerBase): """.strip(), ) - render_content( - name="query_job", - content_tpl=rf""" + render_content(name="query_job", content_func="_slurm_query_script") + + def _slurm_query_script(self): + return rf""" #!/bin/bash {_ensure_job_id_snippet} -squeue -j $job_id - """.strip(), - ) +{self.runner.generate_query_command("$job_id")} + """.strip() - render_content( - name="cancel_job", - content_tpl=rf""" + render_content(name="cancel_job", content_func="_slurm_cancel_script") + + def _slurm_cancel_script(self): + return rf""" #!/bin/bash {_ensure_job_id_snippet} -scancel $job_id - """.strip(), - ) +{self.runner.generate_cancel_command("$job_id")} + """.strip() render_content( name="slurm_execute_experiment", content_func="_slurm_execute_script" @@ -112,8 +117,26 @@ def _slurm_execute_script(self): cd {{experiment_run_dir}} -scontrol show hostnames > {{experiment_run_dir}}/hostfile +{self.runner.generate_hostfile_command()} > {{experiment_run_dir}}/hostfile {{command}} """.strip() return content + + +class SlurmRunner(CommandRunner): + """Runner for executing slurm commands""" + + def __init__(self, dry_run=False): + super().__init__( + name="slurm", command="slurm", dry_run=dry_run, required=False + ) + + def generate_query_command(self, job_id): + return f"squeue -j {job_id}" + + def generate_cancel_command(self, job_id): + return f"scancel {job_id}" + + def generate_hostfile_command(self): + return "scontrol show hostnames"