From 73c46b507d1d89091ca4be7eca00f42af3e3037b Mon Sep 17 00:00:00 2001 From: AlxdrPolyakov <122611538+AlxdrPolyakov@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:13:01 +0100 Subject: [PATCH 1/2] fix policy score Signed-off-by: AlxdrPolyakov <122611538+AlxdrPolyakov@users.noreply.github.com> --- causaltune/scoring.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/causaltune/scoring.py b/causaltune/scoring.py index af10c1ba..e69970b3 100644 --- a/causaltune/scoring.py +++ b/causaltune/scoring.py @@ -513,17 +513,7 @@ def policy_risk_score( # Calculate propensity scores using the pre-fitted propensity model propensity_scores = ( self.psw_estimator.estimator.propensity_model.predict_proba( - df.drop( - [ - 'index', - 'variant', - 'Y', - 'dy', - 'yhat' - ], - axis=1 - ) - ) + df[['random'] + self.psw_estimator._effect_modifier_names] ) if propensity_scores.ndim == 2: # Use second column if 2D array From d1d1416f026a9efa589319f5ba5a56d2e40ed3bb Mon Sep 17 00:00:00 2001 From: AlxdrPolyakov <122611538+AlxdrPolyakov@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:20:20 +0100 Subject: [PATCH 2/2] Update scoring.py Signed-off-by: AlxdrPolyakov <122611538+AlxdrPolyakov@users.noreply.github.com> --- causaltune/scoring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/causaltune/scoring.py b/causaltune/scoring.py index e69970b3..ac6e6ab3 100644 --- a/causaltune/scoring.py +++ b/causaltune/scoring.py @@ -514,6 +514,7 @@ def policy_risk_score( propensity_scores = ( self.psw_estimator.estimator.propensity_model.predict_proba( df[['random'] + self.psw_estimator._effect_modifier_names] + ) ) if propensity_scores.ndim == 2: # Use second column if 2D array