Skip to content

Commit

Permalink
fix ERUPT std
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed Dec 6, 2024
1 parent e51b58c commit 5627d9a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 75 deletions.
4 changes: 2 additions & 2 deletions causaltune/score/erupt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def erupt_with_std(
]
mean += np.mean(means)
std += np.std(means) / np.sqrt(num_splits) # Standard error of the mean

return mean / resamples, std / resamples
# 1.5 is an empirical factor to make the confidence interval wider
return mean / resamples, 1.5 * std / resamples


def erupt(
Expand Down
130 changes: 65 additions & 65 deletions notebooks/ERUPT basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,45 +103,45 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.771976</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>1.066413</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1\n",
"0 0.452636 0 1.684484\n",
"1 0.380215 0 0.745268\n",
"2 0.584036 1 0.762300\n",
"3 0.505191 0 1.425354\n",
"4 0.384110 1 1.834628"
"0 0.898227 1 1.288637\n",
"1 0.462092 0 0.771976\n",
"2 0.858974 0 1.881019\n",
"3 0.228084 1 0.357797\n",
"4 0.962512 1 1.066413"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -216,65 +216,65 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.726318</td>\n",
" <td>0</td>\n",
" <td>0.273682</td>\n",
" <td>0.904259</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" <td>0.949114</td>\n",
" <td>1</td>\n",
" <td>0.949114</td>\n",
" <td>2.229118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.690108</td>\n",
" <td>1</td>\n",
" <td>0.690108</td>\n",
" <td>1.930383</td>\n",
" <td>0.771976</td>\n",
" <td>0.731046</td>\n",
" <td>0</td>\n",
" <td>0.268954</td>\n",
" <td>0.572308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.792018</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" <td>0.929487</td>\n",
" <td>1</td>\n",
" <td>0.792018</td>\n",
" <td>0.959608</td>\n",
" <td>0.929487</td>\n",
" <td>2.601592</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.752596</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" <td>0.614042</td>\n",
" <td>1</td>\n",
" <td>0.752596</td>\n",
" <td>1.017777</td>\n",
" <td>0.614042</td>\n",
" <td>0.542638</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>0.692055</td>\n",
" <td>1.066413</td>\n",
" <td>0.981256</td>\n",
" <td>1</td>\n",
" <td>0.692055</td>\n",
" <td>2.374030</td>\n",
" <td>0.981256</td>\n",
" <td>2.401383</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1 p T2 p_of_actual Y2\n",
"0 0.452636 0 1.684484 0.726318 0 0.273682 0.904259\n",
"1 0.380215 0 0.745268 0.690108 1 0.690108 1.930383\n",
"2 0.584036 1 0.762300 0.792018 1 0.792018 0.959608\n",
"3 0.505191 0 1.425354 0.752596 1 0.752596 1.017777\n",
"4 0.384110 1 1.834628 0.692055 1 0.692055 2.374030"
"0 0.898227 1 1.288637 0.949114 1 0.949114 2.229118\n",
"1 0.462092 0 0.771976 0.731046 0 0.268954 0.572308\n",
"2 0.858974 0 1.881019 0.929487 1 0.929487 2.601592\n",
"3 0.228084 1 0.357797 0.614042 1 0.614042 0.542638\n",
"4 0.962512 1 1.066413 0.981256 1 0.981256 2.401383"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -319,10 +319,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n",
"Estimated outcome of random assignment: 1.251567372523789\n",
"95% confidence interval for estimated outcome: 1.2311928820519622 1.2719418629956158\n",
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n"
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n",
"Estimated outcome of random assignment: 1.2594221770638483\n",
"95% confidence interval for estimated outcome: 1.230204391668238 1.2886399624594587\n",
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n"
]
}
],
Expand Down Expand Up @@ -360,17 +360,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n",
"Estimated outcome of biased assignment: 1.4147647990746988\n",
"Confidence interval for estimated outcome: 1.398423601541284 1.4311059966081134\n",
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n"
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n",
"Estimated outcome of biased assignment: 1.405112521603215\n",
"95% confidence interval for estimated outcome: 1.3814865905561569 1.428738452650273\n",
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n"
]
}
],
"source": [
"# Conversely, we can take the outcome of the fully random test and use it to estimate what the outcome of the biased assignment would have been\n",
"# Conversely, we can take the outcome of the fully random test and use it \n",
"# to estimate what the outcome of the biased assignment would have been\n",
"\n",
"# Let's use data from biased assignment experiment to estimate the average effect of fully random assignment\n",
"hypothetical_policy = df[\"T2\"]\n",
"est, std = erupt_with_std(actual_propensity=0.5*pd.Series(np.ones(len(df))), \n",
" actual_treatment=df[\"T1\"],\n",
Expand All @@ -379,7 +379,7 @@
"\n",
"print(\"Average outcome of the actual random assignment experiment:\", df[\"Y1\"].mean())\n",
"print(\"Estimated outcome of biased assignment:\", est)\n",
"print(\"Confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"95% confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"Average outcome of the actual biased assignment experiment:\", df[\"Y2\"].mean())"
]
},
Expand All @@ -388,7 +388,7 @@
"id": "f724dbc3",
"metadata": {},
"source": [
"As you can see, the actual outcome is well within the confidence interval estimated by ERUPT"
"As you can see, the actual outcome is within the confidence interval estimated by ERUPT"
]
},
{
Expand Down
18 changes: 10 additions & 8 deletions notebooks/RunExperiments/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def get_estimator_list(dataset_name):
return [est for est in estimator_list if "Dummy" not in est]



def run_experiment(args):
# Process datasets
data_sets = {}
Expand Down Expand Up @@ -281,10 +280,12 @@ def get_all_test_scores(out_dir, dataset_name):
return out


def generate_plots(out_dir: str,
log_scale: List[str]|None = None,
upper_bounds: dict|None = None,
lower_bounds: dict|None=None):
def generate_plots(
out_dir: str,
log_scale: List[str] | None = None,
upper_bounds: dict | None = None,
lower_bounds: dict | None = None,
):
if log_scale is None:
log_scale = ["energy_distance", "psw_energy_distance", "frobenius_norm"]
metrics, datasets = extract_metrics_datasets(out_dir)
Expand Down Expand Up @@ -404,14 +405,13 @@ def plot_mse_grid(title):
legend_elements = []
for j, dataset in enumerate(datasets):
df = get_all_test_scores(out_dir, dataset)
for m, value in upper_bounds.items():
for m, value in upper_bounds.items():
if m in df.columns:
df = df[df[m] < value].copy()
for m, value in lower_bounds.items():
if m in df.columns:
df = df[df[m] > value].copy()


for i, metric in enumerate(all_metrics):
ax = axs[i, j]
this_df = df[["estimator_name", metric, "MSE"]].dropna()
Expand Down Expand Up @@ -502,4 +502,6 @@ def plot_mse_grid(title):
out_dir = run_experiment(args)
# upper_bounds = {"MSE": 1e2, "policy_risk": 0.2}
# lower_bounds = {"erupt": 0.06, "bite": 0.75}
generate_plots(os.path.join(out_dir, "RCT"))#, upper_bounds=upper_bounds, lower_bounds=lower_bounds)
generate_plots(
os.path.join(out_dir, "RCT")
) # , upper_bounds=upper_bounds, lower_bounds=lower_bounds)

0 comments on commit 5627d9a

Please sign in to comment.