Skip to content

Commit

Permalink
Add a slurm command runner
Browse files Browse the repository at this point in the history
  • Loading branch information
linsword13 committed Dec 11, 2024
1 parent 9b5c83b commit 8ed5c0c
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_slurm_workflow():
assert ".slurm_job" in content
with open(os.path.join(path, "slurm_execute_experiment")) as f:
content = f.read()
assert "scontrol show hostnames" in content
assert "#SBATCH --gpus-per-task=1" in content
with open(os.path.join(path, "query_job")) as f:
content = f.read()
Expand Down
6 changes: 4 additions & 2 deletions lib/ramble/ramble/util/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class CommandRunner:
Can be inherited to construct custom command runners.
"""

def __init__(self, name=None, command=None, shell="bash", dry_run=False, path=None):
def __init__(
self, name=None, command=None, shell="bash", dry_run=False, path=None, required=True
):
"""
Ensure required command is found in the path
"""
self.name = name
self.dry_run = dry_run
self.shell = shell
required = not self.dry_run
required = required and not self.dry_run
try:
if path is None:
self.command = which(command, required=required)
Expand Down
5 changes: 5 additions & 0 deletions lib/ramble/ramble/wmkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
from ramble.language.shared_language import *

from ramble.workflow_manager import WorkflowManagerBase

from ramble.util.command_runner import (
CommandRunner,
RunnerError,
)
1 change: 1 addition & 0 deletions lib/ramble/ramble/workflow_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, file_path):
ramble.util.directives.define_directive_methods(self)

self.app_inst = None
self.runner = None

def set_application(self, app_inst):
"""Set a reference to the associated app_inst"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class Slurm(WorkflowManagerBase):

tags("workflow", "slurm")

def __init__(self, file_path):
super().__init__(file_path)

self.runner = SlurmRunner()

workflow_manager_variable(
name="partition",
default="",
Expand All @@ -54,23 +59,23 @@ class Slurm(WorkflowManagerBase):
""".strip(),
)

render_content(
name="query_job",
content_tpl=rf"""
render_content(name="query_job", content_func="_slurm_query_script")

def _slurm_query_script(self):
return rf"""
#!/bin/bash
{_ensure_job_id_snippet}
squeue -j $job_id
""".strip(),
)
{self.runner.generate_query_command("$job_id")}
""".strip()

render_content(
name="cancel_job",
content_tpl=rf"""
render_content(name="cancel_job", content_func="_slurm_cancel_script")

def _slurm_cancel_script(self):
return rf"""
#!/bin/bash
{_ensure_job_id_snippet}
scancel $job_id
""".strip(),
)
{self.runner.generate_cancel_command("$job_id")}
""".strip()

render_content(
name="slurm_execute_experiment", content_func="_slurm_execute_script"
Expand Down Expand Up @@ -112,8 +117,26 @@ def _slurm_execute_script(self):
cd {{experiment_run_dir}}
scontrol show hostnames > {{experiment_run_dir}}/hostfile
{self.runner.generate_hostfile_command()} > {{experiment_run_dir}}/hostfile
{{command}}
""".strip()

return content


class SlurmRunner(CommandRunner):
"""Runner for executing slurm commands"""

def __init__(self, dry_run=False):
super().__init__(
name="slurm", command="slurm", dry_run=dry_run, required=False
)

def generate_query_command(self, job_id):
return f"squeue -j {job_id}"

def generate_cancel_command(self, job_id):
return f"scancel {job_id}"

def generate_hostfile_command(self):
return "scontrol show hostnames"

0 comments on commit 8ed5c0c

Please sign in to comment.