From c79861b8712b11be00112eb490d56ad92d622537 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Thu, 4 Jul 2024 11:06:16 +0200 Subject: [PATCH] Fix color style. --- shapash/explainer/smart_plotter.py | 4 ++-- shapash/style/style_utils.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index d933586c..d9799215 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -3645,7 +3645,7 @@ def _prediction_classification_plot( x=df_correct_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_correct_predict)), y=df_correct_predict["proba_values"].values.flatten(), mode="markers", - marker_color=self._style_dict["prediction_plot"][0], + marker_color=self._style_dict["prediction_plot"][1], showlegend=True, name="Correct Prediction", hovertext=hv_text_correct_predict, @@ -3659,7 +3659,7 @@ def _prediction_classification_plot( x=df_wrong_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_wrong_predict)), y=df_wrong_predict["proba_values"].values.flatten(), mode="markers", - marker_color=self._style_dict["prediction_plot"][1], + marker_color=self._style_dict["prediction_plot"][0], showlegend=True, name="Wrong Prediction", hovertext=hv_text_wrong_predict, diff --git a/shapash/style/style_utils.py b/shapash/style/style_utils.py index 407e98ff..59968fba 100644 --- a/shapash/style/style_utils.py +++ b/shapash/style/style_utils.py @@ -1,6 +1,7 @@ """ functions for loading and manipulating colors """ + import json import os @@ -100,11 +101,11 @@ def define_style(palette): 1: {"color": featureimp_bar[1], "line": {"color": palette["featureimp_line"], "width": 0.5}}, 2: {"color": featureimp_bar[2]}, } - style_dict["featureimp_groups"] = list(palette["featureimp_groups"].values()) + style_dict["featureimp_groups"] = convert_string_to_int_keys(palette["featureimp_groups"]) style_dict["init_contrib_colorscale"] = palette["contrib_colorscale"] style_dict["contrib_distribution"] = palette["contrib_distribution"] - style_dict["violin_area_classif"] = list(palette["violin_area_classif"].values()) - style_dict["prediction_plot"] = list(palette["prediction_plot"].values()) + style_dict["violin_area_classif"] = convert_string_to_int_keys(palette["violin_area_classif"]) + style_dict["prediction_plot"] = convert_string_to_int_keys(palette["prediction_plot"]) style_dict["violin_default"] = palette["violin_default"] style_dict["dict_title_compacity"] = {"font": {"size": 14, "family": "Arial", "color": palette["title_color"]}} style_dict["dict_xaxis"] = {"font": {"size": 16, "family": "Arial Black", "color": palette["axis_color"]}}