Skip to content

Commit

Permalink
Merge pull request #607 from guillaume-vignal/feature/fix_title_height
Browse files Browse the repository at this point in the history
Dynamic Title Height Adjustment for Feature Importance Plot
  • Loading branch information
guillaume-vignal authored Oct 25, 2024
2 parents d4f783c + 3764170 commit 9b4931a
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 31 deletions.
6 changes: 4 additions & 2 deletions shapash/explainer/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.manifold import MDS

from shapash.style.style_utils import colors_loading, define_style, select_palette
from shapash.utils.utils import adjust_title_height


class Consistency:
Expand Down Expand Up @@ -668,10 +669,11 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
auto_open: bool
open automatically the plot
"""
height = max(500, 40 * len(top_features))
title = "Pairwise comparison of Consistency:"
title += "<span style='font-size: 16px;'>\
<br />How are differences in contributions distributed across features?</span>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": yaxis_title}

Expand All @@ -684,7 +686,7 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
yaxis_title=dict_yaxis,
yaxis=dict(range=[-0.7, len(top_features) - 0.3]),
yaxis2=dict(range=[-0.7, len(top_features) - 0.3]),
height=max(500, 40 * len(top_features)),
height=height,
)

fig.update_yaxes(automargin=True, zeroline=False)
Expand Down
2 changes: 1 addition & 1 deletion shapash/explainer/multi_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def summarize(self, s_contribs, var_dicts, xs_sorted, masks, columns_dict, featu
def compute_features_import(self, contributions, norm=1):
"""
Compute a relative features importance, sum of absolute values
\u200b\u200bof the contributions for each
of the contributions for each
features importance compute in base 100
Parameters
Expand Down
2 changes: 1 addition & 1 deletion shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class SmartExplainer:
model: model object
model used to check the different values of target estimate predict proba
features_desc: dict
Dictionary that references the numbers of feature values \u200b\u200bin the x_init
Dictionary that references the numbers of feature values in the x_init
features_imp: pandas.Series (regression) or list (classification)
Features importance values
local_neighbors: dict
Expand Down
6 changes: 4 additions & 2 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from shapash.utils.utils import (
add_line_break,
add_text,
adjust_title_height,
compute_digit_number,
compute_sorted_variables_interactions_list_indices,
maximum_difference_sort_value,
Expand Down Expand Up @@ -1511,9 +1512,10 @@ def ordinal(n):
]
)

height = max(500, 11 * g_df.shape[0] * g_df.shape[1])
title = f"Comparing local explanations in a neighborhood - Id: <b>{index}</b>"
title += "<span style='font-size: 16px;'><br />How similar are explanations for closeby neighbours?</span>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": "Normalized contribution values"}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": ""}

Expand All @@ -1524,7 +1526,7 @@ def ordinal(n):
yaxis_title=dict_yaxis,
hovermode="closest",
barmode="group",
height=max(500, 11 * g_df.shape[0] * g_df.shape[1]),
height=height,
legend={"traceorder": "reversed"},
xaxis={"side": "bottom"},
)
Expand Down
2 changes: 1 addition & 1 deletion shapash/explainer/smart_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def summarize(self, s_contrib, var_dict, x_sorted, mask, columns_dict, features_
def compute_features_import(self, contributions, norm=1):
"""
Compute a relative features importance, sum of absolute values
\u200b\u200bof the contributions for each
of the contributions for each
features importance compute in base 100
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion shapash/manipulation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def summarize_el(dataframe, mask, prefix):
def compute_features_import(dataframe, norm=1):
"""
Compute a relative features importance, sum of absolute values
\u200b\u200bof the contributions for each
of the contributions for each
features importance compute in base 100
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_bar_chart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.utils.utils import add_line_break, truncate_str
from shapash.utils.utils import add_line_break, adjust_title_height, truncate_str


def plot_bar_chart(
Expand Down Expand Up @@ -83,7 +83,7 @@ def plot_bar_chart(
if subtitle:
title += "<br><sup>" + subtitle + "</sup>"
topmargin += 15
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_local_plot_colors = style_dict["dict_local_plot_colors"] | {"text": None}
Expand Down
4 changes: 3 additions & 1 deletion shapash/plots/plot_compacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from plotly.offline import plot
from plotly.subplots import make_subplots

from shapash.utils.utils import adjust_title_height


def plot_compacity(
features_needed, distance_reached, style_dict, approx=0.9, nb_features=5, file_name=None, auto_open=False
Expand Down Expand Up @@ -104,7 +106,7 @@ def plot_compacity(
title += (
"<span style='font-size: 16px;'><br />How many variables are enough to produce accurate explanations?</span>"
)
dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height()}

fig.update_layout(
template="none",
Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.utils.utils import add_line_break, truncate_str
from shapash.utils.utils import add_line_break, adjust_title_height, truncate_str
from shapash.webapp.utils.utils import round_to_k


Expand Down Expand Up @@ -493,7 +493,7 @@ def _update_contributions_fig(
title += "<br><sup>" + subtitle + "</sup>"
else:
title += "<br><sup>" + addnote + "</sup>"
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str(feature_name, 110)}
dict_yaxis = style_dict["dict_yaxis"] | {"text": "Contribution"}

Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from plotly.subplots import make_subplots

from shapash.manipulation.summarize import compute_corr
from shapash.utils.utils import compute_top_correlations_features, suffix_duplicates
from shapash.utils.utils import adjust_title_height, compute_top_correlations_features, suffix_duplicates


def plot_correlations(
Expand Down Expand Up @@ -203,7 +203,7 @@ def prepare_corr_matrix(df_subset):
if len(list_features) < len(df.drop(features_to_hide, axis=1).columns):
subtitle = f"Top {len(list_features)} correlations"
title += f"<span style='font-size: 12px;'><br />{subtitle}</span>"
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}

fig.update_layout(
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict["init_contrib_colorscale"][5:-1]),
Expand Down
7 changes: 4 additions & 3 deletions shapash/plots/plot_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from plotly.offline import plot

from shapash.style.style_utils import get_pyplot_color
from shapash.utils.utils import adjust_title_height


def plot_feature_importance(
Expand Down Expand Up @@ -234,7 +235,7 @@ def _plot_features_import(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_style_bar1 = style_dict["dict_featimp_colors"][1]
Expand Down Expand Up @@ -357,7 +358,7 @@ def _plot_local_features_import(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": "Mean absolute Contribution"}
dict_yaxis = style_dict["dict_yaxis"] | {"text": None}
dict_style_bar = {}
Expand Down Expand Up @@ -537,7 +538,7 @@ def _plot_feature_contributions_cumulative(
else:
title += "<br><sup>" + addnote + "</sup>"
topmargin = topmargin + 15
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}

if (isinstance(lst_feat[0], str)) & (not zoom):
# change index to abc...abc if its length is upper than 30
Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.utils.utils import add_text, truncate_str
from shapash.utils.utils import add_text, adjust_title_height, truncate_str


def plot_interactions_scatter(x_name, y_name, col_name, x_values, y_values, col_values, col_scale, style_dict):
Expand Down Expand Up @@ -160,7 +160,7 @@ def update_interactions_fig(fig, col_name1, col_name2, addnote, width, height, f
title = f"<b>{truncate_str(col_name1)} and {truncate_str(col_name2)}</b> shap interaction values"
if addnote:
title += f"<span style='font-size: 12px;'><br />{add_text([addnote], sep=' - ')}</span>"
dict_t = style_dict["dict_title"] | {"text": title}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str(col_name1, 110)}
dict_yaxis = style_dict["dict_yaxis"] | {"text": "Shap interaction value"}

Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_line_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.utils.utils import add_line_break, truncate_str
from shapash.utils.utils import add_line_break, adjust_title_height, truncate_str


def plot_line_comparison(
Expand Down Expand Up @@ -67,7 +67,7 @@ def plot_line_comparison(
else:
title = "Compare plot - index : " + " ; ".join(["<b>" + str(id) + "</b>" for id in index])
dict_xaxis["text"] = "Contributions"
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}

if subtitle is not None:
topmargin += 15 * height / 275
Expand Down
4 changes: 2 additions & 2 deletions shapash/plots/plot_scatter_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from plotly.offline import plot

from shapash.utils.sampling import subset_sampling
from shapash.utils.utils import truncate_str, tuning_colorscale
from shapash.utils.utils import adjust_title_height, truncate_str, tuning_colorscale


def plot_scatter_prediction(
Expand Down Expand Up @@ -116,7 +116,7 @@ def plot_scatter_prediction(
title += "<br><sup>" + subtitle + "</sup>"
else:
title += "<br><sup>" + addnote + "</sup>"
dict_t = style_dict["dict_title"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": truncate_str("True Values", 110)}
dict_yaxis = style_dict["dict_yaxis"] | {"text": "Predicted Values"}

Expand Down
14 changes: 7 additions & 7 deletions shapash/plots/plot_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.utils.utils import tuning_colorscale
from shapash.utils.utils import adjust_title_height, tuning_colorscale


def plot_stability_distribution(
Expand Down Expand Up @@ -124,10 +124,6 @@ def plot_stability_distribution(

fig.add_trace(colorbar_trace)

fig.update_layout(
height=height_value,
)

_update_stability_fig(
fig=fig,
x_barlen=len(mean_amplitude),
Expand All @@ -137,12 +133,13 @@ def plot_stability_distribution(
yaxis_title=yaxis_title,
file_name=file_name,
auto_open=auto_open,
height=height_value,
)

return fig


def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_title, file_name, auto_open):
def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_title, file_name, auto_open, height=500):
"""
Function used for the `plot_stability_distribution` and `plot_amplitude_vs_stability`
to update the layout of the plotly figure.
Expand All @@ -165,14 +162,16 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool (default=False)
open automatically the plot
height: int
Plotly figure - layout height
Returns
-------
go.Figure
"""
title = "Importance & Local Stability of explanations:"
title += "<span style='font-size: 16px;'><br />How similar are explanations for closeby neighbours?</span>"
dict_t = style_dict["dict_title_stability"] | {"text": title, "yref": "paper"}
dict_t = style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}

dict_xaxis = style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = style_dict["dict_yaxis"] | {"text": yaxis_title}
Expand Down Expand Up @@ -206,6 +205,7 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
yaxis_title=dict_yaxis,
coloraxis_showscale=False,
hovermode="closest",
height=height,
)

fig.update_yaxes(automargin=True)
Expand Down
18 changes: 18 additions & 0 deletions shapash/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@
from shapash.explainer.smart_state import SmartState


def adjust_title_height(figure_height=500):
"""
Adjust the height of the title according to height of the figure
Parameters
----------
figure_height : int
height of the figure
Returns
-------
int
height of the title
"""

return 1 - 0.1 * 500 / figure_height


def suffix_duplicates(lst):
"""
Adds suffixes (_2, _3, ...) to non-unique elements in a list to make them unique.
Expand Down

0 comments on commit 9b4931a

Please sign in to comment.