From d3a6b90eb8b80cc5be15fbdc8df1783fdec9db3b Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Wed, 5 Jun 2024 01:15:16 +0200 Subject: [PATCH] Legend grouping in bands --- .../viz_module/showcase/BandsPlot.ipynb | 30 ++++++- src/sisl/viz/figure/matplotlib.py | 21 +++++ src/sisl/viz/figure/plotly.py | 4 +- src/sisl/viz/plots/bands.py | 90 +++++++++++++------ src/sisl/viz/plots/pdos.py | 1 + src/sisl/viz/plotters/xarray.py | 34 +++---- src/sisl/viz/processors/bands.py | 39 +++++++- src/sisl/viz/processors/xarray.py | 3 - 8 files changed, 163 insertions(+), 59 deletions(-) diff --git a/docs/visualization/viz_module/showcase/BandsPlot.ipynb b/docs/visualization/viz_module/showcase/BandsPlot.ipynb index 779f35ae6..74ba7c7fa 100644 --- a/docs/visualization/viz_module/showcase/BandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/BandsPlot.ipynb @@ -149,6 +149,34 @@ "Notice that in spin polarized bands, **you can select the spins to display using the `spin` setting**, just pass a list of spin components (e.g. `spin=[0]`)." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Individual bands in legend\n", + "--------------------------\n", + "\n", + "Usually, showing all bands individually in the legend would be too messy. However, you might want to do it so that you can interactively hide show certain bands. If that is the case, you can set `group_legend` to `False`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bands_plot.update_inputs(group_legend=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bands_plot = bands_plot.update_inputs(group_legend=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -594,7 +622,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/src/sisl/viz/figure/matplotlib.py b/src/sisl/viz/figure/matplotlib.py index 2a5c4b9bb..0297dbc20 100644 --- a/src/sisl/viz/figure/matplotlib.py +++ b/src/sisl/viz/figure/matplotlib.py @@ -248,6 +248,7 @@ def draw_line( line={}, marker={}, text=None, + showlegend=True, row=None, col=None, _axes=None, @@ -258,6 +259,10 @@ def draw_line( axes = _axes or self._get_subplot_axes(row=row, col=col) + # Matplotlib doesn't show lines on the legend if their name starts + # with an underscore, so prepend the name with "_" if showlegend is False. + name = name if showlegend else f"_{name}" + return axes.plot( x, y, @@ -279,6 +284,7 @@ def draw_multicolor_line( line={}, marker={}, text=None, + showlegend=True, row=None, col=None, _axes=None, @@ -289,6 +295,10 @@ def draw_multicolor_line( color = line.get("color") + # Matplotlib doesn't show lines on the legend if their name starts + # with an underscore, so prepend the name with "_" if showlegend is False. + name = name if showlegend else f"_{name}" + if not np.issubdtype(np.array(color).dtype, np.number): return self.draw_multicolor_scatter( x, @@ -358,6 +368,7 @@ def draw_area_line( y, line={}, name=None, + showlegend=True, dependent_axis=None, row=None, col=None, @@ -371,6 +382,10 @@ def draw_area_line( axes = _axes or self._get_subplot_axes(row=row, col=col) + # Matplotlib doesn't show lines on the legend if their name starts + # with an underscore, so prepend the name with "_" if showlegend is False. + name = name if showlegend else f"_{name}" + if dependent_axis in ("y", None): axes.fill_between( x, y + spacing, y - spacing, color=line.get("color"), label=name @@ -392,6 +407,7 @@ def draw_scatter( marker={}, text=None, zorder=2, + showlegend=True, row=None, col=None, _axes=None, @@ -399,6 +415,11 @@ def draw_scatter( **kwargs, ): axes = _axes or self._get_subplot_axes(row=row, col=col) + + # Matplotlib doesn't show lines on the legend if their name starts + # with an underscore, so prepend the name with "_" if showlegend is False. + name = name if showlegend else f"_{name}" + try: return axes.scatter( x, diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 686889ff8..04a7ec512 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -643,10 +643,10 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, **kwargs): size=sp_size, color=sp_color, opacity=sp_opacity, - name=f"{name}_{i}", + name=name, legendgroup=name, showlegend=showlegend, - meta=meta, + meta={**meta, f"{name}_i": i}, ) showlegend = False diff --git a/src/sisl/viz/plots/bands.py b/src/sisl/viz/plots/bands.py index 209c8bea2..0bc138870 100644 --- a/src/sisl/viz/plots/bands.py +++ b/src/sisl/viz/plots/bands.py @@ -22,6 +22,34 @@ from .orbital_groups_plot import OrbitalGroupsPlot +def _default_random_color(x): + return x.get("color") or random_color() + + +def _group_traces(actions, group_legend: bool = True): + + if not group_legend: + return actions + + seen_groups = [] + + new_actions = [] + for action in actions: + if action["method"].startswith("draw_"): + group = action["kwargs"].get("name") + action = action.copy() + action["kwargs"]["legendgroup"] = group + + if group in seen_groups: + action["kwargs"]["showlegend"] = False + else: + seen_groups.append(group) + + new_actions.append(action) + + return new_actions + + def bands_plot( bands_data: BandsData, Erange: Optional[Tuple[float, float]] = None, @@ -44,6 +72,7 @@ def bands_plot( direct_gaps_only: bool = False, custom_gaps: Sequence[Dict] = [], line_mode: Literal["line", "scatter", "area_line"] = "line", + group_legend: bool = True, backend: str = "plotly", ) -> Figure: """Plots band structure energies, with plentiful of customization options. @@ -87,6 +116,10 @@ def bands_plot( List of custom gaps to display. See the showcase notebooks for examples. line_mode: The method used to draw the band lines. + group_legend: + Whether to group all bands in the legend to show a single legend item. + + If the bands are spin polarized, bands are grouped by spin channel. backend: The backend to use to generate the figure. """ @@ -95,12 +128,19 @@ def bands_plot( # Filter the bands filtered_bands = filter_bands( - bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin + bands_data, + Erange=Erange, + E0=E0, + bands_range=bands_range, + spin=spin, ) # Add the styles styled_bands = style_bands( - filtered_bands, bands_style=bands_style, spindown_style=spindown_style + filtered_bands, + bands_style=bands_style, + spindown_style=spindown_style, + group_legend=group_legend, ) # Determine what goes on each axis @@ -114,9 +154,11 @@ def bands_plot( y=y, set_axrange=True, what=line_mode, + name="line_name", colorscale=colorscale, dependent_axis=E_axis, ) + grouped_bands_plottings = _group_traces(bands_plottings, group_legend=group_legend) # Gap calculation gap_info = calculate_gap(filtered_bands) @@ -133,35 +175,13 @@ def bands_plot( E_axis=E_axis, ) - all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None) + all_plottings = combined( + grouped_bands_plottings, gaps_plottings, composite_method=None + ) return get_figure(backend=backend, plot_actions=all_plottings) -def _default_random_color(x): - return x.get("color") or random_color() - - -def _group_traces(actions): - seen_groups = [] - - new_actions = [] - for action in actions: - if action["method"].startswith("draw_"): - group = action["kwargs"].get("name") - action = action.copy() - action["kwargs"]["legendgroup"] = group - - if group in seen_groups: - action["kwargs"]["showlegend"] = False - else: - seen_groups.append(group) - - new_actions.append(action) - - return new_actions - - # I keep the fatbands plot here so that one can see how similar they are. # I am yet to find a nice solution for extending workflows. def fatbands_plot( @@ -180,6 +200,7 @@ def fatbands_plot( direct_gaps_only: bool = False, custom_gaps: Sequence[Dict] = [], bands_mode: Literal["line", "scatter", "area_line"] = "line", + bands_group_legend: bool = True, # Fatbands inputs groups: OrbitalQueries = [], fatbands_var: str = "norm2", @@ -225,6 +246,10 @@ def fatbands_plot( List of custom gaps to display. See the showcase notebooks for examples. bands_mode: The method used to draw the band lines. + bands_group_legend: + Whether to group all bands in the legend to show a single legend item. + + If the bands are spin polarized, bands are grouped by spin channel. groups: Orbital groups to plots. See showcase notebook for examples. fatbands_var: @@ -246,7 +271,10 @@ def fatbands_plot( # Add the styles styled_bands = style_bands( - filtered_bands, bands_style=bands_style, spindown_style=spindown_style + filtered_bands, + bands_style=bands_style, + spindown_style=spindown_style, + group_legend=bands_group_legend, ) # Process fatbands @@ -299,8 +327,12 @@ def fatbands_plot( y=y, set_axrange=True, what=bands_mode, + name="line_name", dependent_axis=E_axis, ) + grouped_bands_plottings = _group_traces( + bands_plottings, group_legend=bands_group_legend + ) # Gap calculation gap_info = calculate_gap(filtered_bands) @@ -319,7 +351,7 @@ def fatbands_plot( all_plottings = combined( grouped_fatbands_plottings, - bands_plottings, + grouped_bands_plottings, gaps_plottings, composite_method=None, ) diff --git a/src/sisl/viz/plots/pdos.py b/src/sisl/viz/plots/pdos.py index a2e7ceb41..834d2d583 100644 --- a/src/sisl/viz/plots/pdos.py +++ b/src/sisl/viz/plots/pdos.py @@ -84,6 +84,7 @@ def pdos_plot( x=x, y=y, width="size", + name="group", what=line_mode, dependent_axis=dependent_axis, ) diff --git a/src/sisl/viz/plotters/xarray.py b/src/sisl/viz/plotters/xarray.py index 02088536b..b56a1ee73 100644 --- a/src/sisl/viz/plotters/xarray.py +++ b/src/sisl/viz/plotters/xarray.py @@ -88,6 +88,8 @@ def _process_xarray_data(data, x=None, y=None, z=False, style={}): for key, value in style.items(): if value in data: styles[key] = data[value] + elif key == "name": + styles[key] = DataArray(value) else: styles[key] = None @@ -133,7 +135,13 @@ def draw_xarray_xy( x=x, y=y, z=z, - style={"color": color, "width": width, "opacity": opacity, "dash": dash}, + style={ + "color": color, + "width": width, + "opacity": opacity, + "dash": dash, + "name": name, + }, ) if plot_data is None: @@ -185,7 +193,7 @@ def _draw_xarray_lines( # Get the lines styles lines_style = {} extra_style_dims = False - for key in ("color", "width", "opacity", "dash"): + for key in ("color", "width", "opacity", "dash", "name"): lines_style[key] = style.get(key) if lines_style[key] is not None: @@ -262,27 +270,13 @@ def drawing_function(*args, **kwargs): lines_style["width"], lines_style["opacity"], lines_style["dash"], + lines_style["name"], ) fixed_coords_values = {k: arr.values for k, arr in fixed_coords.items()} - single_line = len(data.iterate_dim) == 1 - if name in data.iterate_dim.coords: - name_prefix = "" - else: - name_prefix = f"{name}_" if name and not single_line else name - # Now just iterate over each line and plot it. for values, *styles in iterator: - names = values.iterate_dim.values[()] - if name in values.iterate_dim.coords: - line_name = f"{name_prefix}{values.iterate_dim.coords[name].values[()]}" - elif single_line and not isinstance(names[0], str): - line_name = name_prefix - elif len(names) == 1: - line_name = f"{name_prefix}{names[0]}" - else: - line_name = f"{name_prefix}{names}" parsed_styles = [] for style in styles: @@ -292,7 +286,7 @@ def drawing_function(*args, **kwargs): style = style[()] parsed_styles.append(style) - line_color, line_width, line_opacity, line_dash = parsed_styles + line_color, line_width, line_opacity, line_dash, line_name = parsed_styles line_style = { "color": line_color, "width": line_width, @@ -307,7 +301,7 @@ def drawing_function(*args, **kwargs): } if not extra_style_dims: - drawing_function(**coords, line=line, name=line_name) + drawing_function(**coords, line=line, name=str(line_name)) else: for k, v in line_style.items(): if v is None or v.ndim == 0: @@ -325,7 +319,7 @@ def drawing_function(*args, **kwargs): "opacity": l_opacity, "dash": l_dash, } - drawing_function(**coords, line=line_style, name=line_name) + drawing_function(**coords, line=line_style, name=str(line_name)) return to_plot diff --git a/src/sisl/viz/processors/bands.py b/src/sisl/viz/processors/bands.py index 17fd3e1cf..3178aeea1 100644 --- a/src/sisl/viz/processors/bands.py +++ b/src/sisl/viz/processors/bands.py @@ -77,20 +77,25 @@ def style_bands( bands_data: xr.Dataset, bands_style: dict = {"color": "black", "width": 1}, spindown_style: dict = {"color": "blue", "width": 1}, + group_legend: bool = True, ) -> xr.Dataset: """Returns the bands dataset, with the style information added to it. Parameters ------------ - bands_data: xr.Dataset + bands_data: The dataset containing bands energy information. - bands_style: dict + bands_style: Dictionary containing the style information for the bands. - spindown_style: dict + spindown_style: Dictionary containing the style information for the spindown bands. Any style that is not present in this dictionary will be taken from the "bands_style" dictionary. + group_legend: + Whether the bands will be grouped in the legend. This will determine + how the names of each band are set """ + # If the user provided a styler function, apply it. if bands_style.get("styler") is not None: if callable(bands_style["styler"]): @@ -112,7 +117,7 @@ def style_bands( if "spin" in bands_data.dims: spindown_style = {**bands_style, **spindown_style} style_arrays = {} - for key in ["color", "width", "opacity"]: + for key in ["color", "width", "opacity", "dash"]: if isinstance(bands_style[key], xr.DataArray): if not isinstance(spindown_style[key], xr.DataArray): down_style = bands_style[key].copy(deep=True) @@ -126,11 +131,37 @@ def style_bands( style_arrays[key] = xr.DataArray( [bands_style[key], spindown_style[key]], dims=["spin"] ) + + # Determine the names of the bands + if group_legend: + style_arrays["line_name"] = xr.DataArray( + ["Spin up Bands", "Spin down Bands"], dims=["spin"] + ) + else: + names = [] + for s in bands_data.spin: + spin_string = "UP" if s == 0 else "DOWN" + for iband in bands_data.band.values: + names.append(f"{spin_string}_{iband}") + + style_arrays["line_name"] = xr.DataArray( + np.array(names).reshape(2, -1), + coords=[ + ("spin", bands_data.spin.values), + ("band", bands_data.band.values), + ], + ) else: style_arrays = {} for key in ["color", "width", "opacity", "dash"]: style_arrays[key] = xr.DataArray(bands_style[key]) + # Determine the names of the bands + if group_legend: + style_arrays["line_name"] = xr.DataArray("Bands") + else: + style_arrays["line_name"] = bands_data.band + # Merge the style arrays with the bands dataset and return the styled dataset return bands_data.assign(style_arrays) diff --git a/src/sisl/viz/processors/xarray.py b/src/sisl/viz/processors/xarray.py index 99640ba5c..df23db050 100644 --- a/src/sisl/viz/processors/xarray.py +++ b/src/sisl/viz/processors/xarray.py @@ -12,9 +12,6 @@ import xarray as xr from xarray import DataArray, Dataset -from sisl import Geometry -from sisl.messages import SislError - class XarrayData: @singledispatchmethod