Skip to content

Commit

Permalink
It works!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
maouw committed Sep 21, 2023
1 parent 2584d39 commit 6ee5c5d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 78 deletions.
63 changes: 39 additions & 24 deletions hyakvnc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
import re
import shlex
import subprocess
import time
from dataclasses import asdict
from pathlib import Path
from typing import Optional, Union

import pprint
from .config import HyakVncConfig
from .slurmutil import wait_for_job_status, get_job, SlurmJob, get_job_status
from .util import check_remote_pid_exists_and_port_open
from .util import check_remote_pid_exists_and_port_open, wait_for_file
from .version import VERSION


app_config = HyakVncConfig()


Expand All @@ -32,7 +34,7 @@ def get_apptainer_vnc_instances(read_apptainer_config: bool = False):
all_instance_json_files = app_dir.rglob(app_config.apptainer_instance_prefix + '*.json')

running_hyakvnc_json_files = {p: r.groupdict() for p in all_instance_json_files if (
r := re.match(rf'(?P<prefix>{app_config.apptainer_instance_prefix})(?P<jobid>\d+)-(?P<appinstance>.*)\.json',
r := re.match(rf'(?P<prefix>{app_config.apptainer_instance_prefix})-(?P<jobid>\d+)-(?P<appinstance>.*)\.json',
p.name))}
outs = []
# frr := re.search(r'\s+-rfbport\s+(?P<rfbport>\d+\b', fr)
Expand Down Expand Up @@ -121,34 +123,33 @@ def cmd_create(container_path: Union[str, Path], dry_run=False) -> SlurmJob:
assert container_path.exists(), f"Container path {container_path} does not exist"
assert container_path.is_file(), f"Container path {container_path} is not a file"

cmds = ["sbatch", "--parsable", "--job-name", app_config.job_prefix + container_name]
cmds = ["sbatch", "--parsable", "--job-name", app_config.job_prefix + '-' + container_name]

sbatch_optinfo = {"account": "-A", "partition": "-p", "gpus": "-G", "timelimit": "--time", "mem": "--mem",
"cpus": "-c"}
sbatch_options = [item for pair in [(sbatch_optinfo[k], v) for k, v in asdict(app_config).items() if
sbatch_options = [str(item )for pair in [(sbatch_optinfo[k], v) for k, v in asdict(app_config).items() if
k in sbatch_optinfo.keys() and v is not None] for item in pair]

cmds += sbatch_options

apptainer_env_vars_quoted = [f"{k}={shlex.quote(v)}" for k, v in app_config.apptainer_env_vars.items()]
apptainer_env_vars_string = "" if apptainer_env_vars_quoted else (" ".join(apptainer_env_vars_quoted) + " ")
apptainer_env_vars_string = "" if not apptainer_env_vars_quoted else (" ".join(apptainer_env_vars_quoted) + " ")

# needs to match rf'(?P<prefix>{app_config.apptainer_instance_prefix})(?P<jobid>\d+)-(?P<appinstance>.*)'):
apptainer_instance_name = rf"{app_config.apptainer_instance_prefix}\$SLURM_JOB_ID-{container_name}"

apptainer_cmd = apptainer_env_vars_string + rf"apptainer instance start {container_path} {apptainer_instance_name}"
apptainer_cmd_with_rest = rf"{apptainer_cmd} && while true; do sleep 10; done"
apptainer_instance_name = f"{app_config.apptainer_instance_prefix}-$SLURM_JOB_ID-{container_name}"

cmds += ["--wrap", apptainer_cmd_with_rest]
apptainer_cmd = f"apptainer instance start {container_path} {apptainer_instance_name}"
apptainer_cmd_with_rest = apptainer_env_vars_string + f"{apptainer_cmd} && while true; do sleep 10; done"
cmds += ["--wrap",apptainer_cmd_with_rest]

# Launch sbatch process:
logging.info("Launching sbatch process with command:\n" + " ".join(cmds))
logging.info("Launching sbatch process with command:\n" + repr(cmds))

if dry_run:
print(f"Woud have run: {' '.join(cmds)}")
print(f"Would have run: {' '.join(cmds)}")
return

res = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
res = subprocess.run(cmds, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
if res.returncode != 0:
raise RuntimeError(f"Could not launch sbatch job:\n{res.stderr}")

Expand All @@ -174,7 +175,20 @@ def cmd_create(container_path: Union[str, Path], dry_run=False) -> SlurmJob:
raise RuntimeError(f"Could not get job {job_id} after it started running")

logging.info(f"Job {job_id} is now running")
return job

instance_file = '/mmfs1/home/altan/.apptainer/instances/a/g3071/altan/hyakvnc-14673571-ubuntu22.04_xubuntu.err'
real_instance_name = f"{app_config.apptainer_instance_prefix}-{job.job_id}-{container_name}"
instance_file = (Path(app_config.apptainer_config_dir) / 'instances' / 'app' / job.node_list[0] / job.user_name / real_instance_name / f"{real_instance_name}.json").expanduser()

if wait_for_file(str(instance_file), timeout=app_config.sbatch_post_timeout):
time.sleep(10) # sleep to wait for apptainer to actually start vncserver <FIXME>
instances = { instance["name"]: instance for instance in get_apptainer_vnc_instances() }
if real_instance_name not in instances:
raise TimeoutError(f"Could not find VNC session for job {job_id}")
instance = instances[real_instance_name]
print(get_openssh_connection_string_for_instance(instance, app_config.ssh_host))
else:
logging.info(f"Could not find VNC session for job {job_id}")



Expand All @@ -192,7 +206,7 @@ def cmd_stop(job_id: Optional[int] = None, stop_all: bool = False):

def cmd_status():
vnc_instances = get_apptainer_vnc_instances(read_apptainer_config=True)
print(json.dumps(vnc_instances, indent=2))
pprint.pp(json.dumps(vnc_instances, indent=2))


def create_arg_parser():
Expand Down Expand Up @@ -235,23 +249,21 @@ def create_arg_parser():
help='Kill specified VNC session, cancel its VNC job, and exit', type=int)

parser_stop_all = subparsers.add_parser('stop_all', help='Stop all VNC sessions and exit')
parser_print_config = subparsers.add_parser('print_config', help='Print app configuration and exit')

return parser


arg_parser = create_arg_parser()
args = arg_parser.parse_args()

os.environ.setdefault("HYAKVNC_LOG_LEVEL", "INFO")

if args.debug:
os.environ["HYAKVNC_LOG_LEVEL"] = "DEBUG"

log_level = logging.__dict__.get(os.environ.setdefault("HYAKVNC_LOG_LEVEL", "INFO").upper(), logging.INFO)

log_format = '%(asctime)s - %(levelname)s - %(funcName)s() - %(message)s'

if log_level == logging.DEBUG:
log_format += " - %(pathname)s:%(lineno)d"

logging.basicConfig(level=log_level, format=log_format)
log_level = logging.__dict__.get(os.getenv("HYAKVNC_LOG_LEVEL").upper(), logging.INFO)
logging.getLogger().setLevel(log_level)

if args.print_version:
print(VERSION)
Expand All @@ -270,4 +282,7 @@ def create_arg_parser():
if args.command == 'stop_all':
cmd_stop(stop_all=True)

if args.command == 'print_config':
pprint.pp(asdict(app_config), indent=2, width=79)

exit(0)
29 changes: 13 additions & 16 deletions hyakvnc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .slurmutil import get_default_cluster, get_default_account, get_default_partition

def get_first_env(env_vars: Iterable[str], default: Optional[str] = None, allow_blank: bool = True) -> str:
def get_first_env(env_vars: Iterable[str], default: Optional[str] = None, allow_blank: bool = False) -> str:
"""
Gets the first environment variable that is set, or the default value if none are set.
:param env_vars: list of environment variables to check
Expand All @@ -16,7 +16,7 @@ def get_first_env(env_vars: Iterable[str], default: Optional[str] = None, allow_
:return: the first environment variable that is set, or the default value if none are set
"""

no_match = [None] if allow_blank else ["None", ""]
no_match = [None] if allow_blank else [None, ""]
for x in env_vars:
if (res := os.environ.get(x)) not in no_match:
logging.debug(rf"Using environment variable {x}={res}")
Expand All @@ -31,13 +31,13 @@ class HyakVncConfig:
Configuration for hyakvnc.
"""
# script attributes
job_prefix: str = "hyakvnc-" # prefix for job names
job_prefix: str = "hyakvnc" # prefix for job names
# apptainer config
apptainer_bin: str = "apptainer" # path to apptainer binary
apptainer_config_dir: str = "~/.apptainer" # directory where apptainer config files are stored
apptainer_instance_prefix: str = "hyakvnc-" # prefix for apptainer instance names
apptainer_use_writable_tmpfs: Optional[bool] = None # whether to mount a writable tmpfs for apptainer instances
apptainer_cleanenv: Optional[bool] = None # whether to use clean environment for apptainer instances
apptainer_instance_prefix: str = "hyakvnc" # prefix for apptainer instance names
apptainer_use_writable_tmpfs: Optional[bool] = True # whether to mount a writable tmpfs for apptainer instances
apptainer_cleanenv: Optional[bool] = True # whether to use clean environment for apptainer instances
apptainer_set_bind_paths: Optional[
str] = None # comma-separated list of paths to bind mount for apptainer instances
apptainer_env_vars: Optional[dict[str]] = None # environment variables to set for apptainer instances
Expand All @@ -53,27 +53,24 @@ class HyakVncConfig:
cluster: Optional[str] = None # cluster to use for sbatch jobs | --clusters, SBATCH_CLUSTERS
gpus: Optional[str] = None # number of gpus to use for sbatch jobs | -G, --gpus, SBATCH_GPUS
timelimit: Optional[str] = None # time limit for sbatch jobs | --time, SBATCH_TIMELIMIT
mem: Optional[str] = None # memory limit for sbatch jobs | --mem, SBATCH_MEM
cpus: Optional[int] = None # number of cpus to use for sbatch jobs | -c, --cpus-per-task (not settable by env var)
mem: Optional[str] = "8G" # memory limit for sbatch jobs | --mem, SBATCH_MEM
cpus: Optional[int] = 4 # number of cpus to use for sbatch jobs | -c, --cpus-per-task (not settable by env var)

def __post_init__(self) -> None:
def __post_init__(self):
"""
Post-initialization hook for HyakVncConfig. Sets default values for unset attributes.
:return: None
"""
self.cluster = self.cluster or get_first_env(["HYAKVNC_SLURM_CLUSTER", "SBATCH_CLUSTER"],
get_default_cluster(), allow_blank=False)
self.cluster = self.cluster or get_first_env(["HYAKVNC_SLURM_CLUSTER", "SBATCH_CLUSTER"], default=get_default_cluster())
self.account = self.account or get_first_env(["HYAKVNC_SLURM_ACCOUNT", "SBATCH_ACCOUNT"],
get_default_account(cluster=self.cluster),
allow_blank=False)
get_default_account(cluster=self.cluster))
self.partition = self.partition or get_first_env(["HYAKVNC_SLURM_PARTITION", "SBATCH_PARTITION"],
get_default_partition(cluster=self.cluster,
account=self.account),
allow_blank=False)
account=self.account))
self.gpus = self.gpus or get_first_env(["HYAKVNC_SLURM_GPUS", "SBATCH_GPUS"], None)
self.timelimit = self.timelimit or get_first_env(["HYAKVNC_SLURM_TIMELIMIT", "SBATCH_TIMELIMIT"], None)
self.mem = self.mem or get_first_env(["HYAKVNC_SLURM_MEM", "SBATCH_MEM"], None)
self.cpus = self.cpus or get_first_env(["HYAKVNC_SLURM_CPUS", "SBATCH_CPUS_PER_TASK"], None)
self.cpus = int(self.cpus or get_first_env(["HYAKVNC_SLURM_CPUS", "SBATCH_CPUS_PER_TASK"]))

self.apptainer_env_vars = self.apptainer_env_vars or dict()
all_apptainer_env_vars = {x: os.environ.get(x, "") for x in os.environ.keys() if
Expand Down
72 changes: 35 additions & 37 deletions hyakvnc/slurmutil.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import logging
import os
import subprocess
import time
Expand Down Expand Up @@ -90,28 +91,28 @@ def node_range_to_list(s: str) -> list[str]:
:return: list of SLURM nodes
:raises ValueError: if the node range could not be converted to a list of nodes
"""
output = subprocess.run(f"scontrol show hostnames {s}", stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
cmds = ["scontrol", "show", "hostnames", s]
output = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if output.returncode != 0:
raise ValueError(f"Could not convert node range '{s}' to list of nodes:\n{output.stderr}")
return output.stdout.rstrip().splitlines()


@dataclass
class SlurmJob:
job_id: int = field(metadata={"squeue_field": "JobID"})
job_name: str = field(metadata={"squeue_field": "JobName"})
account: str = field(metadata={"squeue_field": "Account"})
partition: str = field(metadata={"squeue_field": "Partition"})
user_name: str = field(metadata={"squeue_field": "UserName"})
state: str = field(metadata={"squeue_field": "State"})
time_used: str = field(metadata={"squeue_field": "TimeUsed"})
time_limit: str = field(metadata={"squeue_field": "TimeLimit"})
num_nodes: int = field(metadata={"squeue_field": "NumNodes"})
node_list: str = field(metadata={"squeue_field": "NodeList"})
command: str = field(metadata={"squeue_field": "Command"})
cpus_per_task: int = field(metadata={"squeue_field": "cpus-per-task"})
num_cpus: int = field(metadata={"squeue_field": "NumCPUs"})
min_memory: str = field(metadata={"squeue_field": "MinMemory"})
job_id: int = field(metadata={"squeue_field": "%i"})
job_name: str = field(metadata={"squeue_field": "%j"})
account: str = field(metadata={"squeue_field": "%a"})
partition: str = field(metadata={"squeue_field": "%P"})
user_name: str = field(metadata={"squeue_field": "%u"})
state: str = field(metadata={"squeue_field": "%T"})
time_used: str = field(metadata={"squeue_field": "%M"})
time_limit: str = field(metadata={"squeue_field": "%l"})
cpus: int = field(metadata={"squeue_field": "%C"})
min_memory: str = field(metadata={"squeue_field": "%m"})
num_nodes: int = field(metadata={"squeue_field": "%D"})
node_list: str = field(metadata={"squeue_field": "%N"})
command: str = field(metadata={"squeue_field": "%o"})

@staticmethod
def from_squeue_line(line: str, field_order=None) -> "SlurmJob":
Expand All @@ -125,43 +126,43 @@ def from_squeue_line(line: str, field_order=None) -> "SlurmJob":
valid_field_names = [x.name for x in fields(SlurmJob)]
if field_order is None:
field_order = valid_field_names

#
all_fields_dict = {field_order[i]: x for i, x in enumerate(line.split())}
field_dict = {k: v for k, v in all_fields_dict.items() if k in valid_field_names}

try:
field_dict["num_nodes"] = int(field_dict["num_nodes"])
except (ValueError, TypeError, KeyError):
field_dict["num_nodes"] = None
try:
field_dict["cpus_per_task"] = int(field_dict["cpus_per_task"])
except (ValueError, TypeError, KeyError):
field_dict["cpus_per_task"] = None
try:
field_dict["num_cpus"] = int(field_dict["num_cpus"])
except (ValueError, TypeError, KeyError):
field_dict["num_cpus"] = None

try:
field_dict["node_list"] = node_range_to_list(field_dict["node_list"])
field_dict["cpus"] = int(field_dict["cpus"])
except (ValueError, TypeError, KeyError):
field_dict["cpus"] = None

if field_dict.get("node_list") == "(null)":
field_dict["node_list"] = None
else:
try:
field_dict["node_list"] = node_range_to_list(field_dict["node_list"])
except (ValueError, TypeError, KeyError, FileNotFoundError):
logging.debug(f"Could not convert node range '{field_dict['node_list']}' to list of nodes")
field_dict["node_list"] = None

if field_dict.get("command") == "(null)":
field_dict["command"] = None

return SlurmJob(**field_dict)


def get_job(jobs: Optional[Union[int, list[int]]] = None,
user: Optional[str] = os.getlogin(),
cluster: Optional[str] = None,
field_names: Optional[Container[str]] = None
cluster: Optional[str] = None
) -> Union[SlurmJob, list[SlurmJob], None]:
"""
Gets the specified slurm job(s).
:param user: User to get jobs for
:param jobs: Job(s) to get
:param cluster: Cluster to get jobs for
:param field_names: Fields to get for jobs (defaults to all fields in SlurmJob)
:return: the specified slurm job(s) as a SlurmJob object or list of SlurmJobs, or None if no jobs were found
"""
cmds: list[str] = ['squeue', '--noheader']
Expand All @@ -175,16 +176,13 @@ def get_job(jobs: Optional[Union[int, list[int]]] = None,
if jobs:
if job_is_int:
jobs = [jobs]
else:
jobs = ','.join([str(x) for x in jobs])
cmds += ['--jobs', jobs]

slurm_job_fields = [f for f in fields(SlurmJob) if f.name in field_names]
assert len(slurm_job_fields) > 0, "Must specify at least one field to get for slurm jobs"
squeue_format_fields = ",".join([f.metadata.get("squeue_field", "") for f in slurm_job_fields])
jobs = ','.join([str(x) for x in jobs])
cmds += ['--jobs', jobs]

cmds += ['--Format', squeue_format_fields]
res = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
squeue_format_fields = "\t".join([f.metadata.get("squeue_field", "") for f in fields(SlurmJob)])
cmds += ['--format', squeue_format_fields]
res = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=False)
if res.returncode != 0:
raise ValueError(f"Could not get slurm jobs:\n{res.stderr}")

Expand Down
2 changes: 1 addition & 1 deletion hyakvnc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def wait_for_file(path: Union[Path, str], timeout: Optional[float] = None,

def check_remote_pid_exists_and_port_open(host: str, pid: int, port: int) -> bool:
cmd = f"ssh {host} ps -p {pid} && nc -z localhost {port}".split()
res = subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
res = subprocess.run(cmd, shell=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return res.returncode == 0

0 comments on commit 6ee5c5d

Please sign in to comment.