Skip to content

Commit

Permalink
Merge pull request #330 from EgorKraevTransferwise/experiment_plotting
Browse files Browse the repository at this point in the history
Experiment plotting improvements
  • Loading branch information
EgorKraevTransferwise authored Dec 6, 2024
2 parents 346aa84 + 5627d9a commit 1205430
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 226 deletions.
30 changes: 3 additions & 27 deletions causaltune/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from joblib import Parallel, delayed

from causaltune.search.params import SimpleParamService
from causaltune.score.scoring import Scorer
from causaltune.score.scoring import Scorer, metrics_to_minimize
from causaltune.utils import treatment_is_multivalue
from causaltune.models.monkey_patches import (
AutoML,
Expand Down Expand Up @@ -514,19 +514,7 @@ def fit(
evaluated_rewards=(
[] if len(self.resume_scores) == 0 else self.resume_scores
),
mode=(
"min"
if self.metric
in [
"energy_distance",
"psw_energy_distance",
"frobenius_norm",
"psw_frobenius_norm",
"codec",
"policy_risk",
]
else "max"
),
mode=("min" if self.metric in metrics_to_minimize() else "max"),
low_cost_partial_config={},
**self._settings["tuner"],
)
Expand All @@ -547,19 +535,7 @@ def fit(
evaluated_rewards=(
[] if len(self.resume_scores) == 0 else self.resume_scores
),
mode=(
"min"
if self.metric
in [
"energy_distance",
"psw_energy_distance",
"frobenius_norm",
"psw_frobenius_norm",
"codec",
"policy_risk",
]
else "max"
),
mode=("min" if self.metric in metrics_to_minimize() else "max"),
low_cost_partial_config={},
**self._settings["tuner"],
)
Expand Down
4 changes: 2 additions & 2 deletions causaltune/score/erupt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def erupt_with_std(
]
mean += np.mean(means)
std += np.std(means) / np.sqrt(num_splits) # Standard error of the mean

return mean / resamples, std / resamples
# 1.5 is an empirical factor to make the confidence interval wider
return mean / resamples, 1.5 * std / resamples


def erupt(
Expand Down
21 changes: 18 additions & 3 deletions causaltune/score/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def const_marginal_effect(self, X):
return self.cate_estimate


def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List[str]:
def supported_metrics(problem: str, multivalue: bool, scores_only: bool, constant_ptt: bool=False) -> List[str]:
if problem == "iv":
metrics = ["energy_distance", "frobenius_norm", "codec"]
if not scores_only:
Expand All @@ -52,12 +52,12 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
metrics = [
"erupt",
"norm_erupt",
"greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
# "greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
"policy_risk", # NEW
"qini",
"auc",
# "r_scorer",
"energy_distance", # is broken without propensity weighting
"energy_distance", # should only be used in iv problems
"psw_energy_distance",
"frobenius_norm", # NEW
"codec", # NEW
Expand All @@ -68,6 +68,17 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
return metrics


def metrics_to_minimize():
return [
"energy_distance",
"psw_energy_distance",
"codec",
"frobenius_norm",
"psw_frobenius_norm",
"policy_risk",
]


class Scorer:
def __init__(
self,
Expand All @@ -90,6 +101,10 @@ def __init__(
self.identified_estimand = causal_model.identify_effect(
proceed_when_unidentifiable=True
)
if "Dummy" in propensity_model.__class__.__name__:
self.constant_ptt = True
else:
self.constant_ptt = False

if problem == "backdoor":
print(
Expand Down
21 changes: 20 additions & 1 deletion causaltune/search/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
cfg, init_params, low_cost_init_params = flaml_config_to_tune_config(
cls.search_space(data_size=data_size, task=task)
)

cfg, init_params = tweak_config(cfg, init_params, name)
# Test if the estimator instantiates fine
try:
cls(task=task, **init_params)
Expand All @@ -76,6 +76,25 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
return tune.choice(joint_cfg), joint_init_params, joint_low_cost_init_params


def tweak_config(cfg: dict, init_params: dict, estimator_name: str):
"""
Tweak built-in FLAML search spaces to limit the number of estimators
:param cfg:
:param estimator_name:
:return:
"""
out = copy.deepcopy(cfg)
if "xgboost" in estimator_name or estimator_name in [
"random_forest",
"extra_trees",
"lgbm",
"catboost",
]:
out["n_estimators"] = tune.lograndint(4, 1000)
init_params["n_estimators"] = 100
return out, init_params


def model_from_cfg(cfg: dict):
cfg = copy.deepcopy(cfg)
model_name = cfg.pop("estimator_name")
Expand Down
130 changes: 65 additions & 65 deletions notebooks/ERUPT basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,45 +103,45 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.771976</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>1.066413</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1\n",
"0 0.452636 0 1.684484\n",
"1 0.380215 0 0.745268\n",
"2 0.584036 1 0.762300\n",
"3 0.505191 0 1.425354\n",
"4 0.384110 1 1.834628"
"0 0.898227 1 1.288637\n",
"1 0.462092 0 0.771976\n",
"2 0.858974 0 1.881019\n",
"3 0.228084 1 0.357797\n",
"4 0.962512 1 1.066413"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -216,65 +216,65 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.726318</td>\n",
" <td>0</td>\n",
" <td>0.273682</td>\n",
" <td>0.904259</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" <td>0.949114</td>\n",
" <td>1</td>\n",
" <td>0.949114</td>\n",
" <td>2.229118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.690108</td>\n",
" <td>1</td>\n",
" <td>0.690108</td>\n",
" <td>1.930383</td>\n",
" <td>0.771976</td>\n",
" <td>0.731046</td>\n",
" <td>0</td>\n",
" <td>0.268954</td>\n",
" <td>0.572308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.792018</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" <td>0.929487</td>\n",
" <td>1</td>\n",
" <td>0.792018</td>\n",
" <td>0.959608</td>\n",
" <td>0.929487</td>\n",
" <td>2.601592</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.752596</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" <td>0.614042</td>\n",
" <td>1</td>\n",
" <td>0.752596</td>\n",
" <td>1.017777</td>\n",
" <td>0.614042</td>\n",
" <td>0.542638</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>0.692055</td>\n",
" <td>1.066413</td>\n",
" <td>0.981256</td>\n",
" <td>1</td>\n",
" <td>0.692055</td>\n",
" <td>2.374030</td>\n",
" <td>0.981256</td>\n",
" <td>2.401383</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1 p T2 p_of_actual Y2\n",
"0 0.452636 0 1.684484 0.726318 0 0.273682 0.904259\n",
"1 0.380215 0 0.745268 0.690108 1 0.690108 1.930383\n",
"2 0.584036 1 0.762300 0.792018 1 0.792018 0.959608\n",
"3 0.505191 0 1.425354 0.752596 1 0.752596 1.017777\n",
"4 0.384110 1 1.834628 0.692055 1 0.692055 2.374030"
"0 0.898227 1 1.288637 0.949114 1 0.949114 2.229118\n",
"1 0.462092 0 0.771976 0.731046 0 0.268954 0.572308\n",
"2 0.858974 0 1.881019 0.929487 1 0.929487 2.601592\n",
"3 0.228084 1 0.357797 0.614042 1 0.614042 0.542638\n",
"4 0.962512 1 1.066413 0.981256 1 0.981256 2.401383"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -319,10 +319,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n",
"Estimated outcome of random assignment: 1.251567372523789\n",
"95% confidence interval for estimated outcome: 1.2311928820519622 1.2719418629956158\n",
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n"
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n",
"Estimated outcome of random assignment: 1.2594221770638483\n",
"95% confidence interval for estimated outcome: 1.230204391668238 1.2886399624594587\n",
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n"
]
}
],
Expand Down Expand Up @@ -360,17 +360,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n",
"Estimated outcome of biased assignment: 1.4147647990746988\n",
"Confidence interval for estimated outcome: 1.398423601541284 1.4311059966081134\n",
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n"
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n",
"Estimated outcome of biased assignment: 1.405112521603215\n",
"95% confidence interval for estimated outcome: 1.3814865905561569 1.428738452650273\n",
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n"
]
}
],
"source": [
"# Conversely, we can take the outcome of the fully random test and use it to estimate what the outcome of the biased assignment would have been\n",
"# Conversely, we can take the outcome of the fully random test and use it \n",
"# to estimate what the outcome of the biased assignment would have been\n",
"\n",
"# Let's use data from biased assignment experiment to estimate the average effect of fully random assignment\n",
"hypothetical_policy = df[\"T2\"]\n",
"est, std = erupt_with_std(actual_propensity=0.5*pd.Series(np.ones(len(df))), \n",
" actual_treatment=df[\"T1\"],\n",
Expand All @@ -379,7 +379,7 @@
"\n",
"print(\"Average outcome of the actual random assignment experiment:\", df[\"Y1\"].mean())\n",
"print(\"Estimated outcome of biased assignment:\", est)\n",
"print(\"Confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"95% confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"Average outcome of the actual biased assignment experiment:\", df[\"Y2\"].mean())"
]
},
Expand All @@ -388,7 +388,7 @@
"id": "f724dbc3",
"metadata": {},
"source": [
"As you can see, the actual outcome is well within the confidence interval estimated by ERUPT"
"As you can see, the actual outcome is within the confidence interval estimated by ERUPT"
]
},
{
Expand Down
Loading

0 comments on commit 1205430

Please sign in to comment.