diff --git a/causaltune/scoring.py b/causaltune/scoring.py index af10c1ba..ac6e6ab3 100644 --- a/causaltune/scoring.py +++ b/causaltune/scoring.py @@ -513,16 +513,7 @@ def policy_risk_score( # Calculate propensity scores using the pre-fitted propensity model propensity_scores = ( self.psw_estimator.estimator.propensity_model.predict_proba( - df.drop( - [ - 'index', - 'variant', - 'Y', - 'dy', - 'yhat' - ], - axis=1 - ) + df[['random'] + self.psw_estimator._effect_modifier_names] ) ) if propensity_scores.ndim == 2: