diff --git a/tests/test_batchq.py b/tests/test_batchq.py index 4dbc130..d208c70 100644 --- a/tests/test_batchq.py +++ b/tests/test_batchq.py @@ -29,16 +29,35 @@ def get_server_type(): def get_partition_and_qos(server): if server == "Midway2": - return "xenon1t", "xenon1t" + return [ + "xenon1t", + "broadwl", + "kicp", + "build", + "bigmem2", + "gpu2", + ] elif server == "Midway3": - return "lgrandi", "lgrandi" + return [ + "kicp", + "lgrandi", + "caslake", + "build", + ] elif server == "Dali": - return "dali", "dali" + return [ + "dali", + "xenon1t", + "broadwl", + "kicp", + "build", + "bigmem2", + "gpu2", + ] else: raise ValueError(f"Unknown server: {server}") - -PARTITION, QOS = get_partition_and_qos(SERVER) +PARTITIONS = get_partition_and_qos(SERVER) # Fixture to provide a sample valid JobSubmission instance @@ -46,35 +65,46 @@ def get_partition_and_qos(server): def valid_job_submission() -> JobSubmission: return JobSubmission( jobstring="Hello World", - partition=PARTITION, - qos=QOS, + partition=PARTITIONS[0], + qos=PARTITIONS[0], hours=10, container="xenonnt-development.simg", ) -def test_job_submission_submit(valid_job_submission: JobSubmission): +import pytest + +@pytest.mark.parametrize("partition", PARTITIONS) +def test_job_submission_submit_all_partitions(partition): + job_submission = JobSubmission( + jobstring="echo 'Job started'; sleep 10; echo 'Job completed'", + partition=partition, + qos=partition, # QOS is the same as the partition + hours=1, + container="xenonnt-development.simg", + ) + with patch("utilix.batchq.Slurm") as mock_slurm_class: mock_slurm = MagicMock() mock_slurm_class.return_value = mock_slurm - valid_job_submission.jobstring = "echo 'Job started'; sleep 10; echo 'Job completed'" - valid_job_submission.submit() + job_submission.submit() mock_slurm_class.assert_called_once_with( - job_name=valid_job_submission.jobname, - output=valid_job_submission.log, - qos=valid_job_submission.qos, - error=valid_job_submission.log, - account=valid_job_submission.account, - partition=valid_job_submission.partition, - mem_per_cpu=valid_job_submission.mem_per_cpu, - cpus_per_task=valid_job_submission.cpus_per_task, - time=datetime.timedelta(hours=valid_job_submission.hours), + job_name=job_submission.jobname, + output=job_submission.log, + qos=partition, + error=job_submission.log, + account=job_submission.account, + partition=partition, + mem_per_cpu=job_submission.mem_per_cpu, + cpus_per_task=job_submission.cpus_per_task, + time=datetime.timedelta(hours=job_submission.hours), ) mock_slurm.add_cmd.assert_called_once() mock_slurm.sbatch.assert_called_once_with(shell="/bin/bash") -def test_submit_job_function(): +@pytest.mark.parametrize("partition", PARTITIONS) +def test_submit_job_function_all_partitions(partition): jobstring = "echo 'Job started'; sleep 10; echo 'Job completed'" with patch("utilix.batchq.JobSubmission") as mock_job_submission_class: @@ -83,9 +113,9 @@ def test_submit_job_function(): submit_job( jobstring=jobstring, - partition=PARTITION, - qos=QOS, - hours=10, + partition=partition, + qos=partition, + hours=1, container="xenonnt-development.simg", ) @@ -93,8 +123,8 @@ def test_submit_job_function(): jobstring=jobstring, exclude_lc_nodes=False, log="job.log", - partition=PARTITION, - qos=QOS, + partition=partition, + qos=partition, account="pi-lgrandi", jobname="somejob", sbatch_file=None, @@ -103,7 +133,7 @@ def test_submit_job_function(): container="xenonnt-development.simg", bind=batchq.DEFAULT_BIND, cpus_per_task=1, - hours=10, + hours=1, node=None, exclude_nodes=None, dependency=None, diff --git a/utilix/batchq.py b/utilix/batchq.py index d02435b..6512f92 100644 --- a/utilix/batchq.py +++ b/utilix/batchq.py @@ -32,6 +32,8 @@ "kicp", "caslake", "build", + "bigmem2", + "gpu2", ] TMPDIR: Dict[str, str] = { "dali": f"/dali/lgrandi/{USER}/tmp", @@ -41,6 +43,8 @@ "kicp": os.path.join(SCRATCH_DIR, "tmp"), "caslake": os.path.join(SCRATCH_DIR, "tmp"), "build": os.path.join(SCRATCH_DIR, "tmp"), + "bigmem2": os.path.join(SCRATCH_DIR, "tmp"), + "gpu2": os.path.join(SCRATCH_DIR, "tmp"), } SINGULARITY_DIR: str = "lgrandi/xenonnt/singularity-images" @@ -122,7 +126,7 @@ class JobSubmission(BaseModel): ) log: str = Field("job.log", description="Where to store the log file of the job") partition: Literal[ - "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build" + "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build", "bigmem2", "gpu2" ] = Field("xenon1t", description="Partition to submit the job to") bind: List[str] = Field( default_factory=lambda: DEFAULT_BIND, @@ -488,7 +492,7 @@ def submit_job( exclude_lc_nodes: bool = False, log: str = "job.log", partition: Literal[ - "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build" + "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build", "bigmem2", "gpu2" ] = "xenon1t", qos: str = "xenon1t", account: str = "pi-lgrandi", @@ -513,7 +517,7 @@ def submit_job( jobstring (str): The command to execute. exclude_lc_nodes (bool): Exclude the loosely coupled nodes. Default is True. log (str): Where to store the log file of the job. Default is "job.log". - partition (Literal["dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"]): + partition (Literal["dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build", "bigmem2", "gpu2" (the only GPU node)]): Partition to submit the job to. Default is "xenon1t". qos (str): QOS to submit the job to. Default is "xenon1t". account (str): Account to submit the job to. Default is "pi-lgrandi".