Skip to content

Commit

Permalink
Persistent runner
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Dec 26, 2024
1 parent 145e6cd commit 4b54349
Showing 1 changed file with 97 additions and 63 deletions.
160 changes: 97 additions & 63 deletions notebooks/RunExperiments/runners/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pickle
import sys
import time
import warnings
from typing import List, Union
import cloudpickle
Expand Down Expand Up @@ -35,6 +36,8 @@
# Configure warnings
warnings.filterwarnings("ignore")

RAY_NAMESPACE = "causaltune_experiments"


def parse_arguments():
parser = argparse.ArgumentParser(description="Run CausalTune experiments")
Expand Down Expand Up @@ -92,7 +95,7 @@ def get_estimator_list(dataset_name):
return [est for est in estimator_list if "Dummy" not in est]


def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int = 3):
def run_experiment(args, dataset_path: str, use_ray: bool):
# Process datasets
data_sets = {}
for dataset in args.datasets:
Expand All @@ -114,32 +117,68 @@ def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int

print(f"Loaded datasets: {list(data_sets.keys())}")

tasks = []
out = []
i_run = 1
already_running = False
if use_ray:
try:
runner = ray.get_actor("TaskRunner")
print("\n" * 4)
print(
"!!! Found an existing detached TaskRunner. Will assume the tasks have already been submitted."
)
print(
"!!! If you want to re-run the experiments from scratch, "
'run ray.kill(ray.get_actor("TaskRunner", namespace="{}")) or recreate the cluster.'.format(
RAY_NAMESPACE
)
)
print("\n" * 4)
already_running = True
except ValueError:
print("Ray: no detached TaskRunner found, creating...")
# This thing will be alive even if the host program exits
# Must be killed explicitly: ray.kill(ray.get_actor("TaskRunner"))
runner = TaskRunner.options(name="TaskRunner", lifetime="detached").remote()

for dataset_name, cd in data_sets.items():
estimators = get_estimator_list(dataset_name)
# Extract case while preserving original string checking logic
if "KCKP" in dataset_name:
case = "KCKP"
elif "KC" in dataset_name:
case = "KC"
elif "IV" in dataset_name:
case = "IV"
else:
case = "RCT"

os.makedirs(f"{out_dir}/{case}", exist_ok=True)
for metric in args.metrics:
fn = make_filename(metric, dataset_name, i_run)
out_fn = os.path.join(out_dir, case, fn)
if os.path.isfile(out_fn):
print(f"File {out_fn} exists, skipping...")
continue
if use_ray:
tasks.append(
remote_single_run.remote(
out = []
if not already_running:
tasks = []
i_run = 1

for dataset_name, cd in data_sets.items():
estimators = get_estimator_list(dataset_name)
# Extract case while preserving original string checking logic
if "KCKP" in dataset_name:
case = "KCKP"
elif "KC" in dataset_name:
case = "KC"
elif "IV" in dataset_name:
case = "IV"
else:
case = "RCT"

os.makedirs(f"{out_dir}/{case}", exist_ok=True)
for metric in args.metrics:
fn = make_filename(metric, dataset_name, i_run)
out_fn = os.path.join(out_dir, case, fn)
if os.path.isfile(out_fn):
print(f"File {out_fn} exists, skipping...")
continue
if use_ray:
tasks.append(
runner.remote_single_run.remote(
dataset_name,
cd,
metric,
args.test_size,
args.num_samples,
args.components_time_budget,
out_dir,
out_fn,
estimators,
)
)
else:
results = single_run(
dataset_name,
cd,
metric,
Expand All @@ -148,48 +187,32 @@ def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int
args.components_time_budget,
out_dir,
out_fn,
estimators,
)
)
else:
results = single_run(
dataset_name,
cd,
metric,
args.test_size,
args.num_samples,
args.components_time_budget,
out_dir,
out_fn,
)
out.append(results)
out.append(results)

if use_ray:
remaining_tasks = tasks
n_fetch_errors = 0
out = []

while remaining_tasks:
print(f"Ray: {len(remaining_tasks)} tasks still running...")
ready_tasks, remaining_tasks = ray.wait(remaining_tasks, num_returns=1, timeout=5)
for ready_task in ready_tasks:
print(f"Ray: task ready: {ready_task}")
for retry in range(ray_get_retries + 1):
try:
result = ray.get(ready_task)
results.append(result)
break
except Exception as e:
print(
f"Ray: error fetching task {ready_task} result (retry {retry} of {ray_get_retries}): {e}"
)
if retry == ray_get_retries:
print("Ray: error: task result could not be fetched")
print(f"Ray: tasks completed with {n_fetch_errors} fetch errors")
while True:
completed, in_progress = ray.get(runner.get_progress.remote())
print(f"Ray: {completed}/{completed + in_progress} tasks completed")
if not in_progress:
print("Ray: all tasks completed!")
break
time.sleep(10)

print("Ray: fetching results...")
out = ray.get(runner.get_results.remote())

for out_fn, results in out:
with open(out_fn, "wb") as f:
pickle.dump(results, f)

if use_ray:
destroy = input("Ray: seems like the results fetched OK. Destroy TaskRunner? ")
if destroy.lower().startswith("y"):
print("Destroying TaskRunner... ", end="")
ray.kill(runner)
print("success!")

return out_dir


Expand Down Expand Up @@ -544,6 +567,7 @@ def run_batch(
ray.init(
"ray://localhost:10001",
runtime_env={"working_dir": ".", "pip": ["causaltune", "catboost"]},
namespace=RAY_NAMESPACE,
)

out_dir = run_experiment(args, dataset_path=dataset_path, use_ray=use_ray)
Expand All @@ -561,7 +585,7 @@ def __init__(self):
self.futures = {}

def remote_single_run(self, *args):
ref = remote_single_run(*args)
ref = remote_single_run.remote(*args)
self.futures[ref.hex()] = ref
return ref.hex()

Expand All @@ -572,9 +596,19 @@ def get_single_result(self, ref_hex):
return ray.get(self.futures[ref_hex])

def is_ready(self, ref_hex):
ready, _ = ray.wait([self.futures[ref_hex]], timeout=0)
ready, _ = ray.wait([self.futures[ref_hex]], timeout=0, fetch_local=False)
return bool(ready)

def all_tasks_ready(self):
_, in_progress = ray.wait(list(self.futures.values()), timeout=0, fetch_local=False)
return not bool(in_progress)

def get_progress(self):
completed, in_progress = ray.wait(
list(self.futures.values()), num_returns=len(self.futures), timeout=0, fetch_local=False
)
return len(completed), len(in_progress)


def single_run(
dataset_name: str,
Expand Down

0 comments on commit 4b54349

Please sign in to comment.