Skip to content

Commit

Permalink
minor test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed Nov 28, 2024
1 parent be4985a commit ec7ca8f
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions causaltune/score/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
# print("backdoor")
if multivalue:
# TODO: support other metrics for the multivalue case
return ["energy_distance", "psw_energy_distance"]
return ["psw_energy_distance", "energy_distance"] # TODO: add erupt
else:
metrics = [
"erupt",
"norm_erupt",
"prob_erupt", # NEW
# "prob_erupt", # regular erupt was made probabilistic, no need for a separate one
"policy_risk", # NEW
"qini",
"auc",
# "r_scorer",
"energy_distance",
"energy_distance", # is broken without propensity weighting
"psw_energy_distance",
"frobenius_norm", # NEW
"codec", # NEW
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def make_scores(
)[:, 1]
values["policy"] = cate_estimate > 0
values["norm_policy"] = cate_estimate > simple_ate
values["weights"] = self.erupt.weights(df, lambda x: cate_estimate > 0)
# values["weights"] = self.erupt.weights(df, lambda x: cate_estimate > 0)
else:
pass
# TODO: what do we do here if multiple treatments?
Expand All @@ -1297,17 +1297,6 @@ def make_scores(
)
out["norm_erupt"] = norm_erupt_score

# if "prob_erupt" in metrics_to_report:
# out["prob_erupt"] = self.erupt.probabilistic_erupt_score(
# df, df[est._outcome_name], estimate, cate_estimate
# )

if "prob_erupt" in metrics_to_report:
prob_erupt_score = self.erupt.probabilistic_erupt_score(
df, df[outcome_name], estimate
)
out["prob_erupt"] = prob_erupt_score

# if "frobenius_norm" in metrics_to_report:
# out["frobenius_norm"] = self.frobenius_norm_score(estimate, df)

Expand Down

0 comments on commit ec7ca8f

Please sign in to comment.