Skip to content

Commit

Permalink
Merge pull request #784 from pfebrer/matplotlib_grouplegend
Browse files Browse the repository at this point in the history
Legend grouping in bands plots
  • Loading branch information
zerothi authored Jun 5, 2024
2 parents 1d3f4da + d3a6b90 commit 8659a4a
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 59 deletions.
30 changes: 29 additions & 1 deletion docs/visualization/viz_module/showcase/BandsPlot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,34 @@
"Notice that in spin polarized bands, **you can select the spins to display using the `spin` setting**, just pass a list of spin components (e.g. `spin=[0]`)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Individual bands in legend\n",
"--------------------------\n",
"\n",
"Usually, showing all bands individually in the legend would be too messy. However, you might want to do it so that you can interactively hide show certain bands. If that is the case, you can set `group_legend` to `False`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bands_plot.update_inputs(group_legend=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bands_plot = bands_plot.update_inputs(group_legend=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -594,7 +622,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
21 changes: 21 additions & 0 deletions src/sisl/viz/figure/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def draw_line(
line={},
marker={},
text=None,
showlegend=True,
row=None,
col=None,
_axes=None,
Expand All @@ -258,6 +259,10 @@ def draw_line(

axes = _axes or self._get_subplot_axes(row=row, col=col)

# Matplotlib doesn't show lines on the legend if their name starts
# with an underscore, so prepend the name with "_" if showlegend is False.
name = name if showlegend else f"_{name}"

return axes.plot(
x,
y,
Expand All @@ -279,6 +284,7 @@ def draw_multicolor_line(
line={},
marker={},
text=None,
showlegend=True,
row=None,
col=None,
_axes=None,
Expand All @@ -289,6 +295,10 @@ def draw_multicolor_line(

color = line.get("color")

# Matplotlib doesn't show lines on the legend if their name starts
# with an underscore, so prepend the name with "_" if showlegend is False.
name = name if showlegend else f"_{name}"

if not np.issubdtype(np.array(color).dtype, np.number):
return self.draw_multicolor_scatter(
x,
Expand Down Expand Up @@ -358,6 +368,7 @@ def draw_area_line(
y,
line={},
name=None,
showlegend=True,
dependent_axis=None,
row=None,
col=None,
Expand All @@ -371,6 +382,10 @@ def draw_area_line(

axes = _axes or self._get_subplot_axes(row=row, col=col)

# Matplotlib doesn't show lines on the legend if their name starts
# with an underscore, so prepend the name with "_" if showlegend is False.
name = name if showlegend else f"_{name}"

if dependent_axis in ("y", None):
axes.fill_between(
x, y + spacing, y - spacing, color=line.get("color"), label=name
Expand All @@ -392,13 +407,19 @@ def draw_scatter(
marker={},
text=None,
zorder=2,
showlegend=True,
row=None,
col=None,
_axes=None,
meta={},
**kwargs,
):
axes = _axes or self._get_subplot_axes(row=row, col=col)

# Matplotlib doesn't show lines on the legend if their name starts
# with an underscore, so prepend the name with "_" if showlegend is False.
name = name if showlegend else f"_{name}"

try:
return axes.scatter(
x,
Expand Down
4 changes: 2 additions & 2 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,10 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, **kwargs):
size=sp_size,
color=sp_color,
opacity=sp_opacity,
name=f"{name}_{i}",
name=name,
legendgroup=name,
showlegend=showlegend,
meta=meta,
meta={**meta, f"{name}_i": i},
)
showlegend = False

Expand Down
90 changes: 61 additions & 29 deletions src/sisl/viz/plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@
from .orbital_groups_plot import OrbitalGroupsPlot


def _default_random_color(x):
return x.get("color") or random_color()


def _group_traces(actions, group_legend: bool = True):

if not group_legend:
return actions

seen_groups = []

new_actions = []
for action in actions:
if action["method"].startswith("draw_"):
group = action["kwargs"].get("name")
action = action.copy()
action["kwargs"]["legendgroup"] = group

if group in seen_groups:
action["kwargs"]["showlegend"] = False
else:
seen_groups.append(group)

new_actions.append(action)

return new_actions


def bands_plot(
bands_data: BandsData,
Erange: Optional[Tuple[float, float]] = None,
Expand All @@ -44,6 +72,7 @@ def bands_plot(
direct_gaps_only: bool = False,
custom_gaps: Sequence[Dict] = [],
line_mode: Literal["line", "scatter", "area_line"] = "line",
group_legend: bool = True,
backend: str = "plotly",
) -> Figure:
"""Plots band structure energies, with plentiful of customization options.
Expand Down Expand Up @@ -87,6 +116,10 @@ def bands_plot(
List of custom gaps to display. See the showcase notebooks for examples.
line_mode:
The method used to draw the band lines.
group_legend:
Whether to group all bands in the legend to show a single legend item.
If the bands are spin polarized, bands are grouped by spin channel.
backend:
The backend to use to generate the figure.
"""
Expand All @@ -95,12 +128,19 @@ def bands_plot(

# Filter the bands
filtered_bands = filter_bands(
bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin
bands_data,
Erange=Erange,
E0=E0,
bands_range=bands_range,
spin=spin,
)

# Add the styles
styled_bands = style_bands(
filtered_bands, bands_style=bands_style, spindown_style=spindown_style
filtered_bands,
bands_style=bands_style,
spindown_style=spindown_style,
group_legend=group_legend,
)

# Determine what goes on each axis
Expand All @@ -114,9 +154,11 @@ def bands_plot(
y=y,
set_axrange=True,
what=line_mode,
name="line_name",
colorscale=colorscale,
dependent_axis=E_axis,
)
grouped_bands_plottings = _group_traces(bands_plottings, group_legend=group_legend)

# Gap calculation
gap_info = calculate_gap(filtered_bands)
Expand All @@ -133,35 +175,13 @@ def bands_plot(
E_axis=E_axis,
)

all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None)
all_plottings = combined(
grouped_bands_plottings, gaps_plottings, composite_method=None
)

return get_figure(backend=backend, plot_actions=all_plottings)


def _default_random_color(x):
return x.get("color") or random_color()


def _group_traces(actions):
seen_groups = []

new_actions = []
for action in actions:
if action["method"].startswith("draw_"):
group = action["kwargs"].get("name")
action = action.copy()
action["kwargs"]["legendgroup"] = group

if group in seen_groups:
action["kwargs"]["showlegend"] = False
else:
seen_groups.append(group)

new_actions.append(action)

return new_actions


# I keep the fatbands plot here so that one can see how similar they are.
# I am yet to find a nice solution for extending workflows.
def fatbands_plot(
Expand All @@ -180,6 +200,7 @@ def fatbands_plot(
direct_gaps_only: bool = False,
custom_gaps: Sequence[Dict] = [],
bands_mode: Literal["line", "scatter", "area_line"] = "line",
bands_group_legend: bool = True,
# Fatbands inputs
groups: OrbitalQueries = [],
fatbands_var: str = "norm2",
Expand Down Expand Up @@ -225,6 +246,10 @@ def fatbands_plot(
List of custom gaps to display. See the showcase notebooks for examples.
bands_mode:
The method used to draw the band lines.
bands_group_legend:
Whether to group all bands in the legend to show a single legend item.
If the bands are spin polarized, bands are grouped by spin channel.
groups:
Orbital groups to plots. See showcase notebook for examples.
fatbands_var:
Expand All @@ -246,7 +271,10 @@ def fatbands_plot(

# Add the styles
styled_bands = style_bands(
filtered_bands, bands_style=bands_style, spindown_style=spindown_style
filtered_bands,
bands_style=bands_style,
spindown_style=spindown_style,
group_legend=bands_group_legend,
)

# Process fatbands
Expand Down Expand Up @@ -299,8 +327,12 @@ def fatbands_plot(
y=y,
set_axrange=True,
what=bands_mode,
name="line_name",
dependent_axis=E_axis,
)
grouped_bands_plottings = _group_traces(
bands_plottings, group_legend=bands_group_legend
)

# Gap calculation
gap_info = calculate_gap(filtered_bands)
Expand All @@ -319,7 +351,7 @@ def fatbands_plot(

all_plottings = combined(
grouped_fatbands_plottings,
bands_plottings,
grouped_bands_plottings,
gaps_plottings,
composite_method=None,
)
Expand Down
1 change: 1 addition & 0 deletions src/sisl/viz/plots/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def pdos_plot(
x=x,
y=y,
width="size",
name="group",
what=line_mode,
dependent_axis=dependent_axis,
)
Expand Down
Loading

0 comments on commit 8659a4a

Please sign in to comment.