Skip to content

Commit

Permalink
fix bugs of the 2.7.2
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Oct 24, 2024
1 parent c8fd7bd commit 044b6ea
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ where = ["."]


[tool.setuptools.package-data]
"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png", "*.ico"]
"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png", "*.ico", "*.ipynb", "*.html", "*.j2"]

[tool.pytest.ini_options]
pythonpath = ["."]
Expand Down
12 changes: 4 additions & 8 deletions shapash/explainer/consistency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import itertools

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -671,13 +670,10 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
"""
title = "Pairwise comparison of Consistency:"
title += "<span style='font-size: 16px;'>\
<br />How are differences in contributions distributed across features?</span>"
dict_t = copy.deepcopy(self._style_dict["dict_title_stability"])
dict_xaxis = copy.deepcopy(self._style_dict["dict_xaxis"])
dict_yaxis = copy.deepcopy(self._style_dict["dict_yaxis"])
dict_xaxis["text"] = xaxis_title
dict_yaxis["text"] = yaxis_title
dict_t["text"] = title
<br />How are differences in contributions distributed across features?</span>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": yaxis_title}

fig.layout.yaxis.update(showticklabels=True)
fig.layout.yaxis2.update(showticklabels=False)
Expand Down
2 changes: 1 addition & 1 deletion shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def ordinal(n):

title = f"Comparing local explanations in a neighborhood - Id: <b>{index}</b>"
title += "<span style='font-size: 16px;'><br />How similar are explanations for closeby neighbours?</span>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title}
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": "Normalized contribution values"}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": ""}

Expand Down
2 changes: 1 addition & 1 deletion shapash/plots/plot_bar_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def plot_bar_chart(
if subtitle:
title += "<br><sup>" + subtitle + "</sup>"
topmargin += 15
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_local_plot_colors = style_dict["dict_local_plot_colors"] | {"text": None}
Expand Down
2 changes: 1 addition & 1 deletion shapash/plots/plot_compacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def plot_compacity(
title += (
"<span style='font-size: 16px;'><br />How many variables are enough to produce accurate explanations?</span>"
)
dict_t = style_dict["dict_title_stability"] | {"text": title}
dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}

fig.update_layout(
template="none",
Expand Down
8 changes: 5 additions & 3 deletions shapash/plots/plot_contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _update_contributions_fig(
title += "<br><sup>" + subtitle + "</sup>"
else:
title += "<br><sup>" + addnote + "</sup>"
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str(feature_name, 110)}
dict_yaxis = style_dict["dict_yaxis"] | {"text": "Contribution"}

Expand Down Expand Up @@ -571,13 +571,15 @@ def _update_xaxis_labels(fig, xs, zoom=False):
k = 6

# Shorten labels that exceed the threshold
feature_val = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in feature_val]
feature_val = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in feature_val
]
else:
k = 10
feature_val = []
for feature_name in xs:
feature_name_splited = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x
for x in feature_name.split("<br />")
]
feature_val_name = "<br />".join(feature_name_splited)
Expand Down
41 changes: 20 additions & 21 deletions shapash/plots/plot_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from plotly.subplots import make_subplots

from shapash.manipulation.summarize import compute_corr
from shapash.utils.utils import compute_top_correlations_features
from shapash.utils.utils import compute_top_correlations_features, suffix_duplicates


def plot_correlations(
Expand Down Expand Up @@ -108,6 +108,21 @@ def cluster_corr(corr, degree, inplace=False):

return corr[idx, :][:, idx]

# Function to compute correlation matrix and prepare top features
def prepare_corr_matrix(df_subset):
corr = compute_corr(df_subset.drop(features_to_hide, axis=1), compute_method)
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = [col for col in corr.columns if col in top_features]

# Shorten long feature names and handle duplicates
k = 12
list_features_shorten = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in list_features
]
list_features_shorten = suffix_duplicates(list_features_shorten)
return corr, list_features, list_features_shorten

if features_dict is None:
features_dict = {}

Expand Down Expand Up @@ -142,17 +157,8 @@ def cluster_corr(corr, degree, inplace=False):
)
# Used for the Shapash report to get train then test set
for i, col_v in enumerate(facet_col_values):
corr = compute_corr(df.loc[df[facet_col] == col_v].drop(features_to_hide, axis=1), compute_method)

# Keep the same list of features for each subplot
if len(list_features) == 0:
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = list(corr.columns)
k = 6
list_features_shorten = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
]
df_subset = df[df[facet_col] == col_v]
corr, list_features, list_features_shorten = prepare_corr_matrix(df_subset)

fig.add_trace(
go.Heatmap(
Expand All @@ -174,14 +180,7 @@ def cluster_corr(corr, degree, inplace=False):
)

else:
corr = compute_corr(df.drop(features_to_hide, axis=1), compute_method)
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = [col for col in corr.columns if col in top_features]
k = 6
list_features_shorten = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
]
corr, list_features, list_features_shorten = prepare_corr_matrix(df)

fig = go.Figure(
go.Heatmap(
Expand All @@ -204,7 +203,7 @@ def cluster_corr(corr, degree, inplace=False):
if len(list_features) < len(df.drop(features_to_hide, axis=1).columns):
subtitle = f"Top {len(list_features)} correlations"
title += f"<span style='font-size: 12px;'><br />{subtitle}</span>"
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}

fig.update_layout(
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict["init_contrib_colorscale"][5:-1]),
Expand Down
6 changes: 3 additions & 3 deletions shapash/plots/plot_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _plot_features_import(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_style_bar1 = style_dict["dict_featimp_colors"][1]
Expand Down Expand Up @@ -357,7 +357,7 @@ def _plot_local_features_import(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_style_bar = {}
Expand Down Expand Up @@ -537,7 +537,7 @@ def _plot_feature_contributions_cumulative(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}

if (isinstance(lst_feat[0], str)) & (not zoom):
# change index to abc...abc if its length is upper than 30
Expand Down
2 changes: 1 addition & 1 deletion shapash/plots/plot_line_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot_line_comparison(
else:
title = "Compare plot - index : " + " ; ".join(["<b>" + str(id) + "</b>" for id in index])
dict_xaxis["text"] = "Contributions"
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}

if subtitle is not None:
topmargin += 15 * height / 275
Expand Down
2 changes: 1 addition & 1 deletion shapash/plots/plot_scatter_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def plot_scatter_prediction(
title += "<br><sup>" + subtitle + "</sup>"
else:
title += "<br><sup>" + addnote + "</sup>"
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str("True Values", 110)}
dict_yaxis = style_dict["dict_yaxis"] | {"text": "Predicted Values"}

Expand Down
2 changes: 1 addition & 1 deletion shapash/plots/plot_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
"""
title = "Importance & Local Stability of explanations:"
title += "<span style='font-size: 16px;'><br />How similar are explanations for closeby neighbours?</span>"
dict_t = style_dict["dict_title_stability"] | {"text": title}
dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}

dict_xaxis = style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = style_dict["dict_yaxis"] | {"text": yaxis_title}
Expand Down
33 changes: 33 additions & 0 deletions shapash/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,39 @@
from shapash.explainer.smart_state import SmartState


def suffix_duplicates(lst):
"""
Adds suffixes (_2, _3, ...) to non-unique elements in a list to make them unique.
Args:
lst (list): The input list of elements (strings) which may contain duplicates.
Returns:
list: A new list where non-unique elements have suffixes to ensure uniqueness.
Example:
Input: ["feature1", "feature2", "feature1", "feature2", "feature3"]
Output: ["feature1", "feature2", "feature1_2", "feature2_2", "feature3"]
"""

seen = {}
result = []

for item in lst:
if item in seen:
# If the item has been seen before, increment its count and add a suffix
seen[item] += 1
new_item = f"{item}_{seen[item] + 1}"
else:
# If the item is seen for the first time, add it without a suffix
seen[item] = 0
new_item = item

result.append(new_item)

return result


def get_host_name():
"""
Get the url of the current host
Expand Down

0 comments on commit 044b6ea

Please sign in to comment.