Skip to content

Commit

Permalink
add check_additivity warning for skorch
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Dec 16, 2023
1 parent 63f14f9 commit 3ae3fe6
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 29 deletions.
22 changes: 22 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
# Release Notes


## Version 0.4.4:
### Breaking Changes
-
-

### New Features
-
-

### Bug Fixes
- Add warning to set `shap_kwargs=dict(check_additivity=True)` for skorch models, and switch this on for the tests.
-

### Improvements
-
-

### Other Changes
-
-


## Version 0.4.3:
### Breaking Changes
-
Expand Down
21 changes: 12 additions & 9 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(
self.y.name = "Target"

self.metric = permutation_metric
self.shap_kwargs = shap_kwargs or {}

if shap == "guess":
shap_guess = guess_shap(self.model)
Expand Down Expand Up @@ -331,7 +332,13 @@ def __init__(
"not be calculated!"
)
self.interactions_should_work = False
self.shap_kwargs = shap_kwargs if shap_kwargs else {}
if self.shap == "skorch":
print(
"WARNING: For shap='skorch' the additivity check tends to fail, "
"you set set shap_kwargs=dict(check_additivity=False) to supress "
"this error (at your own risk)!"
)

self.model_output = model_output

if idxs is not None:
Expand Down Expand Up @@ -2783,9 +2790,7 @@ def model_predict(data_asarray):
def shap_base_value(self, pos_label=None):
"""SHAP base value: average outcome of population"""
if not hasattr(self, "_shap_base_value"):
_ = (
self.get_shap_values_df()
) # CatBoost needs to have shap values calculated before expected value for some reason
_ = self.get_shap_values_df() # CatBoost needs to have shap values calculated before expected value for some reason
self._shap_base_value = self.shap_explainer.expected_value
if (
isinstance(self._shap_base_value, np.ndarray)
Expand Down Expand Up @@ -4605,8 +4610,8 @@ def get_decisionpath_df(self, tree_idx, index, pos_label=None):
dataframe with summary of the decision tree path
"""
assert tree_idx >= 0 and tree_idx < len(
self.shadow_trees
assert (
tree_idx >= 0 and tree_idx < len(self.shadow_trees)
), f"tree index {tree_idx} outside 0 and number of trees ({len(self.decision_trees)}) range"
X_row = self.get_X_row(index)
if self.is_classifier:
Expand Down Expand Up @@ -4756,9 +4761,7 @@ def shadow_trees(self):
"Calculating ShadowDecTree for each individual decision tree...",
flush=True,
)
assert hasattr(
self.model, "estimators_"
), """self.model does not have an estimators_ attribute, so probably not
assert hasattr(self.model, "estimators_"), """self.model does not have an estimators_ attribute, so probably not
actually a sklearn RandomForest?"""
y = self.y if self.y_missing else self.y.astype("int16")
self._shadow_trees = [
Expand Down
90 changes: 70 additions & 20 deletions tests/test_skorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@

from explainerdashboard.explainers import RegressionExplainer, ClassifierExplainer


@pytest.fixture(scope="session")
def skorch_regressor():
X, y = make_regression(100, 5, n_informative=3, random_state=0)
X = X.astype(np.float32)
y = y / np.std(y)
y = y.reshape(-1, 1).astype(np.float32)

X_df = pd.DataFrame(X, columns=['col'+str(i) for i in range(X.shape[1])])
X_df = pd.DataFrame(X, columns=["col" + str(i) for i in range(X.shape[1])])

class MyModule(nn.Module):
def __init__(skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.ReLU()):
def __init__(
skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.ReLU()
):
super(MyModule, skorch_classifier_explainer).__init__()

skorch_classifier_explainer.dense0 = nn.Linear(input_units, num_units)
Expand All @@ -28,8 +31,12 @@ def __init__(skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.
skorch_classifier_explainer.output = nn.Linear(num_units, 1)

def forward(skorch_classifier_explainer, X, **kwargs):
X = skorch_classifier_explainer.nonlin(skorch_classifier_explainer.dense0(X))
X = skorch_classifier_explainer.nonlin(skorch_classifier_explainer.dense1(X))
X = skorch_classifier_explainer.nonlin(
skorch_classifier_explainer.dense0(X)
)
X = skorch_classifier_explainer.nonlin(
skorch_classifier_explainer.dense1(X)
)
X = skorch_classifier_explainer.output(X)
return X

Expand All @@ -43,17 +50,19 @@ def forward(skorch_classifier_explainer, X, **kwargs):
model.fit(X_df.values, y)
return model, X_df, y


@pytest.fixture(scope="session")
def skorch_classifier():
X, y = make_classification(200, 5, n_informative=3, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

X_df = pd.DataFrame(X, columns=['col'+str(i) for i in range(X.shape[1])])

X_df = pd.DataFrame(X, columns=["col" + str(i) for i in range(X.shape[1])])

class MyModule(nn.Module):
def __init__(skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.ReLU()):
def __init__(
skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.ReLU()
):
super(MyModule, skorch_classifier_explainer).__init__()

skorch_classifier_explainer.dense0 = nn.Linear(input_units, num_units)
Expand All @@ -63,12 +72,17 @@ def __init__(skorch_classifier_explainer, input_units=5, num_units=5, nonlin=nn.
skorch_classifier_explainer.softmax = nn.Softmax(dim=-1)

def forward(skorch_classifier_explainer, X, **kwargs):
X = skorch_classifier_explainer.nonlin(skorch_classifier_explainer.dense0(X))
X = skorch_classifier_explainer.nonlin(skorch_classifier_explainer.dense1(X))
X = skorch_classifier_explainer.softmax(skorch_classifier_explainer.output(X))
X = skorch_classifier_explainer.nonlin(
skorch_classifier_explainer.dense0(X)
)
X = skorch_classifier_explainer.nonlin(
skorch_classifier_explainer.dense1(X)
)
X = skorch_classifier_explainer.softmax(
skorch_classifier_explainer.output(X)
)
return X


model = NeuralNetClassifier(
MyModule,
max_epochs=20,
Expand All @@ -78,76 +92,112 @@ def forward(skorch_classifier_explainer, X, **kwargs):
model.fit(X_df.values, y)
return model, X_df, y


@pytest.fixture(scope="session")
def skorch_regressor_explainer(skorch_regressor):
model, X, y = skorch_regressor
return RegressionExplainer(model, X, y)
return RegressionExplainer(model, X, y, shap_kwargs=dict(check_additivity=False))


@pytest.fixture(scope="session")
def skorch_classifier_explainer(skorch_classifier):
model, X, y = skorch_classifier
return ClassifierExplainer(model, X, y)
return ClassifierExplainer(model, X, y, shap_kwargs=dict(check_additivity=False))


def test_preds(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.preds, np.ndarray)


def test_permutation_importances(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.get_permutation_importances_df(), pd.DataFrame)
assert isinstance(
skorch_regressor_explainer.get_permutation_importances_df(), pd.DataFrame
)


def test_shap_base_value(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.shap_base_value(), (np.floating, float))
assert isinstance(
skorch_regressor_explainer.shap_base_value(), (np.floating, float)
)


def test_shap_values_shape(skorch_regressor_explainer):
assert (skorch_regressor_explainer.get_shap_values_df().shape == (len(skorch_regressor_explainer), len(skorch_regressor_explainer.merged_cols)))
assert skorch_regressor_explainer.get_shap_values_df().shape == (
len(skorch_regressor_explainer),
len(skorch_regressor_explainer.merged_cols),
)


def test_shap_values(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.get_shap_values_df(), pd.DataFrame)


def test_mean_abs_shap(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.get_mean_abs_shap_df(), pd.DataFrame)


def test_calculate_properties(skorch_regressor_explainer):
skorch_regressor_explainer.calculate_properties(include_interactions=False)


def test_pdp_df(skorch_regressor_explainer):
assert isinstance(skorch_regressor_explainer.pdp_df("col1"), pd.DataFrame)


def test_preds(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.preds, np.ndarray)


def test_pred_probas(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.pred_probas(), np.ndarray)


def test_permutation_importances(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.get_permutation_importances_df(), pd.DataFrame)
assert isinstance(
skorch_classifier_explainer.get_permutation_importances_df(), pd.DataFrame
)


def test_shap_base_value(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.shap_base_value(), (np.floating, float))
assert isinstance(
skorch_classifier_explainer.shap_base_value(), (np.floating, float)
)


def test_shap_values_shape(skorch_classifier_explainer):
assert (skorch_classifier_explainer.get_shap_values_df().shape == (len(skorch_classifier_explainer), len(skorch_classifier_explainer.merged_cols)))
assert skorch_classifier_explainer.get_shap_values_df().shape == (
len(skorch_classifier_explainer),
len(skorch_classifier_explainer.merged_cols),
)


def test_shap_values(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.get_shap_values_df(), pd.DataFrame)


def test_mean_abs_shap(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.get_mean_abs_shap_df(), pd.DataFrame)


def test_calculate_properties(skorch_classifier_explainer):
skorch_classifier_explainer.calculate_properties(include_interactions=False)


def test_pdp_df(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.pdp_df("col1"), pd.DataFrame)


def test_metrics(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.metrics(), dict)


def test_precision_df(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.get_precision_df(), pd.DataFrame)



def test_lift_curve_df(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.get_liftcurve_df(), pd.DataFrame)


def test_prediction_result_df(skorch_classifier_explainer):
assert isinstance(skorch_classifier_explainer.prediction_result_df(0), pd.DataFrame)

0 comments on commit 3ae3fe6

Please sign in to comment.