Skip to content

Commit

Permalink
Merge pull request #28 from dirac-institute/issue/27/use-sharded-wu
Browse files Browse the repository at this point in the history
Changes to make use sharded workunits
  • Loading branch information
drewoldag authored Aug 13, 2024
2 parents 72710e3 + 598dfa6 commit f31d2ad
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 128 deletions.
24 changes: 22 additions & 2 deletions src/kbmod_wf/resource_configs/klone_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
walltimes = {
"compute_bigmem": "01:00:00",
"large_mem": "04:00:00",
"sharded_reproject": "04:00:00",
"gpu_max": "08:00:00",
}

Expand All @@ -26,7 +27,7 @@ def klone_resource_config():
label="small_cpu",
max_workers_per_node=1,
provider=SlurmProvider(
partition="compute-bigmem",
partition="ckpt-g2",
account="astro",
min_blocks=0,
max_blocks=4,
Expand All @@ -52,14 +53,33 @@ def klone_resource_config():
init_blocks=0,
parallelism=1,
nodes_per_block=1,
cores_per_node=8,
cores_per_node=32,
mem_per_node=512,
exclusive=False,
walltime=walltimes["large_mem"],
# Command to run before starting worker - i.e. conda activate <special_env>
worker_init="",
),
),
HighThroughputExecutor(
label="sharded_reproject",
max_workers_per_node=1,
provider=SlurmProvider(
partition="ckpt-g2",
account="astro",
min_blocks=0,
max_blocks=2,
init_blocks=0,
parallelism=1,
nodes_per_block=1,
cores_per_node=32,
mem_per_node=128, # ~2-4 GB per core
exclusive=False,
walltime=walltimes["sharded_reproject"],
# Command to run before starting worker - i.e. conda activate <special_env>
worker_init="",
),
),
HighThroughputExecutor(
label="gpu",
max_workers_per_node=1,
Expand Down
21 changes: 5 additions & 16 deletions src/kbmod_wf/task_impls/ic_to_wu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,6 @@
from logging import Logger


def placeholder_ic_to_wu(ic_file=None, wu_file=None, logger=None):
logger.info("In the ic_to_wu task_impl")
with open(ic_file, "r") as f:
for line in f:
value = line.strip()
logger.info(line.strip())

with open(wu_file, "w") as f:
f.write(f"Logged: {value} - {time.time()}\n")

return wu_file


def ic_to_wu(
ic_filepath: str = None, wu_filepath: str = None, runtime_config: dict = {}, logger: Logger = None
):
Expand Down Expand Up @@ -78,13 +65,15 @@ def create_work_unit(self):
self.logger.info(f"ImageCollection read from {self.ic_filepath}, creating work unit next.")

last_time = time.time()
orig_wu = ic.toWorkUnit(config=SearchConfiguration.from_file(self.search_config_filepath))
#! This needs the butler.
orig_wu = ic.toWorkUnit(search_config=SearchConfiguration.from_file(self.search_config_filepath))
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to create WorkUnit.")

self.logger.info(f"Saving original work unit to: {self.wu_filepath}")
self.logger.info(f"Saving sharded work unit to: {self.wu_filepath}")
last_time = time.time()
orig_wu.to_fits(self.wu_filepath, overwrite=True)
directory_containing_shards, wu_filename = os.path.split(self.wu_filepath)
orig_wu.to_sharded_fits(wu_filename, directory_containing_shards, overwrite=True)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to write WorkUnit to disk: {self.wu_filepath}")

Expand Down
26 changes: 5 additions & 21 deletions src/kbmod_wf/task_impls/kbmod_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,9 @@
from kbmod.work_unit import WorkUnit

import os
import time
import traceback
from logging import Logger


def placeholder_kbmod_search(input_wu=None, result_file=None, logger=None):
logger.info("In the kbmod_search task_impl")
with open(input_wu, "r") as f:
for line in f:
value = line.strip()
logger.info(line.strip())

time.sleep(5)

with open(result_file, "w") as f:
f.write(f"Logged: {value} - {time.time()}\n")

return result_file


def kbmod_search(
wu_filepath: str = None,
result_filepath: str = None,
Expand Down Expand Up @@ -72,9 +55,12 @@ def __init__(

def run_search(self):
self.logger.info("Loading workunit from file")
wu = WorkUnit.from_fits(self.input_wu_filepath)

directory_containing_shards, wu_filename = os.path.split(self.input_wu_filepath)
wu = WorkUnit.from_sharded_fits(wu_filename, directory_containing_shards, lazy=False)
self.logger.debug("Loaded work unit")

#! Seems odd that we extract, modify, and reset the config in the workunit.
#! Can we just modify the config in the workunit directly?
if self.search_config_filepath is not None:
# Load a search configuration, otherwise use the one loaded with the work unit
wu.config = kbmod.configuration.SearchConfiguration.from_file(self.search_config_filepath)
Expand All @@ -90,8 +76,6 @@ def run_search(self):
}
config.set_multiple(input_parameters)

# Save the search config in the results directory for record keeping
config.to_file(os.path.join(self.results_directory, "search_config.yaml"))
wu.config = config

self.logger.info("Running KBMOD search")
Expand Down
85 changes: 31 additions & 54 deletions src/kbmod_wf/task_impls/reproject_wu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,13 @@
from astropy.wcs import WCS
from astropy.io import fits
from astropy.coordinates import EarthLocation
import astropy.time
from astropy.time import Time
import numpy as np
import os
import time
from logging import Logger


def placeholder_reproject_wu(input_wu=None, reprojected_wu=None, logger=None):
logger.info("In the reproject_wu task_impl")
with open(input_wu, "r") as f:
for line in f:
value = line.strip()
logger.info(line.strip())

with open(reprojected_wu, "w") as f:
f.write(f"Logged: {value} - {time.time()}\n")

return reprojected_wu


def reproject_wu(
original_wu_filepath: str = None,
uri_filepath: str = None,
Expand Down Expand Up @@ -104,10 +91,7 @@ def __init__(
f"When patch pixel dimensions are not specifified, the user must supply a pixel scale via the command line or the uri file."
)

self.image_width, self.image_height = self._patch_arcmin_to_pixels(
patch_size_arcmin=self.patch_size,
pixel_scale_arcsec_per_pix=self.pixel_scale,
)
self.image_width, self.image_height = self._patch_arcmin_to_pixels()

self.point_on_earth = EarthLocation.of_site(self.runtime_config.get("observation_site", "ctio"))

Expand All @@ -122,65 +106,58 @@ def reproject_workunit(self):
)

last_time = time.time()
self.logger.info(f"Reading existing WorkUnit from disk: {self.original_wu_filepath}")
orig_wu = WorkUnit.from_fits(self.original_wu_filepath)
self.logger.info(f"Lazy reading existing WorkUnit from disk: {self.original_wu_filepath}")
directory_containing_shards, wu_filename = os.path.split(self.original_wu_filepath)
wu = WorkUnit.from_sharded_fits(wu_filename, directory_containing_shards, lazy=True)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to read original WorkUnit {self.original_wu_filepath}.")
self.logger.debug(
f"Required {elapsed}[s] to lazy read original WorkUnit {self.original_wu_filepath}."
)

# gather elements needed for reproject phase
imgs = orig_wu.im_stack
#! This method to get image dimensions won't hold if the images are different sizes.
image_height, image_width = wu._per_image_wcs[0].array_shape

# Find the EBD (estimated barycentric distance) WCS for each image
last_time = time.time()
ebd_per_image_wcs, geocentric_dists = transform_wcses_to_ebd(
orig_wu._per_image_wcs,
imgs.get_single_image(0).get_width(),
imgs.get_single_image(0).get_height(),
wu._per_image_wcs,
image_width,
image_height,
self.guess_dist,
[astropy.time.Time(img.get_obstime(), format="mjd") for img in imgs.get_images()],
Time(wu.get_all_obstimes(), format="mjd"),
self.point_on_earth,
npoints=10,
seed=None,
)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to transform WCS objects to EBD..")

if len(orig_wu._per_image_wcs) != len(ebd_per_image_wcs):
if len(wu._per_image_wcs) != len(ebd_per_image_wcs):
raise ValueError(
f"Number of barycentric WCS objects ({len(ebd_per_image_wcs)}) does not match the original number of images ({len(orig_wu._per_image_wcs)})."
f"Number of barycentric WCS objects ({len(ebd_per_image_wcs)}) does not match the original number of images ({len(wu._per_image_wcs)})."
)

# Construct a WorkUnit with the EBD WCS and provenance data
self.logger.debug(f"Creating Barycentric WorkUnit...")
last_time = time.time()
ebd_wu = WorkUnit(
im_stack=orig_wu.im_stack,
config=orig_wu.config,
per_image_wcs=orig_wu._per_image_wcs,
per_image_ebd_wcs=ebd_per_image_wcs,
heliocentric_distance=self.guess_dist,
geocentric_distances=geocentric_dists,
)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to create EBD WorkUnit.")
wu._per_image_ebd_wcs = ebd_per_image_wcs
wu.heliocentric_distance = self.guess_dist
wu.geocentric_distances = geocentric_dists

# Reproject to a common WCS using the WCS for our patch
self.logger.debug(f"Reprojecting WorkUnit with {self.n_workers} workers...")
last_time = time.time()
reprojected_wu = reprojection.reproject_work_unit(
ebd_wu, patch_wcs, frame="ebd", max_parallel_processes=self.n_workers
)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to create the reprojected WorkUnit.")

# Save the reprojected WorkUnit
self.logger.debug(f"Saving reprojected work unit to: {self.reprojected_wu_filepath}")
last_time = time.time()
reprojected_wu.to_fits(self.reprojected_wu_filepath)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(
f"Required {elapsed}[s] to create the reprojected WorkUnit: {self.reprojected_wu_filepath}"
directory_containing_reprojected_shards, reprojected_wu_filename = os.path.split(
self.reprojected_wu_filepath
)
reprojection.reproject_lazy_work_unit(
wu,
patch_wcs,
directory_containing_reprojected_shards,
reprojected_wu_filename,
frame="ebd",
max_parallel_processes=self.n_workers,
)
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to create the sharded reprojected WorkUnit.")

return self.reprojected_wu_filepath

Expand Down
14 changes: 0 additions & 14 deletions src/kbmod_wf/task_impls/uri_to_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,6 @@
from kbmod import ImageCollection


def placeholder_uri_to_ic(
target_uris_file_path=None, uris_base_dir=None, ic_output_file_path=None, logger=None
):
with open(target_uris_file_path, "r") as f:
for line in f:
value = line.strip()
logger.info(line.strip())

with open(ic_output_file_path, "w") as f:
f.write(f"Logged: {value} - {time.time()}\n")

return ic_output_file_path


#! I believe that we can remove the `uris_base_dir` parameter from the function
#! signature. It doesn't seem to be used in practice.
def uri_to_ic(
Expand Down
62 changes: 41 additions & 21 deletions src/kbmod_wf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,25 @@ def create_uri_manifest(inputs=[], outputs=[], runtime_config={}, logging_file=N
cache=True, executors=get_executors(["local_dev_testing", "small_cpu"]), ignore_for_cache=["logging_file"]
)
def uri_to_ic(inputs=[], outputs=[], runtime_config={}, logging_file=None):
import traceback
from kbmod_wf.utilities.logger_utilities import configure_logger
from kbmod_wf.task_impls.uri_to_ic import uri_to_ic

logger = configure_logger("task.uri_to_ic", logging_file.filepath)

logger.info("Starting uri_to_ic")
uri_to_ic(
uris_filepath=inputs[0].filepath,
uris_base_dir=None, # determine what, if any, value should be used.
ic_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
try:
uri_to_ic(
uris_filepath=inputs[0].filepath,
uris_base_dir=None, # determine what, if any, value should be used.
ic_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
except Exception as e:
logger.error(f"Error running uri_to_ic: {e}")
logger.error(traceback.format_exc())
raise e
logger.warning("Completed uri_to_ic")

return outputs[0]
Expand All @@ -76,40 +82,54 @@ def uri_to_ic(inputs=[], outputs=[], runtime_config={}, logging_file=None):
cache=True, executors=get_executors(["local_dev_testing", "large_mem"]), ignore_for_cache=["logging_file"]
)
def ic_to_wu(inputs=[], outputs=[], runtime_config={}, logging_file=None):
import traceback
from kbmod_wf.utilities.logger_utilities import configure_logger
from kbmod_wf.task_impls.ic_to_wu import ic_to_wu

logger = configure_logger("task.ic_to_wu", logging_file.filepath)

logger.info("Starting ic_to_wu")
ic_to_wu(
ic_filepath=inputs[0].filepath,
wu_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
try:
ic_to_wu(
ic_filepath=inputs[0].filepath,
wu_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
except Exception as e:
logger.error(f"Error running ic_to_wu: {e}")
logger.error(traceback.format_exc())
raise e
logger.warning("Completed ic_to_wu")

return outputs[0]


@python_app(
cache=True, executors=get_executors(["local_dev_testing", "large_mem"]), ignore_for_cache=["logging_file"]
cache=True,
executors=get_executors(["local_dev_testing", "sharded_reproject"]),
ignore_for_cache=["logging_file"],
)
def reproject_wu(inputs=[], outputs=[], runtime_config={}, logging_file=None):
import traceback
from kbmod_wf.utilities.logger_utilities import configure_logger
from kbmod_wf.task_impls.reproject_wu import reproject_wu

logger = configure_logger("task.reproject_wu", logging_file.filepath)

logger.info("Starting reproject_ic")
reproject_wu(
original_wu_filepath=inputs[0].filepath,
uri_filepath=inputs[1].filepath,
reprojected_wu_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
try:
reproject_wu(
original_wu_filepath=inputs[0].filepath,
uri_filepath=inputs[1].filepath,
reprojected_wu_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
except Exception as e:
logger.error(f"Error running reproject_ic: {e}")
logger.error(traceback.format_exc())
raise e
logger.warning("Completed reproject_ic")

return outputs[0]
Expand Down

0 comments on commit f31d2ad

Please sign in to comment.