Skip to content

Commit

Permalink
Merge pull request #836 from dstl/minor_plotting_additions
Browse files Browse the repository at this point in the history
Add Mapping to plot_sensors. Hide plotting elements in Plotterly
  • Loading branch information
sdhiscocks authored Sep 1, 2023
2 parents 24ebff1 + 156f9f2 commit f99986f
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@

from .models.base import LinearModel, Model

from enum import Enum
from enum import IntEnum


class Dimension(Enum):
class Dimension(IntEnum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
Used to sanitize inputs for the dimension attribute of Plotter().
Attributes
----------
TWO: str
TWO: int
Specifies 2D plotting for Plotter object
THREE: str
THREE: int
Specifies 3D plotting for Plotter object
"""
TWO = 2 # 2D plotting mode (original plotter.py functionality)
Expand All @@ -66,7 +66,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
raise NotImplementedError

@abstractmethod
def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs):
raise NotImplementedError

def _conv_measurements(self, measurements, mapping, measurement_model=None,
Expand Down Expand Up @@ -470,7 +470,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_

return artists

def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs):
"""Plots sensor(s)
Plots sensors. Users can change the color and marker of detections using keyword
Expand All @@ -480,6 +480,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position components of the
sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension`
sensor_label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Expand All @@ -498,16 +501,19 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
if not isinstance(sensors, Collection):
sensors = {sensors} # Make a set of length 1

if mapping is None:
mapping = list(range(self.dimension))

artists = []
for sensor in sensors:
if self.dimension is Dimension.TWO: # plots the sensors in xy
artists.append(self.ax.scatter(sensor.position[0],
sensor.position[1],
artists.append(self.ax.scatter(sensor.position[mapping[0]],
sensor.position[mapping[1]],
**sensor_kwargs))
elif self.dimension is Dimension.THREE: # plots the sensors in xyz
artists.extend(self.ax.plot3D(sensor.position[0],
sensor.position[1],
sensor.position[2],
artists.extend(self.ax.plot3D(sensor.position[mapping[0]],
sensor.position[mapping[1]],
sensor.position[mapping[2]],
**sensor_kwargs))
else:
raise NotImplementedError('Unsupported dimension type for sensor plotting')
Expand Down Expand Up @@ -991,7 +997,7 @@ def func3(x):
points = rotational_matrix @ points.T
return points + state.mean[mapping[:2], :]

def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs):
"""Plots sensor(s)
Plots sensors. Users can change the color and marker of detections using keyword
Expand All @@ -1001,6 +1007,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position
components of the sensor's position.
sensor_label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Expand All @@ -1022,9 +1031,27 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
else:
sensor_kwargs['showlegend'] = True

sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors])
sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors])
self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs)

def hide_plot_traces(self, items_to_hide: set):
"""Hide Plot Traces
This function allows plotting items to be invisible as default. Users can toggle the plot
trace to visible.
Parameters
----------
items_to_hide : set[str]
The legend label (`legendgroups`) for the plot traces that should be invisible as
default
"""
for fig_data in self.fig.data:
if fig_data.legendgroup in items_to_hide:
fig_data.visible = "legendonly"
else:
fig_data.visible = None


class _AnimationPlotterDataClass(Base):
plotting_data = Property(Iterable[State])
Expand Down

0 comments on commit f99986f

Please sign in to comment.