Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tarball all needed templates from different folder #214

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 25 additions & 48 deletions alea/submitters/htcondor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import shlex
from datetime import datetime
import logging
import shutil
from pathlib import Path
from tqdm import tqdm
from utilix.x509 import _validate_x509_proxy
from Pegasus.api import (
Arch,
Expand All @@ -23,7 +25,7 @@
)
from alea.runner import Runner
from alea.submitter import Submitter
from alea.utils import load_yaml, dump_yaml
from alea.utils import TEMPLATE_RECORDS, load_yaml, dump_yaml


DEFAULT_IMAGE = "/cvmfs/singularity.opensciencegrid.org/xenonnt/base-environment:latest"
Expand All @@ -42,12 +44,12 @@ class SubmitterHTCondor(Submitter):
def __init__(self, *args, **kwargs):
# General start
self.htcondor_configurations = kwargs.get("htcondor_configurations", {})
self.template_path = self.htcondor_configurations.pop("template_path", None)
self.singularity_image = self.htcondor_configurations.pop(
"singularity_image", DEFAULT_IMAGE
)
self.top_dir = TOP_DIR
self.work_dir = WORK_DIR
self.template_path = self.htcondor_configurations.pop("template_path", None)
self.combine_n_outputs = self.htcondor_configurations.pop("combine_n_outputs", 100)

# A flag to check if limit_threshold is added to the rc
Expand All @@ -68,6 +70,7 @@ def __init__(self, *args, **kwargs):
self.dagman_maxjobs = self.htcondor_configurations.pop("dagman_maxjobs", 100_000)

super().__init__(*args, **kwargs)
TEMPLATE_RECORDS.lock()

# Job input configurations
self.config_file_path = os.path.abspath(self.config_file_path)
Expand All @@ -80,6 +83,7 @@ def __init__(self, *args, **kwargs):
self.runs_dir = os.path.join(self.workflow_dir, "runs")
self.outputs_dir = os.path.join(self.workflow_dir, "outputs")
self.scratch_dir = os.path.join(self.workflow_dir, "scratch")
self.templates_tarball_dir = os.path.join(self.generated_dir, "templates")

@property
def template_tarball(self):
Expand Down Expand Up @@ -123,41 +127,29 @@ def requirements(self):

return _requirements

def _validate_template_path(self):
"""Validate the template path."""
if self.template_path is None:
raise ValueError("Please provide a template path.")
# This path must exists locally, and it will be used to stage the input files
if not os.path.exists(self.template_path):
raise ValueError(f"Path {self.template_path} does not exist.")

# Printout the template path file structure
logger.info("Template path file structure:")
for dirpath, dirnames, filenames in os.walk(self.template_path):
for filename in filenames:
logger.info(f"File: {filename} in {dirpath}")
if self._contains_subdirectories(self.template_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_contains_subdirectories is no longer used after this PR sop we might as well delete it or am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right.

logger.warning(
"The template path contains subdirectories. All templates files will be tarred."
)

def _tar_h5_files(self, directory, template_tarball="templates.tar.gz"):
"""Tar all .h5 templates in the directory and its subdirectories into a tarball."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring should be updated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right.

# Create a tar.gz archive
with tarfile.open(template_tarball, "w:gz") as tar:
# Walk through the directory
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
if filename.endswith(".h5"):
# Get the full path to the file
filepath = os.path.join(dirpath, filename)
# Add the file to the tar
# Specify the arcname to store relative path within the tar
tar.add(filepath, arcname=os.path.basename(filename))
tar.add(directory, arcname=os.path.basename(directory))
hammannr marked this conversation as resolved.
Show resolved Hide resolved

def _make_template_tarball(self):
"""Make tarball of the templates if not exists."""
self._tar_h5_files(self.template_path, self.template_tarball)
if not TEMPLATE_RECORDS.uniqueness:
raise RuntimeError("All files in the template path must have unique basenames.")
os.makedirs(self.templates_tarball_dir, exist_ok=True)
if os.listdir(self.templates_tarball_dir):
raise RuntimeError(
f"Directory {self.templates_tarball_dir} is not empty. "
"Please remove it before running the script."
)

logger.info(f"Copying templates into {self.templates_tarball_dir}")
for record in tqdm(TEMPLATE_RECORDS):
# Copy each file to the destination folder
shutil.copy(record, self.templates_tarball_dir)
self._tar_h5_files(self.templates_tarball_dir, self.template_tarball)
logger.info(f"Tarbal made at {self.template_tarball}")

def _modify_yaml(self):
"""Modify the statistical model config file to correct the 'template_filename' fields.
Expand Down Expand Up @@ -674,19 +666,6 @@ def _plan_and_submit(self):
**self.pegasus_config,
)

def _check_filename_unique(self):
"""Check if all the files in the template path are unique.

We assume two levels of the template folder.

"""
all_files = []
for _, _, filenames in os.walk(self.template_path):
for filename in filenames:
all_files.append(filename)
if len(all_files) != len(set(all_files)):
raise RuntimeError("All files in the template path must have unique names.")

def submit(self, **kwargs):
"""Serve as the main function to submit the workflow."""
if os.path.exists(self.workflow_dir):
Expand All @@ -704,8 +683,6 @@ def submit(self, **kwargs):
self._modify_yaml()

# Handling templates as part of the inputs
self._validate_template_path()
self._check_filename_unique()
self._make_template_tarball()

self._generate_workflow()
Expand All @@ -715,6 +692,6 @@ def submit(self, **kwargs):
self.wf.graph(
output=os.path.join(self.generated_dir, "workflow_graph.dot"), label="xform-id"
)
self.wf.graph(
output=os.path.join(self.generated_dir, "workflow_graph.svg"), label="xform-id"
)
# self.wf.graph(
# output=os.path.join(self.generated_dir, "workflow_graph.svg"), label="xform-id"
# )
hammannr marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion alea/submitters/run_toymc_wrapper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ METADATA=$(echo "$metadata" | sed "s/'/\"/g")
mkdir -p templates
START=$(date +%s)
for TAR in `ls *.tar.gz`; do
tar -xzf $TAR -C templates
tar -xzf $TAR -C templates --strip-components=1
done
rm *.tar.gz
END=$(date +%s)
Expand Down
46 changes: 43 additions & 3 deletions alea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,44 @@
MAX_FLOAT = np.sqrt(np.finfo(np.float32).max)


class CannotUpdate(Exception):
pass


class LockableSet(set):
"""A set whose `update` method can be locked."""

def __init__(self, *args):
super().__init__(*args)
self.locked = False

def lock(self):
"""Lock the set to prevent modifications."""
self.locked = True

def unlock(self):
"""Unlock the set to allow modifications."""
self.locked = False

def update(self, *args):
"""Update the set with elements if it is not locked."""
if not self.locked:
super().update(*args)
else:
raise CannotUpdate("LockableSet is locked so can not be updated!")

def uniqueness(self):
"""Check if the basenames contains unique elements."""
return len(set(self.basenames)) == len(self.basenames)

def basenames(self):
"""The basenames of the filenames in the set."""
return [os.path.basename(record) for record in self]


TEMPLATE_RECORDS = LockableSet()


class ReadOnlyDict:
"""A read-only dict."""

Expand Down Expand Up @@ -133,6 +171,7 @@ def _prefix_file_path(
if isinstance(config[key], str) and key not in ignore_keys:
try:
config[key] = get_file_path(config[key], template_folder_list)
TEMPLATE_RECORDS.update(glob(formatted_to_asterisked(config[key])))
except RuntimeError:
pass

Expand Down Expand Up @@ -165,6 +204,7 @@ def adapt_likelihood_config_for_blueice(
raise ValueError(f"Could not find {likelihood_config_copy['default_source_class']}!")
likelihood_config_copy["default_source_class"] = default_source_class

# Translation to blueice's language
for source in likelihood_config_copy["sources"]:
if "template_filename" in source:
source["templatename"] = get_file_path(
Expand Down Expand Up @@ -212,7 +252,7 @@ def dump_json(file_name: str, data: dict):
json.dump(data, file, indent=4)


def _get_abspath(file_name):
def _get_internal(file_name):
"""Get the abspath of the file.

Raise FileNotFoundError when not found in any subfolder
Expand Down Expand Up @@ -276,7 +316,7 @@ def get_file_path(fname, folder_list: Optional[List[str]] = None):

#. fname begin with '/', return absolute path
#. folder begin with '/', return folder + name
#. can get file from _get_abspath, return alea internal file path
#. can get file from _get_internal, return alea internal file path
#. can be found in local installed ntauxfiles, return ntauxfiles absolute path
#. can be downloaded from MongoDB, download and return cached path

Expand Down Expand Up @@ -312,7 +352,7 @@ def get_file_path(fname, folder_list: Optional[List[str]] = None):

# 3. From alea internal files
try:
return _get_abspath(fname)
return _get_internal(fname)
except FileNotFoundError:
pass

Expand Down
Loading