diff --git a/requirements.dev.txt b/requirements.dev.txt index b201c7d0..89553a5c 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -13,7 +13,7 @@ dash-table==5.0.0 lightgbm==2.3.1 pandas>=2.1.0 plotly==5.6.0 -shap>=0.38.1,<0.45.0 +shap>=0.45.0 Sphinx==4.5.0 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 diff --git a/setup.py b/setup.py index 93bb8fb9..b1cd1884 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ "matplotlib>=3.2.0", "numpy>1.18.0", "pandas>=2.1.0", - "shap>=0.38.1,<0.45.0", + "shap>=0.45.0", "Flask<2.3.0", "dash>=2.3.1", "dash-bootstrap-components>=1.1.0", diff --git a/tests/integration_tests/test_contributions_multiclass.py b/tests/integration_tests/test_contributions_multiclass.py index 0a2e7bc1..ad1d68a7 100644 --- a/tests/integration_tests/test_contributions_multiclass.py +++ b/tests/integration_tests/test_contributions_multiclass.py @@ -97,7 +97,10 @@ def test_rank_contributions_1(self): model.fit(self.x_train, self.y_train) explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(self.x_test) - slist = [pd.DataFrame(data=tab, index=self.x_test.index, columns=self.x_test.columns) for tab in shap_values] + slist = [ + pd.DataFrame(data=shap_values[:, :, i], index=self.x_test.index, columns=self.x_test.columns) + for i in range(3) + ] for i in range(3): s_ord, x_ord, s_dict = rank_contributions(slist[i], pd.DataFrame(data=self.x_test))