diff --git a/tests/test_batchq.py b/tests/test_batchq.py index dd019c1..9e86b2c 100644 --- a/tests/test_batchq.py +++ b/tests/test_batchq.py @@ -2,8 +2,42 @@ from utilix import batchq from utilix.batchq import JobSubmission, QOSNotFoundError, FormatError, submit_job import pytest +import os from unittest.mock import patch import inspect +import logging + + +# Get the SERVER type +def get_server_type(): + hostname = os.uname().nodename + if "midway2" in hostname: + return "Midway2" + elif "midway3" in hostname: + return "Midway3" + elif "dali" in hostname: + return "Dali" + else: + raise ValueError( + f"Unknown server type for hostname {hostname}. Please use midway2, midway3, or dali." + ) + + +SERVER = get_server_type() + + +def get_partition_and_qos(server): + if server == "Midway2": + return "xenon1t", "xenon1t" + elif server == "Midway3": + return "lgrandi", "lgrandi" + elif server == "Dali": + return "dali", "dali" + else: + raise ValueError(f"Unknown server: {server}") + + +PARTITION, QOS = get_partition_and_qos(SERVER) # Fixture to provide a sample valid JobSubmission instance @@ -11,8 +45,8 @@ def valid_job_submission() -> JobSubmission: return JobSubmission( jobstring="Hello World", - partition="xenon1t", - qos="xenon1t", + partition=PARTITION, + qos=QOS, hours=10, container="xenonnt-development.simg", ) @@ -23,49 +57,6 @@ def test_valid_jobstring(valid_job_submission: JobSubmission): assert valid_job_submission.jobstring == "Hello World" -def test_invalid_qos(): - """Test case to check if the appropriate validation error is raised when an invalid value is provided for the qos field.""" - with pytest.raises(QOSNotFoundError) as exc_info: - JobSubmission( - jobstring="Hello World", - qos="invalid_qos", - hours=10, - container="xenonnt-development.simg", - ) - assert "QOS invalid_qos is not in the list of available qos" in str(exc_info.value) - - -def test_valid_qos(valid_job_submission: JobSubmission): - """Test case to check if a valid qos is accepted.""" - assert valid_job_submission.qos == "xenon1t" - - -def test_invalid_hours(): - """Test case to check if the appropriate validation error is raised when an invalid value is provided for the hours field.""" - with pytest.raises(ValidationError) as exc_info: - JobSubmission( - jobstring="Hello World", - qos="xenon1t", - hours=100, - container="xenonnt-development.simg", - ) - assert "Hours must be between 0 and 72" in str(exc_info.value) - - -def test_valid_hours(valid_job_submission: JobSubmission): - """Test case to check if a valid hours value is accepted.""" - assert valid_job_submission.hours == 10 - - -def test_invalid_container(): - """Test case to check if the appropriate validation error is raised when an invalid value is provided for the container field.""" - with pytest.raises(FormatError) as exc_info: - JobSubmission( - jobstring="Hello World", qos="xenon1t", hours=10, container="invalid.ext" - ) - assert "Container must end with .simg" in str(exc_info.value) - - def test_valid_container(valid_job_submission: JobSubmission): """Test case to check if a valid path for the container is found.""" assert "xenonnt-development.simg" in valid_job_submission.container @@ -85,6 +76,55 @@ def test_container_exists(valid_job_submission: JobSubmission, tmp_path: str): assert f"Container {invalid_container} does not exist" in str(exc_info.value) +def test_invalid_container(valid_job_submission: JobSubmission): + """Test case to check if the appropriate validation error is raised when an invalid value is provided for the container field.""" + job_submission_data = valid_job_submission.dict().copy() + job_submission_data["container"] = "invalid.txt" + with pytest.raises(FormatError) as exc_info: + job_submission = JobSubmission(**job_submission_data) + assert "Container must end with .simg" in str(exc_info.value) + + +def test_invalid_qos(valid_job_submission: JobSubmission): + """Test case to check if the appropriate validation error is raised when an invalid value is provided for the qos field.""" + job_submission_data = valid_job_submission.dict().copy() + job_submission_data["qos"] = "invalid_qos" + with pytest.raises(QOSNotFoundError) as exc_info: + JobSubmission(**job_submission_data) + assert "QOS invalid_qos is not in the list of available qos" in str(exc_info.value) + + +def test_valid_qos(valid_job_submission: JobSubmission): + """Test case to check if a valid qos is accepted.""" + assert valid_job_submission.qos == valid_job_submission.qos + + +def test_invalid_bind(valid_job_submission: JobSubmission, caplog): + """Test case to check if the appropriate validation error is raised when an invalid value is provided for the bind field.""" + job_submission_data = valid_job_submission.dict().copy() + invalid_bind = "/project999" + job_submission_data["bind"].append(invalid_bind) + with caplog.at_level(logging.WARNING): + JobSubmission(**job_submission_data) + + assert "skipped mounting" in caplog.text + assert invalid_bind in caplog.text + + +def test_invalid_hours(valid_job_submission: JobSubmission): + """Test case to check if the appropriate validation error is raised when an invalid value is provided for the hours field.""" + job_submission_data = valid_job_submission.dict().copy() + job_submission_data["hours"] = 1000 + with pytest.raises(ValidationError) as exc_info: + JobSubmission(**job_submission_data) + assert "Hours must be between 0 and 72" in str(exc_info.value) + + +def test_valid_hours(valid_job_submission: JobSubmission): + """Test case to check if a valid hours value is accepted.""" + assert valid_job_submission.hours == valid_job_submission.hours + + def test_bypass_validation_qos(valid_job_submission: JobSubmission): """ Test case to check if the validation for the qos field is skipped when it is included in the bypass_validation list. diff --git a/utilix/batchq.py b/utilix/batchq.py index 7e026c1..d02435b 100644 --- a/utilix/batchq.py +++ b/utilix/batchq.py @@ -58,15 +58,19 @@ "/dali/lgrandi/grid_proxy/xenon_service_proxy:/project2/lgrandi/grid_proxy/xenon_service_proxy", ] + class QOSNotFoundError(Exception): """ Provided qos is not found in the qos list """ + + class FormatError(Exception): """ Format of file is not correct """ + def _make_executable(path: str) -> None: """ Make a file executable by the user, group and others. @@ -103,6 +107,7 @@ def _get_qos_list() -> List[str]: print(f"An error occurred while executing sacctmgr: {e}") return [] + class JobSubmission(BaseModel): """ Class to generate and submit a job to the SLURM queue. @@ -116,13 +121,13 @@ class JobSubmission(BaseModel): False, description="Exclude the loosely coupled nodes" ) log: str = Field("job.log", description="Where to store the log file of the job") + partition: Literal[ + "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build" + ] = Field("xenon1t", description="Partition to submit the job to") bind: List[str] = Field( default_factory=lambda: DEFAULT_BIND, description="Paths to add to the container. Immutable when specifying dali as partition", ) - partition: Literal[ - "dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build" - ] = Field("xenon1t", description="Partition to submit the job to") qos: str = Field("xenon1t", description="QOS to submit the job to") account: str = Field("pi-lgrandi", description="Account to submit the job to") jobname: str = Field("somejob", description="How to name this job") @@ -175,12 +180,13 @@ def _skip_validation(cls, field: str, values: Dict[Any, Any]) -> bool: bool: True if the field should be validated, False otherwise. """ return field in values.get("bypass_validation", []) + # validate the bypass_validation so that it can be reached in values @validator("bypass_validation", pre=True, each_item=True) def check_bypass_validation(cls, v: list) -> list: return v - @validator("bind", pre=True, each_item=True) + @validator("bind", pre=True) def check_bind(cls, v: str, values: Dict[Any, Any]) -> str: """ Check if the bind path exists. @@ -194,10 +200,20 @@ def check_bind(cls, v: str, values: Dict[Any, Any]) -> str: if cls._skip_validation("bind", values): return v - if not os.path.exists(v): - logger.warning("Bind path %s does not exist", v) - - return v + valid_bind = [] + invalid_bind = [] + for path in v: + if ":" in path: + actual_path = path.split(":")[0] + else: + actual_path = path + if os.path.exists(actual_path): + valid_bind.append(path) + else: + invalid_bind.append(path) + if len(invalid_bind) > 0: + logger.warning("Invalid bind paths: %s, skipped mounting", invalid_bind) + return valid_bind @validator("partition", pre=True, always=True) def overwrite_for_dali(cls, v: str, values: Dict[Any, Any]) -> str: