diff --git a/scikit_posthocs/_plotting.py b/scikit_posthocs/_plotting.py index 622e171..90f4737 100644 --- a/scikit_posthocs/_plotting.py +++ b/scikit_posthocs/_plotting.py @@ -361,6 +361,7 @@ def critical_difference_diagram( marker_props: dict = None, elbow_props: dict = None, crossbar_props: dict = None, + color_palette:Union[Dict[str, str], List] = {}, text_h_margin: float = 0.01) -> Dict[str, list]: """Plot a Critical Difference diagram from ranks and post-hoc results. @@ -428,6 +429,8 @@ def critical_difference_diagram( text_h_margin : float, optional Space between the text labels and the nearest vertical line of an elbow, by default 0.01. + color_palette: dict, optional + Parameters to be passed when you need specific colors for each category Returns ------- @@ -445,6 +448,16 @@ def critical_difference_diagram( .. [2] https://mirkobunse.github.io/CriticalDifferenceDiagrams.jl/stable/ """ + ## check color_palette consistency + if len(color_palette) == 0: + pass + elif isinstance(color_palette, Dict) and ((len(set(ranks.keys()) & set(color_palette.keys())))== len(ranks)): + pass + elif isinstance(color_palette, List) and (len(ranks) <= len(color_palette)) : + pass + else: + raise ValueError("color_palette keys are not consistent, or list size too small") + elbow_props = elbow_props or {} marker_props = {"zorder": 3, **(marker_props or {})} label_props = {"va": "center", **(label_props or {})} @@ -509,15 +522,24 @@ def critical_difference_diagram( lowest_crossbar_ypos = -len(crossbar_levels) - def plot_items(points, xpos, label_fmt, label_props): + def plot_items(points, xpos, label_fmt,color_palette, label_props): """Plot each marker + elbow + label.""" ypos = lowest_crossbar_ypos - 1 - for label, rank in points.items(): - elbow, *_ = ax.plot( - [xpos, rank, rank], - [ypos, ypos, 0], - **elbow_props, - ) + for idx, (label, rank) in enumerate(points.items()): + if len(color_palette) == 0: + elbow, *_ = ax.plot( + [xpos, rank, rank], + [ypos, ypos, 0], + **elbow_props, + ) + else: + elbow, *_ = ax.plot( + [xpos, rank, rank], + [ypos, ypos, 0], + c=color_palette[label] if isinstance(color_palette, Dict) else color_palette[idx], + **elbow_props, + ) + elbows.append(elbow) curr_color = elbow.get_color() markers.append( @@ -537,12 +559,15 @@ def plot_items(points, xpos, label_fmt, label_props): points_left, xpos=points_left.iloc[0] - text_h_margin, label_fmt=label_fmt_left, - label_props={"ha": "right", **label_props}, + color_palette = color_palette, + label_props={"ha": "right", **label_props, + }, ) plot_items( points_right[::-1], xpos=points_right.iloc[-1] + text_h_margin, label_fmt=label_fmt_right, + color_palette = color_palette, label_props={"ha": "left", **label_props}, )