From 3055054c7fb729053b1cb0c8bb993c88f0dc8fe2 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Fri, 28 Jun 2024 16:31:09 +0200 Subject: [PATCH 1/2] fix: line breaks in interaction plots --- shapash/explainer/smart_plotter.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 3905bfe2..87fda432 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -2151,15 +2151,6 @@ def _plot_interactions_scatter(self, x_name, y_name, col_name, x_values, y_value ------- go.Figure """ - # add break line to X label if necessary - max_len_by_row = max([round(50 / self.explainer.features_desc[x_values.columns.values[0]]), 8]) - x_values.iloc[:, 0] = x_values.iloc[:, 0].apply( - add_line_break, - args=( - max_len_by_row, - 120, - ), - ) data_df = pd.DataFrame( { @@ -2210,16 +2201,6 @@ def _plot_interactions_violin(self, x_name, y_name, col_name, x_values, y_values fig = go.Figure() - # add break line to X label - max_len_by_row = max([round(50 / self.explainer.features_desc[x_values.columns.values[0]]), 8]) - x_values.iloc[:, 0] = x_values.iloc[:, 0].apply( - add_line_break, - args=( - max_len_by_row, - 120, - ), - ) - uniq_l = list(pd.unique(x_values.values.flatten())) uniq_l.sort() @@ -2445,6 +2426,16 @@ def interactions_plot( if col_id1 != col_id2: interaction_values = interaction_values * 2 + # add break line to X label if necessary + max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values1.columns.values[0]]), 8]) + feature_values1.iloc[:, 0] = feature_values1.iloc[:, 0].apply( + add_line_break, + args=( + max_len_by_row, + 120, + ), + ) + # selecting the best plot : Scatter, Violin? if col_value_count1 > violin_maxf: fig = self._plot_interactions_scatter( From a467eaae648ea5f16a619b366a0c09bde7aad656 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Mon, 1 Jul 2024 15:11:27 +0200 Subject: [PATCH 2/2] fix somme future warning --- shapash/report/project_report.py | 13 ++++++++----- shapash/utils/utils.py | 9 +++++++-- tests/unit_tests/explainer/test_smart_plotter.py | 4 ++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/shapash/report/project_report.py b/shapash/report/project_report.py index d7bc8b39..a646f9cd 100644 --- a/shapash/report/project_report.py +++ b/shapash/report/project_report.py @@ -1,3 +1,4 @@ +import importlib.metadata import logging import os import sys @@ -219,11 +220,13 @@ def display_model_analysis(self): print_md(f"**Library :** {self.explainer.model.__class__.__module__}") for _, module in sorted(sys.modules.items()): - if ( - hasattr(module, "__version__") - and self.explainer.model.__class__.__module__.split(".")[0] == module.__name__ - ): - print_md(f"**Library version :** {module.__version__}") + module_name = module.__name__.split(".")[0] + if self.explainer.model.__class__.__module__.split(".")[0] == module_name: + try: + version = importlib.metadata.version(module_name) + print_md(f"**Library version :** {version}") + except importlib.metadata.PackageNotFoundError: + print_md(f"**Library version :** not found for {module_name}") print_md("**Model parameters :** ") model_params = self.explainer.model.__dict__ diff --git a/shapash/utils/utils.py b/shapash/utils/utils.py index 0fc4b673..4e3e87f4 100644 --- a/shapash/utils/utils.py +++ b/shapash/utils/utils.py @@ -160,11 +160,16 @@ def compute_digit_number(value): int number of digits """ + if isinstance(value, np.ndarray): + scalar_value = value.item() + else: + scalar_value = value + # fix for 0 value - if value == 0: + if scalar_value == 0: first_nz = 1 else: - first_nz = int(math.log10(abs(value))) + first_nz = int(math.log10(abs(scalar_value))) digit = abs(min(3, first_nz) - 3) return digit diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index cf61aad1..ffd3620a 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -1330,9 +1330,9 @@ def test_plot_line_comparison_1(self): name=f"Id: {index[i]}", hovertext=[ f"Id: {index[i]}
X1
Contribution: {contributions[0][i]:.4f}
" - + f"Value: {data.iloc[i][0]}", + + f"Value: {data.iloc[i,0]}", f"Id: {index[i]}
X2
Contribution: {contributions[1][i]:.4f}
" - + f"Value: {data.iloc[i][1]}", + + f"Value: {data.iloc[i,1]}", ], marker={"color": colors[i]}, )