Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AnimatedPolarPlotterly Class #1105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3160,3 +3160,215 @@ def plot_sensors(self, sensors, label="Sensors", resize=True, **kwargs):

# we have called a plotting function so update flag (used in _resize)
self.plotting_function_called = True


class AnimatedPolarPlotterly(PolarPlotterly):
"""Class to produce 2D animated polar plots."""

def __init__(self, timesteps, tail_length=1, dimension=Dimension.TWO,
sim_duration=6, **kwargs):
if go is None:
raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`")
if isinstance(dimension, type(Dimension.TWO)):
self.dimension = dimension
elif isinstance(dimension, int):
self.dimension = Dimension(dimension)
else:
raise TypeError(f"{type(dimension)} is an unsupported type for \'dimension\'; "
f"expected type {type(Dimension.TWO)}")
if self.dimension != dimension.TWO:
raise TypeError("Only 2D plotting currently supported")

if len(timesteps) < 2:
raise ValueError("Must be at least 2 timesteps for animation.")

# checking that timesteps are evenly spaced
time_spaces = np.unique(np.diff(timesteps))

# gives the unique values of time gaps between timesteps. If this contains more than
# one value, then timesteps are not all evenly spaced which is an issue.
if len(time_spaces) != 1:
warnings.warn("Timesteps are not equally spaced, so the passage of time is not linear")
self.timesteps = timesteps

# checking input to tail_length
if tail_length > 1 or tail_length < 0:
raise ValueError("Tail length should be between 0 and 1")
self.tail_length = tail_length

# checking sim_duration
if sim_duration <= 0:
raise ValueError("Simulation duration must be positive")

# time window is calculated as sim_length * tail_length. This is
# the window of time for which past plots are still visible
self.time_window = (timesteps[-1] - timesteps[0]) * tail_length

layout_kwargs = dict()
self.colorway = colors.qualitative.Plotly[1:] # plotting colours

# Generate plot axes
self.fig = go.Figure(layout=layout_kwargs)
self.fig.frames = [dict(
name=str(time),
data=[],
traces=[]
) for time in timesteps]

frame_duration = sim_duration * 1000 / len(self.fig.frames)

# if the gap between timesteps is greater than a day, it isn't necessary
# to display hour and minute information, so remove this to give a cleaner display.
# a and b are used in the slider steps label later
if time_spaces[0] >= timedelta(days=1):
start_cut_off = None
end_cut_off = 10

# if the simulation is over a day long, display all information which
# looks clunky but is necessary
elif timesteps[-1] - timesteps[0] > timedelta(days=1):
start_cut_off = None
end_cut_off = None

# otherwise, remove day information and just show
# hours, mins, etc. which is cleaner to look at
else:
start_cut_off = 11
end_cut_off = None

# create button and slider
updatemenus = [dict(type='buttons',
buttons=[{
"args": [None,
{"frame": {"duration": frame_duration, "redraw": True},
"fromcurrent": True, "transition": {"duration": 0}}],
"label": "Play",
"method": "animate"
}, {
"args": [[None], {"frame": {"duration": 0, "redraw": True},
"mode": "immediate",
"transition": {"duration": 0}}],
"label": "Stop",
"method": "animate"
}],
direction='left',
pad=dict(r=10, t=75),
showactive=True, x=0.1, y=0, xanchor='right', yanchor='top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {'font': {'size': 16}, 'prefix': 'Time: ', 'visible': True,
'xanchor': 'right'},
'transition': {'duration': frame_duration, 'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [[frame.name], {
'frame': {'duration': 1.0, 'easing': 'linear', 'redraw': True},
'transition': {'duration': 0, 'easing': 'linear'}}],
'label': frame.name[start_cut_off: end_cut_off],
'method': 'animate'} for frame in
self.fig.frames
]}]
self.fig.update_layout(updatemenus=updatemenus, sliders=sliders)
layout_kwargs.update(kwargs)

def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping: int,
label="", **kwargs):
"""Plots state sequence(s)

Plots each state sequence passed in to :attr:`state_sequences` and generates a legend
automatically.

Users can change line style, color and marker using keyword arguments. Any changes
will apply to all ground truths.

Parameters
----------
state_sequences : Collection of :class:`~.StateMutableSequence`
Collection of state sequences which will be plotted. If not a collection,
and instead a single :class:`~.StateMutableSequence` type, the argument is modified
to be a set to allow for iteration.
angle_mapping: int
Specifying the mapping of the angular component of the state space to be plotted.
range_mapping: int
Specifying the mapping of the range component of the state space to be plotted.
label: str
Label for truth data.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Default is
``mode=marker``.
The default unit for the angular component is radians. This can be changed to degrees
with the keyword argument ``thetaunit='degrees'``.
"""

if not isinstance(state_sequences, Collection) \
or isinstance(state_sequences, StateMutableSequence):
state_sequences = {state_sequences}

if range_mapping is None:
raise NotImplementedError(
"Angle vs Time plots are not supported for Animated Polar Plots.")

plotting_kwargs = dict(
mode="markers", legendgroup=label, legendrank=200,
name=label, thetaunit="radians")
merge(plotting_kwargs, kwargs)
add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
data = [dict() for _ in state_sequences]
for n, state_sequence in enumerate(state_sequences):
data[n].update(
angle=np.zeros(len(state_sequence)),
range=np.zeros(len(state_sequence)),
time=np.array([0 for _ in range(len(state_sequence))], dtype=object),
time_str=np.array([0 for _ in range(len(state_sequence))], dtype=object),
type=np.array([0 for _ in range(len(state_sequence))], dtype=object))
for k, state in enumerate(state_sequence):
data[n]["angle"][k] = state.state_vector[angle_mapping]
data[n]["range"][k] = state.state_vector[range_mapping]
data[n]["time"][k] = state.timestamp
data[n]["time_str"][k] = str(state.timestamp)
data[n]["type"][k] = type(state).__name__

trace_base = len(self.fig.data)
scatter_kwargs = plotting_kwargs.copy()
if add_legend:
scatter_kwargs['showlegend'] = True
add_legend = False
else:
scatter_kwargs['showlegend'] = False
merge(scatter_kwargs, kwargs)
self.fig.add_trace(go.Scatterpolar(scatter_kwargs))

for n, _ in enumerate(state_sequences):
merge(scatter_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)])))
merge(scatter_kwargs, kwargs)
self.fig.add_trace(go.Scatterpolar(scatter_kwargs))

for frame in self.fig.frames:
data_ = list(frame.data)
traces_ = list(frame.traces)

frame_time = datetime.fromisoformat(frame.name)
cutoff_time = frame_time - self.time_window

for n, state_sequence in enumerate(state_sequences):
t_upper = [data[n]["time"] <= frame_time]
t_lower = [data[n]["time"] >= cutoff_time]

mask = np.logical_and(t_upper, t_lower)

state_angle = data[n]["angle"][tuple(mask)]
state_angle = np.append(state_angle, [np.inf])

state_range = data[n]["range"][tuple(mask)]
state_range = np.append(state_range, [np.inf])

times = data[n]["time_str"][tuple(mask)]
data_.append(go.Scatterpolar(r=state_range,
theta=state_angle,
meta=times))
traces_.append(trace_base + n + 1)

frame.data = data_
frame.traces = traces_
35 changes: 22 additions & 13 deletions stonesoup/tests/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \
ConstantVelocity
from stonesoup.plotter import Plotter, Dimension, AnimatedPlotterly, AnimationPlotter, Plotterly, \
PolarPlotterly
PolarPlotterly, AnimatedPolarPlotterly
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.sensor.radar.radar import RadarElevationBearingRange
from stonesoup.types.detection import TrueDetection, Clutter
Expand Down Expand Up @@ -115,17 +115,17 @@ def plotter_class(request):

plotter_class = request.param
assert plotter_class in {Plotter, Plotterly, AnimationPlotter,
PolarPlotterly, AnimatedPlotterly}
PolarPlotterly, AnimatedPlotterly, AnimatedPolarPlotterly}

def _generate_animated_plotterly(*args, **kwargs):
return AnimatedPlotterly(*args, timesteps=timesteps, **kwargs)
return plotter_class(*args, timesteps=timesteps, **kwargs)

def _generate_plotter(*args, **kwargs):
return plotter_class(*args, **kwargs)

if plotter_class in {Plotter, Plotterly, AnimationPlotter, PolarPlotterly}:
yield _generate_plotter
elif plotter_class is AnimatedPlotterly:
elif plotter_class in {AnimatedPlotterly, AnimatedPolarPlotterly}:
yield _generate_animated_plotterly
else:
raise ValueError("Invalid Plotter type.")
Expand All @@ -152,7 +152,8 @@ def test_plot_sensors():

@pytest.mark.parametrize(
"plotter_class",
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True)
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly,
AnimatedPolarPlotterly], indirect=True)
def test_empty_tracks(plotter_class):
plotter = plotter_class()
plotter.plot_tracks(set(), [0, 2])
Expand Down Expand Up @@ -404,7 +405,8 @@ def test_show_plot(labels):

@pytest.mark.parametrize(
"plotter_class",
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True)
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly,
AnimatedPolarPlotterly], indirect=True)
@pytest.mark.parametrize(
"_measurements",
[true_measurements, clutter_measurements, all_measurements,
Expand All @@ -417,7 +419,8 @@ def test_plotters_plot_measurements_2d(plotter_class, _measurements):

@pytest.mark.parametrize(
"plotter_class",
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True)
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly,
AnimatedPolarPlotterly], indirect=True)
def test_plotters_plot_tracks(plotter_class):
plotter = plotter_class()
plotter.plot_tracks(track, [0, 2])
Expand All @@ -429,7 +432,8 @@ def test_plotters_plot_tracks(plotter_class):
Plotterly,
pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)),
AnimatedPlotterly],
AnimatedPlotterly,
pytest.param(AnimatedPolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError))],
indirect=True
)
def test_plotters_plot_track_uncertainty(plotter_class):
Expand All @@ -441,16 +445,18 @@ def test_plotters_plot_track_uncertainty(plotter_class):
@pytest.mark.parametrize(
"plotter_class",
[AnimationPlotter,
PolarPlotterly]
)
PolarPlotterly,
AnimatedPolarPlotterly],
indirect=True)
def test_plotters_plot_track_particle(plotter_class):
plotter = plotter_class()
plotter.plot_tracks(track, [0, 2], particle=True)


@pytest.mark.parametrize(
"plotter_class",
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True)
[Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly,
AnimatedPolarPlotterly], indirect=True)
def test_plotters_plot_truths(plotter_class):
plotter = plotter_class()
plotter.plot_ground_truths(truth, [0, 2])
Expand All @@ -462,15 +468,18 @@ def test_plotters_plot_truths(plotter_class):
Plotterly,
pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)),
AnimatedPlotterly], indirect=True
AnimatedPlotterly,
pytest.param(AnimatedPolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError))],
indirect=True
)
def test_plotters_plot_sensors(plotter_class):
plotter = plotter_class()
plotter.plot_sensors(sensor2d)


@pytest.mark.parametrize("plotter_class",
[Plotterly, PolarPlotterly, AnimatedPlotterly], indirect=True)
[Plotterly, PolarPlotterly, AnimatedPlotterly, PolarPlotterly],
indirect=True)
@pytest.mark.parametrize("_measurements, expected_labels",
[(true_measurements, {'Measurements'}),
(clutter_measurements, {'Measurements<br>(Clutter)'}),
Expand Down