Skip to content

Commit

Permalink
Add clear error type; Format the code
Browse files Browse the repository at this point in the history
  • Loading branch information
yuema137 committed Apr 8, 2024
1 parent 8b0195b commit db6af90
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions utilix/batchq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
"You may need to set SCRATCH_DIR manually in your .bashrc or .bash_profile."
)

PARTITIONS: List[str] = ["dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"]
PARTITIONS: List[str] = [
"dali",
"lgrandi",
"xenon1t",
"broadwl",
"kicp",
"caslake",
"build",
]
TMPDIR: Dict[str, str] = {
"dali": f"/dali/lgrandi/{USER}/tmp",
"lgrandi": os.path.join(SCRATCH_DIR, "tmp"),
Expand Down Expand Up @@ -56,6 +64,15 @@
"/dali/lgrandi/grid_proxy/xenon_service_proxy:/project2/lgrandi/grid_proxy/xenon_service_proxy",
]

class QOSNotFoundError(Exception):
"""
Provided qos is not found in the qos list
"""

class FormatError(Exception):
"""
Format of file is not correct
"""

def _make_executable(path: str) -> None:
"""
Expand Down Expand Up @@ -83,26 +100,29 @@ def _get_qos_list() -> List[str]:
"""
cmd = "sacctmgr show qos format=name -p"
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True, shell=True)
result = subprocess.run(
cmd, capture_output=True, text=True, check=True, shell=True
)
qos_list: List[str] = result.stdout.strip().split("\n")
qos_list = [qos[:-1] for qos in qos_list]
return qos_list
except subprocess.CalledProcessError as e:
print(f"An error occurred while executing sacctmgr: {e}")
return []


class JobSubmission(BaseModel):
"""
Class to generate and submit a job to the SLURM queue.
"""

jobstring: str = Field(..., description="The command to execute")
exclude_lc_nodes: bool = Field(True, description="Exclude the loosely coupled nodes")
log: str = Field("job.log", description="Where to store the log file of the job")
partition: Literal["dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"] = Field(
"xenon1t", description="Partition to submit the job to"
exclude_lc_nodes: bool = Field(
True, description="Exclude the loosely coupled nodes"
)
log: str = Field("job.log", description="Where to store the log file of the job")
partition: Literal[
"dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"
] = Field("xenon1t", description="Partition to submit the job to")
qos: str = Field("xenon1t", description="QOS to submit the job to")
account: str = Field("pi-lgrandi", description="Account to submit the job to")
jobname: str = Field("somejob", description="How to name this job")
Expand All @@ -120,14 +140,22 @@ class JobSubmission(BaseModel):
)
cpus_per_task: int = Field(1, description="CPUs requested for job")
hours: Optional[float] = Field(None, description="Max hours of a job")
node: Optional[str] = Field(None, description="Define a certain node to submit your job")
node: Optional[str] = Field(
None, description="Define a certain node to submit your job"
)
exclude_nodes: Optional[str] = Field(
None, description="Define a list of nodes which should be excluded from submission"
None,
description="Define a list of nodes which should be excluded from submission",
)
dependency: Optional[str] = Field(
None, description="Provide list of job ids to wait for before running this job"
)
verbose: bool = Field(False, description="Print the sbatch command before submitting")
verbose: bool = Field(
False, description="Print the sbatch command before submitting"
)
bypass_validation: List[str] = Field(
default_factory=list, description="List of parameters to bypass validation for"
)

# Check if there is any positional argument which is not allowed
def __new__(cls, *args, **kwargs):
Expand All @@ -140,7 +168,7 @@ def __new__(cls, *args, **kwargs):

def __init__(self, **kwargs):
super().__init__(**kwargs)

@validator("bind", pre=True, each_item=True)
def check_bind(cls, v: str) -> str:
"""
Expand Down Expand Up @@ -198,7 +226,9 @@ def check_qos(cls, v: str) -> str:
qos_list = _get_qos_list()
if v not in qos_list:
# Raise an error if the qos is not in the list of available qos
raise ValueError(f"QOS {v} is not in the list of available qos: \n {qos_list}")
raise QOSNotFoundError(
f"QOS {v} is not in the list of available qos: \n {qos_list}"
)
return v

@validator("hours")
Expand Down Expand Up @@ -254,7 +284,7 @@ def check_container_format(cls, v: str, values: Dict[Any, Any]) -> str:
str: The container to use.
"""
if not v.endswith(".simg"):
raise ValueError("Container must end with .simg")
raise FormatError("Container must end with .simg")
# Check if the container exists
partition: str = values.get("partition", "xenon1t")
if not os.path.exists(os.path.join(SINGULARITY_DIR[partition], v)):
Expand Down Expand Up @@ -293,7 +323,9 @@ def _create_singularity_jobstring(self) -> str:
file_discriptor = None
exec_file = f"{TMPDIR[self.partition]}/tmp.sh"
else:
file_discriptor, exec_file = tempfile.mkstemp(suffix=".sh", dir=TMPDIR[self.partition])
file_discriptor, exec_file = tempfile.mkstemp(
suffix=".sh", dir=TMPDIR[self.partition]
)
_make_executable(exec_file)
os.write(file_discriptor, bytes("#!/bin/bash\n" + self.jobstring, "utf-8"))
bind_string = " ".join(
Expand Down Expand Up @@ -405,7 +437,9 @@ def submit_job(
jobstring: str,
exclude_lc_nodes: bool = True,
log: str = "job.log",
partition: Literal["dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"] = "xenon1t",
partition: Literal[
"dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"
] = "xenon1t",
qos: str = "xenon1t",
account: str = "pi-lgrandi",
jobname: str = "somejob",
Expand Down

0 comments on commit db6af90

Please sign in to comment.