Skip to content

Commit

Permalink
🐛 fixed sampling bug in shap explanations, add tests (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Sep 9, 2020
1 parent b4d52ad commit ca2c917
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 67 deletions.
3 changes: 3 additions & 0 deletions supervised/base_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ def _build_dataframe(self, X, y=None):

X, y = ExcludeRowsMissingTarget.transform(X, y, warn=True)

X.reset_index(drop=True, inplace=True)
y.reset_index(drop=True, inplace=True)

return X, y

def _fit(self, X, y):
Expand Down
134 changes: 67 additions & 67 deletions supervised/utils/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def get_sample(X_validation, y_validation):
sample = y_validation.index.tolist()
np.random.shuffle(sample)
sample = sample[:SAMPLES_LIMIT]
X_vald = X_validation.iloc[sample]
y_vald = y_validation.iloc[sample]
X_vald = X_validation.loc[sample]
y_vald = y_validation.loc[sample]
else:
X_vald = X_validation
y_vald = y_validation
Expand Down Expand Up @@ -181,77 +181,77 @@ def compute(
):
if not PlotSHAP.is_available(algorithm, X_train, y_train, ml_task):
return
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
explainer = PlotSHAP.get_explainer(algorithm, X_train)
X_vald, y_vald = PlotSHAP.get_sample(X_validation, y_validation)
shap_values = explainer.shap_values(X_vald)

# fix problem with 1 or 2 dimensions for binary classification
expected_value = explainer.expected_value
if ml_task == BINARY_CLASSIFICATION and isinstance(shap_values, list):
shap_values = shap_values[1]
expected_value = explainer.expected_value[1]

# Summary SHAP plot
PlotSHAP.summary(
shap_values, X_vald, model_file_path, learner_name, class_names
)
# Dependence SHAP plots
if ml_task == MULTICLASS_CLASSIFICATION:
for t in np.unique(y_vald):
PlotSHAP.dependence(
shap_values[t],
X_vald,
model_file_path,
learner_name,
f"_class_{class_names[t]}",
)
else:
PlotSHAP.dependence(shap_values, X_vald, model_file_path, learner_name)

# Decision SHAP plots
df_preds = PlotSHAP.get_predictions(algorithm, X_vald, y_vald, ml_task)

if ml_task == REGRESSION:
PlotSHAP.decisions_regression(
df_preds,
shap_values,
expected_value,
X_vald,
y_vald,
model_file_path,
learner_name,
)
elif ml_task == BINARY_CLASSIFICATION:
PlotSHAP.decisions_binary(
df_preds,
shap_values,
expected_value,
X_vald,
y_vald,
model_file_path,
learner_name,
)
else:
PlotSHAP.decisions_multiclass(
df_preds,
shap_values,
expected_value,
#try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
explainer = PlotSHAP.get_explainer(algorithm, X_train)
X_vald, y_vald = PlotSHAP.get_sample(X_validation, y_validation)
shap_values = explainer.shap_values(X_vald)

# fix problem with 1 or 2 dimensions for binary classification
expected_value = explainer.expected_value
if ml_task == BINARY_CLASSIFICATION and isinstance(shap_values, list):
shap_values = shap_values[1]
expected_value = explainer.expected_value[1]

# Summary SHAP plot
PlotSHAP.summary(
shap_values, X_vald, model_file_path, learner_name, class_names
)
# Dependence SHAP plots
if ml_task == MULTICLASS_CLASSIFICATION:
for t in np.unique(y_vald):
PlotSHAP.dependence(
shap_values[t],
X_vald,
y_vald,
model_file_path,
learner_name,
class_names,
f"_class_{class_names[t]}",
)
except Exception as e:
print(
f"Exception while producing SHAP explanations. {str(e)}\nContinuing ..."
else:
PlotSHAP.dependence(shap_values, X_vald, model_file_path, learner_name)

# Decision SHAP plots
df_preds = PlotSHAP.get_predictions(algorithm, X_vald, y_vald, ml_task)

if ml_task == REGRESSION:
PlotSHAP.decisions_regression(
df_preds,
shap_values,
expected_value,
X_vald,
y_vald,
model_file_path,
learner_name,
)
logger.info(
f"Exception while producing SHAP explanations. {str(e)}\nContinuing ..."
elif ml_task == BINARY_CLASSIFICATION:
PlotSHAP.decisions_binary(
df_preds,
shap_values,
expected_value,
X_vald,
y_vald,
model_file_path,
learner_name,
)
else:
PlotSHAP.decisions_multiclass(
df_preds,
shap_values,
expected_value,
X_vald,
y_vald,
model_file_path,
learner_name,
class_names,
)
#except Exception as e:
# print(
# f"Exception while producing SHAP explanations. {str(e)}\nContinuing ..."
# )
# logger.info(
# f"Exception while producing SHAP explanations. {str(e)}\nContinuing ..."
# )

@staticmethod
def decisions_regression(
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_utils/test_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
import numpy as np
import pandas as pd

from supervised.utils.shap import PlotSHAP


class PlotSHAPTest(unittest.TestCase):
def test_get_sample_data_larger_1k(self):
""" Get sample when data is larger than 1k """
X = pd.DataFrame(np.random.uniform(size=(5763, 31)))
y = pd.Series(np.random.randint(0, 2, size=(5763, )))

X_, y_ = PlotSHAP.get_sample(X, y)

self.assertEqual(X_.shape[0], 1000)
self.assertEqual(y_.shape[0], 1000)

def test_get_sample_data_smaller_1k(self):
""" Get sample when data is smaller than 1k """
SAMPLES = 100
X = pd.DataFrame(np.random.uniform(size=(SAMPLES, 31)))
y = pd.Series(np.random.randint(0, 2, size=(SAMPLES, )))

X_, y_ = PlotSHAP.get_sample(X, y)

self.assertEqual(X_.shape[0], SAMPLES)
self.assertEqual(y_.shape[0], SAMPLES)

0 comments on commit ca2c917

Please sign in to comment.