Skip to content

Commit

Permalink
Tarball all needed templates from different folder (#214)
Browse files Browse the repository at this point in the history
* Tarball all needed templates from different folder as long as their
names are different

* Minor change
  • Loading branch information
dachengx authored Sep 19, 2024
1 parent aa380c9 commit 0b7c580
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 52 deletions.
73 changes: 25 additions & 48 deletions alea/submitters/htcondor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import shlex
from datetime import datetime
import logging
import shutil
from pathlib import Path
from tqdm import tqdm
from utilix.x509 import _validate_x509_proxy
from Pegasus.api import (
Arch,
Expand All @@ -23,7 +25,7 @@
)
from alea.runner import Runner
from alea.submitter import Submitter
from alea.utils import load_yaml, dump_yaml
from alea.utils import TEMPLATE_RECORDS, load_yaml, dump_yaml


DEFAULT_IMAGE = "/cvmfs/singularity.opensciencegrid.org/xenonnt/base-environment:latest"
Expand All @@ -42,12 +44,12 @@ class SubmitterHTCondor(Submitter):
def __init__(self, *args, **kwargs):
# General start
self.htcondor_configurations = kwargs.get("htcondor_configurations", {})
self.template_path = self.htcondor_configurations.pop("template_path", None)
self.singularity_image = self.htcondor_configurations.pop(
"singularity_image", DEFAULT_IMAGE
)
self.top_dir = TOP_DIR
self.work_dir = WORK_DIR
self.template_path = self.htcondor_configurations.pop("template_path", None)
self.combine_n_outputs = self.htcondor_configurations.pop("combine_n_outputs", 100)

# A flag to check if limit_threshold is added to the rc
Expand All @@ -68,6 +70,7 @@ def __init__(self, *args, **kwargs):
self.dagman_maxjobs = self.htcondor_configurations.pop("dagman_maxjobs", 100_000)

super().__init__(*args, **kwargs)
TEMPLATE_RECORDS.lock()

# Job input configurations
self.config_file_path = os.path.abspath(self.config_file_path)
Expand All @@ -80,6 +83,7 @@ def __init__(self, *args, **kwargs):
self.runs_dir = os.path.join(self.workflow_dir, "runs")
self.outputs_dir = os.path.join(self.workflow_dir, "outputs")
self.scratch_dir = os.path.join(self.workflow_dir, "scratch")
self.templates_tarball_dir = os.path.join(self.generated_dir, "templates")

@property
def template_tarball(self):
Expand Down Expand Up @@ -123,41 +127,29 @@ def requirements(self):

return _requirements

def _validate_template_path(self):
"""Validate the template path."""
if self.template_path is None:
raise ValueError("Please provide a template path.")
# This path must exists locally, and it will be used to stage the input files
if not os.path.exists(self.template_path):
raise ValueError(f"Path {self.template_path} does not exist.")

# Printout the template path file structure
logger.info("Template path file structure:")
for dirpath, dirnames, filenames in os.walk(self.template_path):
for filename in filenames:
logger.info(f"File: {filename} in {dirpath}")
if self._contains_subdirectories(self.template_path):
logger.warning(
"The template path contains subdirectories. All templates files will be tarred."
)

def _tar_h5_files(self, directory, template_tarball="templates.tar.gz"):
"""Tar all .h5 templates in the directory and its subdirectories into a tarball."""
# Create a tar.gz archive
with tarfile.open(template_tarball, "w:gz") as tar:
# Walk through the directory
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
if filename.endswith(".h5"):
# Get the full path to the file
filepath = os.path.join(dirpath, filename)
# Add the file to the tar
# Specify the arcname to store relative path within the tar
tar.add(filepath, arcname=os.path.basename(filename))
tar.add(directory, arcname=os.path.basename(directory))

def _make_template_tarball(self):
"""Make tarball of the templates if not exists."""
self._tar_h5_files(self.template_path, self.template_tarball)
if not TEMPLATE_RECORDS.uniqueness:
raise RuntimeError("All files in the template path must have unique basenames.")
os.makedirs(self.templates_tarball_dir, exist_ok=True)
if os.listdir(self.templates_tarball_dir):
raise RuntimeError(
f"Directory {self.templates_tarball_dir} is not empty. "
"Please remove it before running the script."
)

logger.info(f"Copying templates into {self.templates_tarball_dir}")
for record in tqdm(TEMPLATE_RECORDS):
# Copy each file to the destination folder
shutil.copy(record, self.templates_tarball_dir)
self._tar_h5_files(self.templates_tarball_dir, self.template_tarball)
logger.info(f"Tarbal made at {self.template_tarball}")

def _modify_yaml(self):
"""Modify the statistical model config file to correct the 'template_filename' fields.
Expand Down Expand Up @@ -674,19 +666,6 @@ def _plan_and_submit(self):
**self.pegasus_config,
)

def _check_filename_unique(self):
"""Check if all the files in the template path are unique.
We assume two levels of the template folder.
"""
all_files = []
for _, _, filenames in os.walk(self.template_path):
for filename in filenames:
all_files.append(filename)
if len(all_files) != len(set(all_files)):
raise RuntimeError("All files in the template path must have unique names.")

def submit(self, **kwargs):
"""Serve as the main function to submit the workflow."""
if os.path.exists(self.workflow_dir):
Expand All @@ -704,8 +683,6 @@ def submit(self, **kwargs):
self._modify_yaml()

# Handling templates as part of the inputs
self._validate_template_path()
self._check_filename_unique()
self._make_template_tarball()

self._generate_workflow()
Expand All @@ -715,6 +692,6 @@ def submit(self, **kwargs):
self.wf.graph(
output=os.path.join(self.generated_dir, "workflow_graph.dot"), label="xform-id"
)
self.wf.graph(
output=os.path.join(self.generated_dir, "workflow_graph.svg"), label="xform-id"
)
# self.wf.graph(
# output=os.path.join(self.generated_dir, "workflow_graph.svg"), label="xform-id"
# )
2 changes: 1 addition & 1 deletion alea/submitters/run_toymc_wrapper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ METADATA=$(echo "$metadata" | sed "s/'/\"/g")
mkdir -p templates
START=$(date +%s)
for TAR in `ls *.tar.gz`; do
tar -xzf $TAR -C templates
tar -xzf $TAR -C templates --strip-components=1
done
rm *.tar.gz
END=$(date +%s)
Expand Down
46 changes: 43 additions & 3 deletions alea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,44 @@
MAX_FLOAT = np.sqrt(np.finfo(np.float32).max)


class CannotUpdate(Exception):
pass


class LockableSet(set):
"""A set whose `update` method can be locked."""

def __init__(self, *args):
super().__init__(*args)
self.locked = False

def lock(self):
"""Lock the set to prevent modifications."""
self.locked = True

def unlock(self):
"""Unlock the set to allow modifications."""
self.locked = False

def update(self, *args):
"""Update the set with elements if it is not locked."""
if not self.locked:
super().update(*args)
else:
raise CannotUpdate("LockableSet is locked so can not be updated!")

def uniqueness(self):
"""Check if the basenames contains unique elements."""
return len(set(self.basenames)) == len(self.basenames)

def basenames(self):
"""The basenames of the filenames in the set."""
return [os.path.basename(record) for record in self]


TEMPLATE_RECORDS = LockableSet()


class ReadOnlyDict:
"""A read-only dict."""

Expand Down Expand Up @@ -133,6 +171,7 @@ def _prefix_file_path(
if isinstance(config[key], str) and key not in ignore_keys:
try:
config[key] = get_file_path(config[key], template_folder_list)
TEMPLATE_RECORDS.update(glob(formatted_to_asterisked(config[key])))
except RuntimeError:
pass

Expand Down Expand Up @@ -165,6 +204,7 @@ def adapt_likelihood_config_for_blueice(
raise ValueError(f"Could not find {likelihood_config_copy['default_source_class']}!")
likelihood_config_copy["default_source_class"] = default_source_class

# Translation to blueice's language
for source in likelihood_config_copy["sources"]:
if "template_filename" in source:
source["templatename"] = get_file_path(
Expand Down Expand Up @@ -212,7 +252,7 @@ def dump_json(file_name: str, data: dict):
json.dump(data, file, indent=4)


def _get_abspath(file_name):
def _get_internal(file_name):
"""Get the abspath of the file.
Raise FileNotFoundError when not found in any subfolder
Expand Down Expand Up @@ -276,7 +316,7 @@ def get_file_path(fname, folder_list: Optional[List[str]] = None):
#. fname begin with '/', return absolute path
#. folder begin with '/', return folder + name
#. can get file from _get_abspath, return alea internal file path
#. can get file from _get_internal, return alea internal file path
#. can be found in local installed ntauxfiles, return ntauxfiles absolute path
#. can be downloaded from MongoDB, download and return cached path
Expand Down Expand Up @@ -312,7 +352,7 @@ def get_file_path(fname, folder_list: Optional[List[str]] = None):

# 3. From alea internal files
try:
return _get_abspath(fname)
return _get_internal(fname)
except FileNotFoundError:
pass

Expand Down

0 comments on commit 0b7c580

Please sign in to comment.