From 156f9f213449cc8fab64d10e618e9f2a3b69e185 Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 15 Aug 2023 09:44:04 +0100 Subject: [PATCH] Added `hide_plot_traces` in Plotterly class Added mapping to `plot_sensors` in Plotterly and Plotter class --- stonesoup/plotter.py | 53 +++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index d3e840e94..cb77d2527 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -31,18 +31,18 @@ from .models.base import LinearModel, Model -from enum import Enum +from enum import IntEnum -class Dimension(Enum): +class Dimension(IntEnum): """Dimension Enum class for specifying plotting parameters in the Plotter class. Used to sanitize inputs for the dimension attribute of Plotter(). Attributes ---------- - TWO: str + TWO: int Specifies 2D plotting for Plotter object - THREE: str + THREE: int Specifies 3D plotting for Plotter object """ TWO = 2 # 2D plotting mode (original plotter.py functionality) @@ -66,7 +66,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ raise NotImplementedError @abstractmethod - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs): raise NotImplementedError def _conv_measurements(self, measurements, mapping, measurement_model=None, @@ -470,7 +470,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ return artists - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): """Plots sensor(s) Plots sensors. Users can change the color and marker of detections using keyword @@ -480,6 +480,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): ---------- sensors : Collection of :class:`~.Sensor` Sensors to plot + mapping: list + List of items specifying the mapping of the position components of the + sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension` sensor_label: str Label to apply to all tracks for legend. \\*\\*kwargs: dict @@ -498,16 +501,19 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): if not isinstance(sensors, Collection): sensors = {sensors} # Make a set of length 1 + if mapping is None: + mapping = list(range(self.dimension)) + artists = [] for sensor in sensors: if self.dimension is Dimension.TWO: # plots the sensors in xy - artists.append(self.ax.scatter(sensor.position[0], - sensor.position[1], + artists.append(self.ax.scatter(sensor.position[mapping[0]], + sensor.position[mapping[1]], **sensor_kwargs)) elif self.dimension is Dimension.THREE: # plots the sensors in xyz - artists.extend(self.ax.plot3D(sensor.position[0], - sensor.position[1], - sensor.position[2], + artists.extend(self.ax.plot3D(sensor.position[mapping[0]], + sensor.position[mapping[1]], + sensor.position[mapping[2]], **sensor_kwargs)) else: raise NotImplementedError('Unsupported dimension type for sensor plotting') @@ -991,7 +997,7 @@ def func3(x): points = rotational_matrix @ points.T return points + state.mean[mapping[:2], :] - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs): """Plots sensor(s) Plots sensors. Users can change the color and marker of detections using keyword @@ -1001,6 +1007,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): ---------- sensors : Collection of :class:`~.Sensor` Sensors to plot + mapping: list + List of items specifying the mapping of the position + components of the sensor's position. sensor_label: str Label to apply to all tracks for legend. \\*\\*kwargs: dict @@ -1022,9 +1031,27 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): else: sensor_kwargs['showlegend'] = True - sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors]) + sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors]) self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs) + def hide_plot_traces(self, items_to_hide: set): + """Hide Plot Traces + + This function allows plotting items to be invisible as default. Users can toggle the plot + trace to visible. + + Parameters + ---------- + items_to_hide : set[str] + The legend label (`legendgroups`) for the plot traces that should be invisible as + default + """ + for fig_data in self.fig.data: + if fig_data.legendgroup in items_to_hide: + fig_data.visible = "legendonly" + else: + fig_data.visible = None + class _AnimationPlotterDataClass(Base): plotting_data = Property(Iterable[State])