From 5627d9a18533b2ec5d9c05e0b2be727d4eaf5a45 Mon Sep 17 00:00:00 2001 From: "Egor.Kraev" Date: Fri, 6 Dec 2024 10:33:02 +0000 Subject: [PATCH] fix ERUPT std --- causaltune/score/erupt_core.py | 4 +- notebooks/ERUPT basics.ipynb | 130 +++++++++--------- notebooks/RunExperiments/experiment_runner.py | 18 +-- 3 files changed, 77 insertions(+), 75 deletions(-) diff --git a/causaltune/score/erupt_core.py b/causaltune/score/erupt_core.py index ae79260..6a69a8c 100644 --- a/causaltune/score/erupt_core.py +++ b/causaltune/score/erupt_core.py @@ -50,8 +50,8 @@ def erupt_with_std( ] mean += np.mean(means) std += np.std(means) / np.sqrt(num_splits) # Standard error of the mean - - return mean / resamples, std / resamples + # 1.5 is an empirical factor to make the confidence interval wider + return mean / resamples, 1.5 * std / resamples def erupt( diff --git a/notebooks/ERUPT basics.ipynb b/notebooks/ERUPT basics.ipynb index c4bdb54..94fcb30 100644 --- a/notebooks/ERUPT basics.ipynb +++ b/notebooks/ERUPT basics.ipynb @@ -103,33 +103,33 @@ " \n", " \n", " 0\n", - " 0.452636\n", - " 0\n", - " 1.684484\n", + " 0.898227\n", + " 1\n", + " 1.288637\n", " \n", " \n", " 1\n", - " 0.380215\n", + " 0.462092\n", " 0\n", - " 0.745268\n", + " 0.771976\n", " \n", " \n", " 2\n", - " 0.584036\n", - " 1\n", - " 0.762300\n", + " 0.858974\n", + " 0\n", + " 1.881019\n", " \n", " \n", " 3\n", - " 0.505191\n", - " 0\n", - " 1.425354\n", + " 0.228084\n", + " 1\n", + " 0.357797\n", " \n", " \n", " 4\n", - " 0.384110\n", + " 0.962512\n", " 1\n", - " 1.834628\n", + " 1.066413\n", " \n", " \n", "\n", @@ -137,11 +137,11 @@ ], "text/plain": [ " X T1 Y1\n", - "0 0.452636 0 1.684484\n", - "1 0.380215 0 0.745268\n", - "2 0.584036 1 0.762300\n", - "3 0.505191 0 1.425354\n", - "4 0.384110 1 1.834628" + "0 0.898227 1 1.288637\n", + "1 0.462092 0 0.771976\n", + "2 0.858974 0 1.881019\n", + "3 0.228084 1 0.357797\n", + "4 0.962512 1 1.066413" ] }, "execution_count": 2, @@ -216,53 +216,53 @@ " \n", " \n", " 0\n", - " 0.452636\n", - " 0\n", - " 1.684484\n", - " 0.726318\n", - " 0\n", - " 0.273682\n", - " 0.904259\n", + " 0.898227\n", + " 1\n", + " 1.288637\n", + " 0.949114\n", + " 1\n", + " 0.949114\n", + " 2.229118\n", " \n", " \n", " 1\n", - " 0.380215\n", + " 0.462092\n", " 0\n", - " 0.745268\n", - " 0.690108\n", - " 1\n", - " 0.690108\n", - " 1.930383\n", + " 0.771976\n", + " 0.731046\n", + " 0\n", + " 0.268954\n", + " 0.572308\n", " \n", " \n", " 2\n", - " 0.584036\n", - " 1\n", - " 0.762300\n", - " 0.792018\n", + " 0.858974\n", + " 0\n", + " 1.881019\n", + " 0.929487\n", " 1\n", - " 0.792018\n", - " 0.959608\n", + " 0.929487\n", + " 2.601592\n", " \n", " \n", " 3\n", - " 0.505191\n", - " 0\n", - " 1.425354\n", - " 0.752596\n", + " 0.228084\n", + " 1\n", + " 0.357797\n", + " 0.614042\n", " 1\n", - " 0.752596\n", - " 1.017777\n", + " 0.614042\n", + " 0.542638\n", " \n", " \n", " 4\n", - " 0.384110\n", + " 0.962512\n", " 1\n", - " 1.834628\n", - " 0.692055\n", + " 1.066413\n", + " 0.981256\n", " 1\n", - " 0.692055\n", - " 2.374030\n", + " 0.981256\n", + " 2.401383\n", " \n", " \n", "\n", @@ -270,11 +270,11 @@ ], "text/plain": [ " X T1 Y1 p T2 p_of_actual Y2\n", - "0 0.452636 0 1.684484 0.726318 0 0.273682 0.904259\n", - "1 0.380215 0 0.745268 0.690108 1 0.690108 1.930383\n", - "2 0.584036 1 0.762300 0.792018 1 0.792018 0.959608\n", - "3 0.505191 0 1.425354 0.752596 1 0.752596 1.017777\n", - "4 0.384110 1 1.834628 0.692055 1 0.692055 2.374030" + "0 0.898227 1 1.288637 0.949114 1 0.949114 2.229118\n", + "1 0.462092 0 0.771976 0.731046 0 0.268954 0.572308\n", + "2 0.858974 0 1.881019 0.929487 1 0.929487 2.601592\n", + "3 0.228084 1 0.357797 0.614042 1 0.614042 0.542638\n", + "4 0.962512 1 1.066413 0.981256 1 0.981256 2.401383" ] }, "execution_count": 3, @@ -319,10 +319,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Average outcome of the actual biased assignment experiment: 1.411675477573636\n", - "Estimated outcome of random assignment: 1.251567372523789\n", - "95% confidence interval for estimated outcome: 1.2311928820519622 1.2719418629956158\n", - "Average outcome of the actual random assignment experiment: 1.2559621877416332\n" + "Average outcome of the actual biased assignment experiment: 1.4064676444383317\n", + "Estimated outcome of random assignment: 1.2594221770638483\n", + "95% confidence interval for estimated outcome: 1.230204391668238 1.2886399624594587\n", + "Average outcome of the actual random assignment experiment: 1.2461659092712785\n" ] } ], @@ -360,17 +360,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Average outcome of the actual random assignment experiment: 1.2559621877416332\n", - "Estimated outcome of biased assignment: 1.4147647990746988\n", - "Confidence interval for estimated outcome: 1.398423601541284 1.4311059966081134\n", - "Average outcome of the actual biased assignment experiment: 1.411675477573636\n" + "Average outcome of the actual random assignment experiment: 1.2461659092712785\n", + "Estimated outcome of biased assignment: 1.405112521603215\n", + "95% confidence interval for estimated outcome: 1.3814865905561569 1.428738452650273\n", + "Average outcome of the actual biased assignment experiment: 1.4064676444383317\n" ] } ], "source": [ - "# Conversely, we can take the outcome of the fully random test and use it to estimate what the outcome of the biased assignment would have been\n", + "# Conversely, we can take the outcome of the fully random test and use it \n", + "# to estimate what the outcome of the biased assignment would have been\n", "\n", - "# Let's use data from biased assignment experiment to estimate the average effect of fully random assignment\n", "hypothetical_policy = df[\"T2\"]\n", "est, std = erupt_with_std(actual_propensity=0.5*pd.Series(np.ones(len(df))), \n", " actual_treatment=df[\"T1\"],\n", @@ -379,7 +379,7 @@ "\n", "print(\"Average outcome of the actual random assignment experiment:\", df[\"Y1\"].mean())\n", "print(\"Estimated outcome of biased assignment:\", est)\n", - "print(\"Confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n", + "print(\"95% confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n", "print(\"Average outcome of the actual biased assignment experiment:\", df[\"Y2\"].mean())" ] }, @@ -388,7 +388,7 @@ "id": "f724dbc3", "metadata": {}, "source": [ - "As you can see, the actual outcome is well within the confidence interval estimated by ERUPT" + "As you can see, the actual outcome is within the confidence interval estimated by ERUPT" ] }, { diff --git a/notebooks/RunExperiments/experiment_runner.py b/notebooks/RunExperiments/experiment_runner.py index b8eec36..033b86a 100644 --- a/notebooks/RunExperiments/experiment_runner.py +++ b/notebooks/RunExperiments/experiment_runner.py @@ -88,7 +88,6 @@ def get_estimator_list(dataset_name): return [est for est in estimator_list if "Dummy" not in est] - def run_experiment(args): # Process datasets data_sets = {} @@ -281,10 +280,12 @@ def get_all_test_scores(out_dir, dataset_name): return out -def generate_plots(out_dir: str, - log_scale: List[str]|None = None, - upper_bounds: dict|None = None, - lower_bounds: dict|None=None): +def generate_plots( + out_dir: str, + log_scale: List[str] | None = None, + upper_bounds: dict | None = None, + lower_bounds: dict | None = None, +): if log_scale is None: log_scale = ["energy_distance", "psw_energy_distance", "frobenius_norm"] metrics, datasets = extract_metrics_datasets(out_dir) @@ -404,14 +405,13 @@ def plot_mse_grid(title): legend_elements = [] for j, dataset in enumerate(datasets): df = get_all_test_scores(out_dir, dataset) - for m, value in upper_bounds.items(): + for m, value in upper_bounds.items(): if m in df.columns: df = df[df[m] < value].copy() for m, value in lower_bounds.items(): if m in df.columns: df = df[df[m] > value].copy() - for i, metric in enumerate(all_metrics): ax = axs[i, j] this_df = df[["estimator_name", metric, "MSE"]].dropna() @@ -502,4 +502,6 @@ def plot_mse_grid(title): out_dir = run_experiment(args) # upper_bounds = {"MSE": 1e2, "policy_risk": 0.2} # lower_bounds = {"erupt": 0.06, "bite": 0.75} - generate_plots(os.path.join(out_dir, "RCT"))#, upper_bounds=upper_bounds, lower_bounds=lower_bounds) + generate_plots( + os.path.join(out_dir, "RCT") + ) # , upper_bounds=upper_bounds, lower_bounds=lower_bounds)