diff --git a/shapash/explainer/smart_explainer.py b/shapash/explainer/smart_explainer.py
index 4fbda92d..8d3143e7 100644
--- a/shapash/explainer/smart_explainer.py
+++ b/shapash/explainer/smart_explainer.py
@@ -1191,6 +1191,8 @@ def generate_report(
working_dir=None,
notebook_path=None,
kernel_name=None,
+ max_points=200,
+ nb_top_interactions=5,
):
"""
This method will generate an HTML report containing different information about the project.
@@ -1233,6 +1235,10 @@ def generate_report(
Name of the kernel used to generate the report. This parameter can be usefull if
you have multiple jupyter kernels and that the method does not use the right kernel
by default.
+ max_points : int, optional
+ number of maximum points in the contribution plot
+ nb_top_interactions : int
+ Number of top interactions to display.
Examples
--------
>>> xpl.generate_report(
@@ -1284,6 +1290,8 @@ def generate_report(
title_story=title_story,
title_description=title_description,
metrics=metrics,
+ max_points=max_points,
+ nb_top_interactions=nb_top_interactions,
),
notebook_path=notebook_path,
kernel_name=kernel_name,
diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py
index 9ec19bd1..708971a9 100644
--- a/shapash/explainer/smart_plotter.py
+++ b/shapash/explainer/smart_plotter.py
@@ -2639,6 +2639,7 @@ def generate_title_dict(col_name1, col_name2, addnote):
def correlations(
self,
df=None,
+ optimized=False,
max_features=20,
features_to_hide=None,
facet_col=None,
@@ -2658,6 +2659,9 @@ def correlations(
----------
df : pd.DataFrame, optional
DataFrame for which we want to compute correlations. Will use x_init by default.
+ optimized : boolean, optional
+ True if we want to potentially accelerate the computation of the correlation matrix by reducing the
+ lenght of the data and the number of modalties per columns.
max_features : int (default: 10)
Max number of features to show on the matrix.
features_to_hide : list (optional)
@@ -2731,7 +2735,17 @@ def cluster_corr(corr, degree, inplace=False):
if df is None:
# Use x_init by default
- df = self.explainer.x_init
+ df = self.explainer.x_init.copy()
+
+ if optimized:
+ categorical_columns = df.select_dtypes(include=["object", "category"]).columns
+
+ for col in categorical_columns:
+ top_categories = df[col].value_counts().nlargest(200).index
+ df[col] = df[col].where(df[col].isin(top_categories), other="Other")
+
+ if len(df) > 10000:
+ df = df.sample(n=10000, random_state=1)
if facet_col:
features_to_hide += [facet_col]
@@ -2758,12 +2772,16 @@ def cluster_corr(corr, degree, inplace=False):
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = list(corr.columns)
+ k = 6
+ list_features_shorten = [
+ x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
+ ]
fig.add_trace(
go.Heatmap(
z=corr.loc[list_features, list_features].round(decimals).values,
- x=list_features,
- y=list_features,
+ x=list_features_shorten,
+ y=list_features_shorten,
coloraxis="coloraxis",
text=[
[
@@ -2784,12 +2802,16 @@ def cluster_corr(corr, degree, inplace=False):
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = [col for col in corr.columns if col in top_features]
+ k = 6
+ list_features_shorten = [
+ x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
+ ]
fig = go.Figure(
go.Heatmap(
z=corr.loc[list_features, list_features].round(decimals).values,
- x=list_features,
- y=list_features,
+ x=list_features_shorten,
+ y=list_features_shorten,
coloraxis="coloraxis",
text=[
[
diff --git a/shapash/report/html/explainability.html b/shapash/report/html/explainability.html
index 6ef1bb66..51a70292 100644
--- a/shapash/report/html/explainability.html
+++ b/shapash/report/html/explainability.html
@@ -8,14 +8,14 @@
Global feature importance plot
{% for label in labels %}
{{ label['feature_importance_plot'] }}
- {% with menuId='dropdownMenu2', menuText='Feature', values=label['features'], menuDivVisible='explain-contrib-'~label['index'] %}
- {% include "dropdown.html" %}
- {% endwith %}
{% endfor %}
Features contribution plots
{% for label in labels %}
+ {% with menuId='dropdownMenu2', menuText='Feature', values=label['features'], menuDivVisible='explain-contrib-'~label['index'] %}
+ {% include "dropdown.html" %}
+ {% endwith %}
{% for col in label['features'] %}
{{ col['name'] }} - {{ col['type'] }}
@@ -28,3 +28,21 @@ {{ col['name'] }} - {{ col['type'] }}
{% endfor %}
{% endfor %}
+
Features Top Interaction plots
+{% for label in labels %}
+
+ {% with menuId='dropdownMenu3', menuText='Interactions', values=label['features_interaction'], menuDivVisible='explain-contrib-interaction-'~label['index'] %}
+ {% include "dropdown.html" %}
+ {% endwith %}
+ {% for col in label['features_interaction'] %}
+
+
{{ col['name'] }} - {{ col['type'] }}
+ {% if col['name'] != col['description'] %}
+
{{ col['description'] }}
+ {% else %}
+ {% endif %}
+ {{ col['plot'] }}
+
+ {% endfor %}
+
+{% endfor %}
diff --git a/shapash/report/project_report.py b/shapash/report/project_report.py
index a646f9cd..d8835f67 100644
--- a/shapash/report/project_report.py
+++ b/shapash/report/project_report.py
@@ -24,7 +24,7 @@
)
from shapash.utils.io import load_yml
from shapash.utils.transform import apply_postprocessing, handle_categorical_missing, inverse_transform
-from shapash.utils.utils import get_project_root, truncate_str
+from shapash.utils.utils import compute_sorted_variables_interactions_list_indices, get_project_root, truncate_str
from shapash.webapp.utils.utils import round_to_k
logging.basicConfig(level=logging.INFO)
@@ -98,6 +98,16 @@ def __init__(
self.y_train, target_name_train = self._get_values_and_name(y_train, "target")
self.target_name = target_name_train or target_name_test
+ if "max_points" in self.config.keys():
+ self.max_points = config["max_points"]
+ else:
+ self.max_points = 200
+
+ if "nb_top_interactions" in self.config.keys():
+ self.nb_top_interactions = config["nb_top_interactions"]
+ else:
+ self.nb_top_interactions = 5
+
if "title_story" in self.config.keys():
self.title_story = config["title_story"]
elif self.explainer.title_story != "":
@@ -308,6 +318,7 @@ def display_dataset_analysis(
print_md("### Multivariate analysis")
fig_corr = self.explainer.plot.correlations(
self.df_train_test,
+ optimized=True,
facet_col="data_train_test",
max_features=20,
width=900 if len(self.df_train_test["data_train_test"].unique()) > 1 else 500,
@@ -389,13 +400,16 @@ def display_model_explainability(self):
c_list = self.explainer._classes if multiclass else [1] # list just used for multiclass
for index_label, label in enumerate(c_list): # Iterating over all labels in multiclass case
label_value = self.explainer.check_label_name(label)[2] if multiclass else ""
+
+ # Feature Importance
fig_features_importance = self.explainer.plot.features_importance(label=label)
+ # Contribution Plot
explain_contrib_data = list()
list_cols_labels = [self.explainer.features_dict.get(col, col) for col in self.col_names]
for feature_label in sorted(list_cols_labels):
feature = self.explainer.inv_features_dict.get(feature_label, feature_label)
- fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=200)
+ fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=self.max_points)
# Apparently matkers are not supported during conversion into html
for el in fig.data:
if el.type == "bar":
@@ -408,6 +422,37 @@ def display_model_explainability(self):
"plot": plotly.io.to_html(fig, include_plotlyjs=False, full_html=False),
}
)
+
+ # Interaction Plot
+ explain_contrib_data_interaction = list()
+ list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
+ selection=None, max_points=self.max_points
+ )
+ interaction_values = self.explainer.get_interaction_values(selection=list_ind)
+ sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
+ indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]
+
+ for i, ids in enumerate(indices_to_plot):
+ id0, id1 = ids
+
+ fig_one_interaction = self.explainer.plot.interactions_plot(
+ col1=self.explainer.columns_dict[id0],
+ col2=self.explainer.columns_dict[id1],
+ max_points=self.max_points,
+ )
+
+ explain_contrib_data_interaction.append(
+ {
+ "feature_index": i,
+ "name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
+ "description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
+ + " / "
+ + self.explainer.features_dict[self.explainer.columns_dict[id1]],
+ "plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
+ }
+ )
+
+ # Aggregating the data
explain_data.append(
{
"index": index_label,
@@ -416,6 +461,7 @@ def display_model_explainability(self):
fig_features_importance, include_plotlyjs=False, full_html=False
),
"features": explain_contrib_data,
+ "features_interaction": explain_contrib_data_interaction,
}
)
print_html(explainability_template.render(labels=explain_data))