From 4b543492b2779dd2f737220bf6fa2656f078dcd5 Mon Sep 17 00:00:00 2001 From: whimo Date: Thu, 26 Dec 2024 23:09:54 +0800 Subject: [PATCH] Persistent runner --- .../runners/experiment_runner.py | 160 +++++++++++------- 1 file changed, 97 insertions(+), 63 deletions(-) diff --git a/notebooks/RunExperiments/runners/experiment_runner.py b/notebooks/RunExperiments/runners/experiment_runner.py index 7cdd751..a77af58 100644 --- a/notebooks/RunExperiments/runners/experiment_runner.py +++ b/notebooks/RunExperiments/runners/experiment_runner.py @@ -4,6 +4,7 @@ import os import pickle import sys +import time import warnings from typing import List, Union import cloudpickle @@ -35,6 +36,8 @@ # Configure warnings warnings.filterwarnings("ignore") +RAY_NAMESPACE = "causaltune_experiments" + def parse_arguments(): parser = argparse.ArgumentParser(description="Run CausalTune experiments") @@ -92,7 +95,7 @@ def get_estimator_list(dataset_name): return [est for est in estimator_list if "Dummy" not in est] -def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int = 3): +def run_experiment(args, dataset_path: str, use_ray: bool): # Process datasets data_sets = {} for dataset in args.datasets: @@ -114,32 +117,68 @@ def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int print(f"Loaded datasets: {list(data_sets.keys())}") - tasks = [] - out = [] - i_run = 1 + already_running = False + if use_ray: + try: + runner = ray.get_actor("TaskRunner") + print("\n" * 4) + print( + "!!! Found an existing detached TaskRunner. Will assume the tasks have already been submitted." + ) + print( + "!!! If you want to re-run the experiments from scratch, " + 'run ray.kill(ray.get_actor("TaskRunner", namespace="{}")) or recreate the cluster.'.format( + RAY_NAMESPACE + ) + ) + print("\n" * 4) + already_running = True + except ValueError: + print("Ray: no detached TaskRunner found, creating...") + # This thing will be alive even if the host program exits + # Must be killed explicitly: ray.kill(ray.get_actor("TaskRunner")) + runner = TaskRunner.options(name="TaskRunner", lifetime="detached").remote() - for dataset_name, cd in data_sets.items(): - estimators = get_estimator_list(dataset_name) - # Extract case while preserving original string checking logic - if "KCKP" in dataset_name: - case = "KCKP" - elif "KC" in dataset_name: - case = "KC" - elif "IV" in dataset_name: - case = "IV" - else: - case = "RCT" - - os.makedirs(f"{out_dir}/{case}", exist_ok=True) - for metric in args.metrics: - fn = make_filename(metric, dataset_name, i_run) - out_fn = os.path.join(out_dir, case, fn) - if os.path.isfile(out_fn): - print(f"File {out_fn} exists, skipping...") - continue - if use_ray: - tasks.append( - remote_single_run.remote( + out = [] + if not already_running: + tasks = [] + i_run = 1 + + for dataset_name, cd in data_sets.items(): + estimators = get_estimator_list(dataset_name) + # Extract case while preserving original string checking logic + if "KCKP" in dataset_name: + case = "KCKP" + elif "KC" in dataset_name: + case = "KC" + elif "IV" in dataset_name: + case = "IV" + else: + case = "RCT" + + os.makedirs(f"{out_dir}/{case}", exist_ok=True) + for metric in args.metrics: + fn = make_filename(metric, dataset_name, i_run) + out_fn = os.path.join(out_dir, case, fn) + if os.path.isfile(out_fn): + print(f"File {out_fn} exists, skipping...") + continue + if use_ray: + tasks.append( + runner.remote_single_run.remote( + dataset_name, + cd, + metric, + args.test_size, + args.num_samples, + args.components_time_budget, + out_dir, + out_fn, + estimators, + ) + ) + else: + results = single_run( dataset_name, cd, metric, @@ -148,48 +187,32 @@ def run_experiment(args, dataset_path: str, use_ray: bool, ray_get_retries: int args.components_time_budget, out_dir, out_fn, - estimators, ) - ) - else: - results = single_run( - dataset_name, - cd, - metric, - args.test_size, - args.num_samples, - args.components_time_budget, - out_dir, - out_fn, - ) - out.append(results) + out.append(results) + if use_ray: - remaining_tasks = tasks - n_fetch_errors = 0 - out = [] - - while remaining_tasks: - print(f"Ray: {len(remaining_tasks)} tasks still running...") - ready_tasks, remaining_tasks = ray.wait(remaining_tasks, num_returns=1, timeout=5) - for ready_task in ready_tasks: - print(f"Ray: task ready: {ready_task}") - for retry in range(ray_get_retries + 1): - try: - result = ray.get(ready_task) - results.append(result) - break - except Exception as e: - print( - f"Ray: error fetching task {ready_task} result (retry {retry} of {ray_get_retries}): {e}" - ) - if retry == ray_get_retries: - print("Ray: error: task result could not be fetched") - print(f"Ray: tasks completed with {n_fetch_errors} fetch errors") + while True: + completed, in_progress = ray.get(runner.get_progress.remote()) + print(f"Ray: {completed}/{completed + in_progress} tasks completed") + if not in_progress: + print("Ray: all tasks completed!") + break + time.sleep(10) + + print("Ray: fetching results...") + out = ray.get(runner.get_results.remote()) for out_fn, results in out: with open(out_fn, "wb") as f: pickle.dump(results, f) + if use_ray: + destroy = input("Ray: seems like the results fetched OK. Destroy TaskRunner? ") + if destroy.lower().startswith("y"): + print("Destroying TaskRunner... ", end="") + ray.kill(runner) + print("success!") + return out_dir @@ -544,6 +567,7 @@ def run_batch( ray.init( "ray://localhost:10001", runtime_env={"working_dir": ".", "pip": ["causaltune", "catboost"]}, + namespace=RAY_NAMESPACE, ) out_dir = run_experiment(args, dataset_path=dataset_path, use_ray=use_ray) @@ -561,7 +585,7 @@ def __init__(self): self.futures = {} def remote_single_run(self, *args): - ref = remote_single_run(*args) + ref = remote_single_run.remote(*args) self.futures[ref.hex()] = ref return ref.hex() @@ -572,9 +596,19 @@ def get_single_result(self, ref_hex): return ray.get(self.futures[ref_hex]) def is_ready(self, ref_hex): - ready, _ = ray.wait([self.futures[ref_hex]], timeout=0) + ready, _ = ray.wait([self.futures[ref_hex]], timeout=0, fetch_local=False) return bool(ready) + def all_tasks_ready(self): + _, in_progress = ray.wait(list(self.futures.values()), timeout=0, fetch_local=False) + return not bool(in_progress) + + def get_progress(self): + completed, in_progress = ray.wait( + list(self.futures.values()), num_returns=len(self.futures), timeout=0, fetch_local=False + ) + return len(completed), len(in_progress) + def single_run( dataset_name: str,