Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto check binding #110

Merged
merged 7 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 85 additions & 45 deletions tests/test_batchq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,51 @@
from utilix import batchq
from utilix.batchq import JobSubmission, QOSNotFoundError, FormatError, submit_job
import pytest
import os
from unittest.mock import patch
import inspect
import logging


# Get the SERVER type
def get_server_type():
hostname = os.uname().nodename
if "midway2" in hostname:
return "Midway2"
elif "midway3" in hostname:
return "Midway3"
elif "dali" in hostname:
return "Dali"
else:
raise ValueError(
f"Unknown server type for hostname {hostname}. Please use midway2, midway3, or dali."
)


SERVER = get_server_type()


def get_partition_and_qos(server):
if server == "Midway2":
return "xenon1t", "xenon1t"
elif server == "Midway3":
return "lgrandi", "lgrandi"
elif server == "Dali":
return "dali", "dali"
else:
raise ValueError(f"Unknown server: {server}")


PARTITION, QOS = get_partition_and_qos(SERVER)


# Fixture to provide a sample valid JobSubmission instance
@pytest.fixture
def valid_job_submission() -> JobSubmission:
return JobSubmission(
jobstring="Hello World",
partition="xenon1t",
qos="xenon1t",
partition=PARTITION,
qos=QOS,
hours=10,
container="xenonnt-development.simg",
)
Expand All @@ -23,49 +57,6 @@ def test_valid_jobstring(valid_job_submission: JobSubmission):
assert valid_job_submission.jobstring == "Hello World"


def test_invalid_qos():
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the qos field."""
with pytest.raises(QOSNotFoundError) as exc_info:
JobSubmission(
jobstring="Hello World",
qos="invalid_qos",
hours=10,
container="xenonnt-development.simg",
)
assert "QOS invalid_qos is not in the list of available qos" in str(exc_info.value)


def test_valid_qos(valid_job_submission: JobSubmission):
"""Test case to check if a valid qos is accepted."""
assert valid_job_submission.qos == "xenon1t"


def test_invalid_hours():
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the hours field."""
with pytest.raises(ValidationError) as exc_info:
JobSubmission(
jobstring="Hello World",
qos="xenon1t",
hours=100,
container="xenonnt-development.simg",
)
assert "Hours must be between 0 and 72" in str(exc_info.value)


def test_valid_hours(valid_job_submission: JobSubmission):
"""Test case to check if a valid hours value is accepted."""
assert valid_job_submission.hours == 10


def test_invalid_container():
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the container field."""
with pytest.raises(FormatError) as exc_info:
JobSubmission(
jobstring="Hello World", qos="xenon1t", hours=10, container="invalid.ext"
)
assert "Container must end with .simg" in str(exc_info.value)


def test_valid_container(valid_job_submission: JobSubmission):
"""Test case to check if a valid path for the container is found."""
assert "xenonnt-development.simg" in valid_job_submission.container
Expand All @@ -85,6 +76,55 @@ def test_container_exists(valid_job_submission: JobSubmission, tmp_path: str):
assert f"Container {invalid_container} does not exist" in str(exc_info.value)


def test_invalid_container(valid_job_submission: JobSubmission):
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the container field."""
job_submission_data = valid_job_submission.dict().copy()
job_submission_data["container"] = "invalid.txt"
with pytest.raises(FormatError) as exc_info:
job_submission = JobSubmission(**job_submission_data)
assert "Container must end with .simg" in str(exc_info.value)


def test_invalid_qos(valid_job_submission: JobSubmission):
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the qos field."""
job_submission_data = valid_job_submission.dict().copy()
job_submission_data["qos"] = "invalid_qos"
with pytest.raises(QOSNotFoundError) as exc_info:
JobSubmission(**job_submission_data)
assert "QOS invalid_qos is not in the list of available qos" in str(exc_info.value)


def test_valid_qos(valid_job_submission: JobSubmission):
"""Test case to check if a valid qos is accepted."""
assert valid_job_submission.qos == valid_job_submission.qos


def test_invalid_bind(valid_job_submission: JobSubmission, caplog):
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the bind field."""
job_submission_data = valid_job_submission.dict().copy()
invalid_bind = "/project999"
job_submission_data["bind"].append(invalid_bind)
with caplog.at_level(logging.WARNING):
JobSubmission(**job_submission_data)

assert "skipped mounting" in caplog.text
assert invalid_bind in caplog.text


def test_invalid_hours(valid_job_submission: JobSubmission):
"""Test case to check if the appropriate validation error is raised when an invalid value is provided for the hours field."""
job_submission_data = valid_job_submission.dict().copy()
job_submission_data["hours"] = 1000
with pytest.raises(ValidationError) as exc_info:
JobSubmission(**job_submission_data)
assert "Hours must be between 0 and 72" in str(exc_info.value)


def test_valid_hours(valid_job_submission: JobSubmission):
"""Test case to check if a valid hours value is accepted."""
assert valid_job_submission.hours == valid_job_submission.hours


def test_bypass_validation_qos(valid_job_submission: JobSubmission):
"""
Test case to check if the validation for the qos field is skipped when it is included in the bypass_validation list.
Expand Down
32 changes: 24 additions & 8 deletions utilix/batchq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@
"/dali/lgrandi/grid_proxy/xenon_service_proxy:/project2/lgrandi/grid_proxy/xenon_service_proxy",
]


class QOSNotFoundError(Exception):
"""
Provided qos is not found in the qos list
"""


class FormatError(Exception):
"""
Format of file is not correct
"""


def _make_executable(path: str) -> None:
"""
Make a file executable by the user, group and others.
Expand Down Expand Up @@ -103,6 +107,7 @@ def _get_qos_list() -> List[str]:
print(f"An error occurred while executing sacctmgr: {e}")
return []


class JobSubmission(BaseModel):
"""
Class to generate and submit a job to the SLURM queue.
Expand All @@ -116,13 +121,13 @@ class JobSubmission(BaseModel):
False, description="Exclude the loosely coupled nodes"
)
log: str = Field("job.log", description="Where to store the log file of the job")
partition: Literal[
dachengx marked this conversation as resolved.
Show resolved Hide resolved
"dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"
] = Field("xenon1t", description="Partition to submit the job to")
bind: List[str] = Field(
default_factory=lambda: DEFAULT_BIND,
description="Paths to add to the container. Immutable when specifying dali as partition",
)
partition: Literal[
"dali", "lgrandi", "xenon1t", "broadwl", "kicp", "caslake", "build"
] = Field("xenon1t", description="Partition to submit the job to")
qos: str = Field("xenon1t", description="QOS to submit the job to")
account: str = Field("pi-lgrandi", description="Account to submit the job to")
jobname: str = Field("somejob", description="How to name this job")
Expand Down Expand Up @@ -175,12 +180,13 @@ def _skip_validation(cls, field: str, values: Dict[Any, Any]) -> bool:
bool: True if the field should be validated, False otherwise.
"""
return field in values.get("bypass_validation", [])

# validate the bypass_validation so that it can be reached in values
@validator("bypass_validation", pre=True, each_item=True)
def check_bypass_validation(cls, v: list) -> list:
return v

@validator("bind", pre=True, each_item=True)
@validator("bind", pre=True)
def check_bind(cls, v: str, values: Dict[Any, Any]) -> str:
"""
Check if the bind path exists.
Expand All @@ -194,10 +200,20 @@ def check_bind(cls, v: str, values: Dict[Any, Any]) -> str:
if cls._skip_validation("bind", values):
return v

if not os.path.exists(v):
logger.warning("Bind path %s does not exist", v)

return v
valid_bind = []
invalid_bind = []
for path in v:
if ":" in path:
actual_path = path.split(":")[0]
else:
actual_path = path
if os.path.exists(actual_path):
valid_bind.append(path)
else:
invalid_bind.append(path)
if len(invalid_bind) > 0:
logger.warning("Invalid bind paths: %s, skipped mounting", invalid_bind)
return valid_bind

@validator("partition", pre=True, always=True)
def overwrite_for_dali(cls, v: str, values: Dict[Any, Any]) -> str:
Expand Down
Loading