From 6fdf674d6606ad39bccfc6925e4b5defa23406d2 Mon Sep 17 00:00:00 2001 From: Camilo Laiton <36769694+camilolaiton@users.noreply.github.com> Date: Wed, 5 Jun 2024 07:05:51 -0700 Subject: [PATCH] updating closing slurm client --- .../compress/dask_utils.py | 35 +++++++++++++++++++ .../smartspim_job.py | 7 +++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/aind_smartspim_data_transformation/compress/dask_utils.py b/src/aind_smartspim_data_transformation/compress/dask_utils.py index 6f70187..bb74731 100644 --- a/src/aind_smartspim_data_transformation/compress/dask_utils.py +++ b/src/aind_smartspim_data_transformation/compress/dask_utils.py @@ -137,3 +137,38 @@ def cancel_slurm_job( response = requests.delete(endpoint, headers=headers) return response + + +def _cleanup(deployment: str) -> None: + """ + Clean up any resources that were created during the job. + + Parameters + ---------- + deployment : str + The type of deployment. Either "local" or "slurm" + """ + if deployment == Deployment.SLURM.value: + job_id = os.getenv("SLURM_JOBID") + if job_id is not None: + try: + api_url = f"http://{os.environ['HPC_HOST']}" + api_url += f":{os.environ['HPC_PORT']}" + api_url += f"/{os.environ['HPC_API_ENDPOINT']}" + headers = { + "X-SLURM-USER-NAME": os.environ["HPC_USERNAME"], + "X-SLURM-USER-PASSWORD": os.environ["HPC_PASSWORD"], + "X-SLURM-USER-TOKEN": os.environ["HPC_TOKEN"], + } + except KeyError as ke: + logging.error(f"Failed to get SLURM env vars to cleanup: {ke}") + return + logging.info(f"Cancelling SLURM job {job_id}") + response = cancel_slurm_job(job_id, api_url, headers) + if response.status_code != 200: + logging.error( + f"Failed to cancel SLURM job {job_id}: {response.text}" + ) + else: + # This might not run if the job is cancelled + logging.info(f"Cancelled SLURM job {job_id}") diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 0e61dbb..07177f5 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -13,10 +13,12 @@ JobResponse, get_parser, ) +from distributed.utils import silence_logging_cmgr from numcodecs.blosc import Blosc from pydantic import Field from aind_smartspim_data_transformation.compress.dask_utils import ( + _cleanup, get_client, get_deployment, ) @@ -159,7 +161,10 @@ def _compress_and_write_channels( ) # Closing client - client.shutdown() + with silence_logging_cmgr(logging.CRITICAL): + client.shutdown() + + _cleanup(deployment) def _compress_raw_data(self) -> None: """Compresses smartspim data"""