From 044b6ea02583012cc1b9a71efa67f74f40438d58 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Thu, 24 Oct 2024 11:01:45 +0200 Subject: [PATCH] fix bugs of the 2.7.2 --- pyproject.toml | 2 +- shapash/explainer/consistency.py | 12 +++---- shapash/explainer/smart_plotter.py | 2 +- shapash/plots/plot_bar_chart.py | 2 +- shapash/plots/plot_compacity.py | 2 +- shapash/plots/plot_contribution.py | 8 +++-- shapash/plots/plot_correlations.py | 41 ++++++++++++------------ shapash/plots/plot_feature_importance.py | 6 ++-- shapash/plots/plot_line_comparison.py | 2 +- shapash/plots/plot_scatter_prediction.py | 2 +- shapash/plots/plot_stability.py | 2 +- shapash/utils/utils.py | 33 +++++++++++++++++++ 12 files changed, 72 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf2c32a0..a277a1d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ where = ["."] [tool.setuptools.package-data] -"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png", "*.ico"] +"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png", "*.ico", "*.ipynb", "*.html", "*.j2"] [tool.pytest.ini_options] pythonpath = ["."] diff --git a/shapash/explainer/consistency.py b/shapash/explainer/consistency.py index 27bdc7c8..8c5a9ead 100644 --- a/shapash/explainer/consistency.py +++ b/shapash/explainer/consistency.py @@ -1,4 +1,3 @@ -import copy import itertools import matplotlib.pyplot as plt @@ -671,13 +670,10 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis """ title = "Pairwise comparison of Consistency:" title += "\ -
How are differences in contributions distributed across features?
" - dict_t = copy.deepcopy(self._style_dict["dict_title_stability"]) - dict_xaxis = copy.deepcopy(self._style_dict["dict_xaxis"]) - dict_yaxis = copy.deepcopy(self._style_dict["dict_yaxis"]) - dict_xaxis["text"] = xaxis_title - dict_yaxis["text"] = yaxis_title - dict_t["text"] = title +
How are differences in contributions distributed across features?" + dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"} + dict_xaxis = self._style_dict["dict_xaxis"] | {"text": xaxis_title} + dict_yaxis = self._style_dict["dict_yaxis"] | {"text": yaxis_title} fig.layout.yaxis.update(showticklabels=True) fig.layout.yaxis2.update(showticklabels=False) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 695c9e81..d4e8607e 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -1513,7 +1513,7 @@ def ordinal(n): title = f"Comparing local explanations in a neighborhood - Id: {index}" title += "
How similar are explanations for closeby neighbours?
" - dict_t = self._style_dict["dict_title_stability"] | {"text": title} + dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"} dict_xaxis = self._style_dict["dict_xaxis"] | {"text": "Normalized contribution values"} dict_yaxis = self._style_dict["dict_yaxis"] | {"text": ""} diff --git a/shapash/plots/plot_bar_chart.py b/shapash/plots/plot_bar_chart.py index 3de6c272..ec7ba9fd 100644 --- a/shapash/plots/plot_bar_chart.py +++ b/shapash/plots/plot_bar_chart.py @@ -83,7 +83,7 @@ def plot_bar_chart( if subtitle: title += "
" + subtitle + "" topmargin += 15 - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": "Contribution"} dict_yaxis = style_dict["dict_yaxis"] | {"text": None} dict_local_plot_colors = style_dict["dict_local_plot_colors"] | {"text": None} diff --git a/shapash/plots/plot_compacity.py b/shapash/plots/plot_compacity.py index bf3836ed..f5e2389f 100644 --- a/shapash/plots/plot_compacity.py +++ b/shapash/plots/plot_compacity.py @@ -104,7 +104,7 @@ def plot_compacity( title += ( "
How many variables are enough to produce accurate explanations?
" ) - dict_t = style_dict["dict_title_stability"] | {"text": title} + dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"} fig.update_layout( template="none", diff --git a/shapash/plots/plot_contribution.py b/shapash/plots/plot_contribution.py index 492d71b4..466ce606 100644 --- a/shapash/plots/plot_contribution.py +++ b/shapash/plots/plot_contribution.py @@ -493,7 +493,7 @@ def _update_contributions_fig( title += "
" + subtitle + "" else: title += "
" + addnote + "" - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str(feature_name, 110)} dict_yaxis = style_dict["dict_yaxis"] | {"text": "Contribution"} @@ -571,13 +571,15 @@ def _update_xaxis_labels(fig, xs, zoom=False): k = 6 # Shorten labels that exceed the threshold - feature_val = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in feature_val] + feature_val = [ + x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in feature_val + ] else: k = 10 feature_val = [] for feature_name in xs: feature_name_splited = [ - x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x + x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in feature_name.split("
") ] feature_val_name = "
".join(feature_name_splited) diff --git a/shapash/plots/plot_correlations.py b/shapash/plots/plot_correlations.py index e716fcfc..5222de49 100644 --- a/shapash/plots/plot_correlations.py +++ b/shapash/plots/plot_correlations.py @@ -6,7 +6,7 @@ from plotly.subplots import make_subplots from shapash.manipulation.summarize import compute_corr -from shapash.utils.utils import compute_top_correlations_features +from shapash.utils.utils import compute_top_correlations_features, suffix_duplicates def plot_correlations( @@ -108,6 +108,21 @@ def cluster_corr(corr, degree, inplace=False): return corr[idx, :][:, idx] + # Function to compute correlation matrix and prepare top features + def prepare_corr_matrix(df_subset): + corr = compute_corr(df_subset.drop(features_to_hide, axis=1), compute_method) + top_features = compute_top_correlations_features(corr=corr, max_features=max_features) + corr = cluster_corr(corr.loc[top_features, top_features], degree=degree) + list_features = [col for col in corr.columns if col in top_features] + + # Shorten long feature names and handle duplicates + k = 12 + list_features_shorten = [ + x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in list_features + ] + list_features_shorten = suffix_duplicates(list_features_shorten) + return corr, list_features, list_features_shorten + if features_dict is None: features_dict = {} @@ -142,17 +157,8 @@ def cluster_corr(corr, degree, inplace=False): ) # Used for the Shapash report to get train then test set for i, col_v in enumerate(facet_col_values): - corr = compute_corr(df.loc[df[facet_col] == col_v].drop(features_to_hide, axis=1), compute_method) - - # Keep the same list of features for each subplot - if len(list_features) == 0: - top_features = compute_top_correlations_features(corr=corr, max_features=max_features) - corr = cluster_corr(corr.loc[top_features, top_features], degree=degree) - list_features = list(corr.columns) - k = 6 - list_features_shorten = [ - x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features - ] + df_subset = df[df[facet_col] == col_v] + corr, list_features, list_features_shorten = prepare_corr_matrix(df_subset) fig.add_trace( go.Heatmap( @@ -174,14 +180,7 @@ def cluster_corr(corr, degree, inplace=False): ) else: - corr = compute_corr(df.drop(features_to_hide, axis=1), compute_method) - top_features = compute_top_correlations_features(corr=corr, max_features=max_features) - corr = cluster_corr(corr.loc[top_features, top_features], degree=degree) - list_features = [col for col in corr.columns if col in top_features] - k = 6 - list_features_shorten = [ - x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features - ] + corr, list_features, list_features_shorten = prepare_corr_matrix(df) fig = go.Figure( go.Heatmap( @@ -204,7 +203,7 @@ def cluster_corr(corr, degree, inplace=False): if len(list_features) < len(df.drop(features_to_hide, axis=1).columns): subtitle = f"Top {len(list_features)} correlations" title += f"
{subtitle}
" - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} fig.update_layout( coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict["init_contrib_colorscale"][5:-1]), diff --git a/shapash/plots/plot_feature_importance.py b/shapash/plots/plot_feature_importance.py index 7628b840..cda16b4b 100644 --- a/shapash/plots/plot_feature_importance.py +++ b/shapash/plots/plot_feature_importance.py @@ -234,7 +234,7 @@ def _plot_features_import( else: title += "
" + addnote + "" topmargin = topmargin + 15 - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"} dict_yaxis = style_dict["dict_yaxis"] | {"text": None} dict_style_bar1 = style_dict["dict_featimp_colors"][1] @@ -357,7 +357,7 @@ def _plot_local_features_import( else: title += "
" + addnote + "" topmargin = topmargin + 15 - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"} dict_yaxis = style_dict["dict_yaxis"] | {"text": None} dict_style_bar = {} @@ -537,7 +537,7 @@ def _plot_feature_contributions_cumulative( else: title += "
" + addnote + "" topmargin = topmargin + 15 - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} if (isinstance(lst_feat[0], str)) & (not zoom): # change index to abc...abc if its length is upper than 30 diff --git a/shapash/plots/plot_line_comparison.py b/shapash/plots/plot_line_comparison.py index 7d4c7dde..c6712d6a 100644 --- a/shapash/plots/plot_line_comparison.py +++ b/shapash/plots/plot_line_comparison.py @@ -67,7 +67,7 @@ def plot_line_comparison( else: title = "Compare plot - index : " + " ; ".join(["" + str(id) + "" for id in index]) dict_xaxis["text"] = "Contributions" - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} if subtitle is not None: topmargin += 15 * height / 275 diff --git a/shapash/plots/plot_scatter_prediction.py b/shapash/plots/plot_scatter_prediction.py index 65849a92..18a3e8f1 100644 --- a/shapash/plots/plot_scatter_prediction.py +++ b/shapash/plots/plot_scatter_prediction.py @@ -116,7 +116,7 @@ def plot_scatter_prediction( title += "
" + subtitle + "" else: title += "
" + addnote + "" - dict_t = style_dict["dict_title"] | {"text": title} + dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str("True Values", 110)} dict_yaxis = style_dict["dict_yaxis"] | {"text": "Predicted Values"} diff --git a/shapash/plots/plot_stability.py b/shapash/plots/plot_stability.py index b46a8eb0..77930c96 100644 --- a/shapash/plots/plot_stability.py +++ b/shapash/plots/plot_stability.py @@ -172,7 +172,7 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t """ title = "Importance & Local Stability of explanations:" title += "
How similar are explanations for closeby neighbours?
" - dict_t = style_dict["dict_title_stability"] | {"text": title} + dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"} dict_xaxis = style_dict["dict_xaxis"] | {"text": xaxis_title} dict_yaxis = style_dict["dict_yaxis"] | {"text": yaxis_title} diff --git a/shapash/utils/utils.py b/shapash/utils/utils.py index b103b3d7..5f34c04d 100644 --- a/shapash/utils/utils.py +++ b/shapash/utils/utils.py @@ -12,6 +12,39 @@ from shapash.explainer.smart_state import SmartState +def suffix_duplicates(lst): + """ + Adds suffixes (_2, _3, ...) to non-unique elements in a list to make them unique. + + Args: + lst (list): The input list of elements (strings) which may contain duplicates. + + Returns: + list: A new list where non-unique elements have suffixes to ensure uniqueness. + + Example: + Input: ["feature1", "feature2", "feature1", "feature2", "feature3"] + Output: ["feature1", "feature2", "feature1_2", "feature2_2", "feature3"] + """ + + seen = {} + result = [] + + for item in lst: + if item in seen: + # If the item has been seen before, increment its count and add a suffix + seen[item] += 1 + new_item = f"{item}_{seen[item] + 1}" + else: + # If the item is seen for the first time, add it without a suffix + seen[item] = 0 + new_item = item + + result.append(new_item) + + return result + + def get_host_name(): """ Get the url of the current host