diff --git a/.circleci/config.yml b/.circleci/config.yml index b5ba959b1..8a3b9fa34 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -28,7 +28,8 @@ jobs: python -m venv venv . venv/bin/activate pip install --upgrade pip - pip install -e .[dev,orbital] opencv-python-headless pyehm + pip install -e .[dev,orbital,architectures] opencv-python-headless pyehm + - save_cache: paths: - ./venv @@ -98,6 +99,11 @@ jobs: - checkout - restore_cache: key: dependencies-doc-{{ .Environment.CACHE_VERSION }}-{{ checksum "/home/circleci/.pyenv/version" }}-{{ checksum "setup.cfg" }} + - run: + name: Install OS Dependencies + command: | + sudo apt-get update + sudo apt-get install -y graphviz - run: name: Install Dependencies command: | @@ -105,7 +111,7 @@ jobs: . venv/bin/activate pip install --upgrade pip pip install -r docs/ci-requirements.txt - pip install -e .[dev,orbital] opencv-python-headless + pip install -e .[dev,orbital,architectures] opencv-python-headless - save_cache: paths: - ./venv diff --git a/docs/source/_static/sphinx_gallery/ArchTutorial_1.png b/docs/source/_static/sphinx_gallery/ArchTutorial_1.png new file mode 100644 index 000000000..51ed8ca0c Binary files /dev/null and b/docs/source/_static/sphinx_gallery/ArchTutorial_1.png differ diff --git a/docs/source/_static/sphinx_gallery/ArchTutorial_2.png b/docs/source/_static/sphinx_gallery/ArchTutorial_2.png new file mode 100644 index 000000000..8b5118d14 Binary files /dev/null and b/docs/source/_static/sphinx_gallery/ArchTutorial_2.png differ diff --git a/docs/source/_static/sphinx_gallery/ArchTutorial_3.png b/docs/source/_static/sphinx_gallery/ArchTutorial_3.png new file mode 100644 index 000000000..f7fbc73f3 Binary files /dev/null and b/docs/source/_static/sphinx_gallery/ArchTutorial_3.png differ diff --git a/docs/tutorials/architecture/01_Introduction_to_Architectures.py b/docs/tutorials/architecture/01_Introduction_to_Architectures.py new file mode 100644 index 000000000..399b59134 --- /dev/null +++ b/docs/tutorials/architecture/01_Introduction_to_Architectures.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +=============================================== +1 - Introduction to Architectures in Stone Soup +=============================================== +""" + +# %% +# Introduction +# ------------ +# +# The architecture package in Stone Soup provides functionality to build information and network +# architectures, enabling the user to simulate sensing, propagation and fusion of data. +# Architectures are modelled by defining the nodes in the architecture, and edges that represent +# connections between nodes. +# +# Nodes +# ----- +# +# Nodes represent points in the architecture that collect, process (fuse), or simply forward on +# data. Before advancing, a few definitions are required: +# +# - Relationships between nodes are defined as parent-child. In a directed graph, an edge from +# node A to node B means that data is passed from the child node, A, to the parent node, B. +# +# - The children of node A, denoted :math:`children(A)`, is defined as the set of nodes B, where +# there exists a direct edge from node B to node A (set of nodes that A receives data from). +# +# - The parents of node A, denoted :math:`parents(A)`, is defined as the set of nodes B, where +# there exists a direct edge from node A to node B (set of nodes that A passes data to). +# +# Different types of node can provide different functionality in the architecture. The following +# are available in Stone Soup: +# +# - :class:`.~SensorNode`: makes detections of targets and propagates data onwards through the +# architecture. +# +# - :class:`.~FusionNode`: receives data from child nodes, and fuses to achieve a fused result. +# The fused result can be propagated onwards. +# +# - :class:`.~SensorFusionNode`: has the functionality of both a SensorNode and a FusionNode. +# +# - :class:`.~RepeaterNode`: does not create or fuse data, but only propagates it onwards. +# It is only used in network architectures. +# +# Set up and Node Properties +# ^^^^^^^^^^^^^^^^^^^^^^^^^^ + +from stonesoup.architecture.node import Node + +node_A = Node(label='Node A') +node_B = Node(label='Node B') +node_C = Node(label='Node C') + +# %% +# The :class:`.~Node` base class contains several properties. The `latency` property gives +# functionality to simulate processing latency at the node. The rest of the properties (`label`, +# `position`, `colour`, `shape`, `font_size`, `node_dim`), are used for graph plotting. + +node_A.colour = '#006494' + +node_A.shape = 'hexagon' + +# %% +# :class:`~.SensorNode` and :class:`~.FusionNode` objects have additional properties that must be +# defined. A :class:`~.SensorNode` must be given an additional `sensor` property - this must be a +# :class:`~.Sensor`. A :class:`~.FusionNode` has two additional properties: `tracker` and +# `fusion_queue`.`tracker` must both be :class:`~.Tracker`\s - the main tracker manages the +# fusion at the node, while the `fusion_queue` property is a :class:`~.FusionQueue` by default - +# this manages the inflow of data from child nodes. +# +# Edges +# ----- +# An edge represents a link between two nodes in an architecture. An :class:`~.Edge` contains a +# property `nodes`: a tuple of :class:`~.Node` objects where the first entry in the tuple is +# the child node and the second is the parent. Edges in Stone Soup are directional (data can +# flow only in one direction), with data flowing from child to parent. Edge objects also +# contain a `latency` property to enable simulation of latency caused by sending a message, +# separately to node latency. + +from stonesoup.architecture.edge import Edge + +edge1 = Edge(nodes=(node_B, node_A)) +edge2 = Edge(nodes=(node_C, node_A)) + +# %% +# :class:`~.Edges` is a container class for :class:`~.Edge` objects. :class:`~.Edges` has an +# `edges` property - a list of :class:`~.Edge` objects. An :class:`~.Edges` object is required +# to pass into an :class:`~.Architecture`. + +from stonesoup.architecture.edge import Edges + +edges = Edges(edges=[edge1, edge2]) + + +# %% +# Architecture +# ------------ +# Architecture classes manage the simulation of data propagation across a network. Two +# architecture classes are available in Stone Soup: :class:`~.InformationArchitecture` and +# :class:`~.NetworkArchitecture`. Information architecture simulates how +# information is shared across the network, only considering nodes that create or modify +# information. Network architecture simulates how data is actually +# propagated through a network. All nodes are considered including nodes that don't open or modify +# any data. +# +# A good analogy for the two is receiving a parcel via post. The "information architecture" is +# from sender to receiver, the former who creates the parcel, and the latter who opens it. +# However, between these are many unseen but crucial steps which form the "network architecture". +# In this analogy, the postman never does anything with the parcel besides deliver it, so functions +# like a Stone Soup :class:`~.RepeaterNode`. +# +# +# Architecture classes contain an `edges` property - this must be an :class:`~.Edges` object. +# The `current_time` property of an Architecture instance maintains the current time within the +# simulation. By default, this begins at the current time of the operating system. + +# sphinx_gallery_thumbnail_path = '_static/sphinx_gallery/ArchTutorial_1.png' + +from stonesoup.architecture import InformationArchitecture + +arch = InformationArchitecture(edges=edges) +arch \ No newline at end of file diff --git a/docs/tutorials/architecture/02_Information_and_Network_Architectures.py b/docs/tutorials/architecture/02_Information_and_Network_Architectures.py new file mode 100644 index 000000000..231fc946a --- /dev/null +++ b/docs/tutorials/architecture/02_Information_and_Network_Architectures.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +======================================== +2 - Information vs Network Architectures +======================================== +""" +# %% +# Comparing Information and Network Architectures Using ArchitectureGenerators +# ---------------------------------------------------------------------------- +# +# In this demo, we intend to show that running a simulation over both an information +# architecture and its underlying network architecture yields the same results. +# +# To build this demonstration, we shall carry out the following steps: +# +# 1) Build a ground truth, as a basis for the simulation +# +# 2) Build a base sensor model, and a base tracker +# +# 3) Use the :class:`~.ArchitectureGenerator` classes to generate 2 pairs of +# identical architectures (one of each type), where the network architecture +# is a valid representation of the information architecture. +# +# 4) Run the simulation over both, and compare results. +# +# 5) Remove edges from each of the architectures, and rerun. + +# %% +# Module Imports +# ^^^^^^^^^^^^^^ + +from datetime import datetime, timedelta +from ordered_set import OrderedSet +import numpy as np +import random + +# %% +# 1 - Ground Truth +# ---------------- +# We start this tutorial by generating a set of :class:`~.GroundTruthPath`\s as a basis for a +# tracking simulation. + + +start_time = datetime.now().replace(microsecond=0) +np.random.seed(2024) +random.seed(2024) + +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState + +# Generate transition model +transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(0.005), + ConstantVelocity(0.005)]) + +yps = range(0, 100, 10) # y value for prior state +truths = OrderedSet() +ntruths = 3 # number of ground truths in simulation +time_max = 60 # timestamps the simulation is observed over +timesteps = [start_time + timedelta(seconds=k) for k in range(time_max)] + +xdirection = 1 +ydirection = 1 + +# Generate ground truths +for j in range(0, ntruths): + truth = GroundTruthPath([GroundTruthState([0, xdirection, yps[j], ydirection], + timestamp=timesteps[0])], id=f"id{j}") + + for k in range(1, time_max): + truth.append( + GroundTruthState(transition_model.function(truth[k - 1], noise=True, + time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + truths.add(truth) + + xdirection *= -1 + if j % 2 == 0: + ydirection *= -1 + +# %% +# 2 - Base Tracker and Sensor Models +# ---------------------------------- +# We can use the :class:`~.ArchitectureGenerator` classes to generate multiple identical +# architectures. These classes take in base tracker and sensor models, which are duplicated and +# applied to each relevant node in the architecture. The base tracker must not have a detector, +# in order for it to be duplicated - the detector will be applied during the architecture +# generation step. +# +# Sensor Model +# ^^^^^^^^^^^^ +# The base sensor model's `position` property is used to calculate a location for sensors in +# the architectures that we will generate. As you'll see in later steps, we can either plot +# all sensors at the same location (`base_sensor.position`), or in a specified range around +# the base sensor's position (`base_sensor.position` +- a specified distance). + + +from stonesoup.types.state import StateVector +from stonesoup.sensor.radar.radar import RadarRotatingBearingRange +from stonesoup.types.angle import Angle + +# Create base sensor +base_sensor = RadarRotatingBearingRange( + position_mapping=(0, 2), + noise_covar=np.array([[0.25*np.radians(0.5) ** 2, 0], + [0, 0.25*1 ** 2]]), + ndim_state=4, + position=np.array([[10], [10]]), + rpm=60, + fov_angle=np.radians(360), + dwell_centre=StateVector([0.0]), + max_range=np.inf, + resolution=Angle(np.radians(30)) +) +base_sensor.timestamp = start_time + +# %% +# Tracker +# ^^^^^^^ +# The base tracker is used here in the same way as the base sensor - it is duplicated and applied +# to each fusion node. In order to duplicate the tracker, its components must all be compatible +# with being deep-copied. This means that we need to remove the fusion queue and reassign it +# after duplication. + + +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.updater.kalman import ExtendedKalmanUpdater +from stonesoup.hypothesiser.distance import DistanceHypothesiser +from stonesoup.measures import Mahalanobis +from stonesoup.dataassociator.neighbour import GNNWith2DAssignment +from stonesoup.deleter.time import UpdateTimeStepsDeleter +from stonesoup.types.state import GaussianState +from stonesoup.initiator.simple import MultiMeasurementInitiator +from stonesoup.tracker.simple import MultiTargetTracker +from stonesoup.updater.wrapper import DetectionAndTrackSwitchingUpdater +from stonesoup.updater.chernoff import ChernoffUpdater + +predictor = KalmanPredictor(transition_model) +updater = ExtendedKalmanUpdater(measurement_model=None) +hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(), missed_distance=5) +data_associator = GNNWith2DAssignment(hypothesiser) +deleter = UpdateTimeStepsDeleter(2) +initiator = MultiMeasurementInitiator( + prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 1, 0, 1])), + measurement_model=None, + deleter=deleter, + data_associator=data_associator, + updater=updater, + min_points=4, + ) + +track_updater = ChernoffUpdater(None) +detection_updater = ExtendedKalmanUpdater(None) +detection_track_updater = DetectionAndTrackSwitchingUpdater(None, detection_updater, track_updater) + +base_tracker = MultiTargetTracker( + initiator, deleter, None, data_associator, detection_track_updater) + +# %% +# 3 - Generate Identical Architectures +# ------------------------------------ +# The :class:`~.NetworkArchitecture` class has a property `information_arch`, which contains the +# information architecture representation of the underlying network architecture. This means +# that if we use the :class:`~.NetworkArchitectureGenerator` class to generate a pair of identical +# network architectures, we can extract the information architecture from one. +# +# This will provide us with two completely separate architecture classes: a network architecture, +# and an information architecture representation of the same network architecture. This will +# enable us to run simulations on both without interference between the two. + + +from stonesoup.architecture.generator import NetworkArchitectureGenerator + +gen = NetworkArchitectureGenerator('decentralised', + start_time, + mean_degree=2, + node_ratio=[3, 1, 2], + base_tracker=base_tracker, + base_sensor=base_sensor, + sensor_max_distance=(30, 30), + n_archs=4) +id_net_archs = gen.generate() + +# Network and Information arch pair +network_arch = id_net_archs[0] +information_arch = id_net_archs[1].information_arch + +network_arch + +# %% +information_arch + +# %% +# The two plots above display a network architecture, and corresponding information architecture, +# respectively. Grey nodes in the network architecture represent repeater nodes - these have the +# sole purpose of passing data from one node to another. Comparing the two graphs, while ignoring +# the repeater nodes, should confirm that the two plots are both representations of the same +# system. + +# %% +# 4 - Tracking Simulations +# ------------------------ +# With two identical architectures, we can now run a simulation over both, in an attempt to +# produce identical results. +# +# Run Network Architecture Simulation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Run the simulation over the network architecture. We then extract some extra information from +# the architecture to add to the plot - location of sensors, and detections. + + +for time in timesteps: + network_arch.measure(truths, noise=True) + network_arch.propagate(time_increment=1) + +# %% +na_sensors = [] +na_dets = set() +for sn in network_arch.sensor_nodes: + na_sensors.append(sn.sensor) + for timestep in sn.data_held['created'].keys(): + for datapiece in sn.data_held['created'][timestep]: + na_dets.add(datapiece.data) + +# %% +# Plot +# ^^^^ + +from stonesoup.plotter import Plotterly + + +def reduce_tracks(tracks): + return { + type(track)([s for s in track.last_timestamp_generator()]) + for track in tracks} + + +plotter = Plotterly() +plotter.plot_ground_truths(truths, [0, 2]) +for node in network_arch.fusion_nodes: + if True: + hexcol = ["#"+''.join([random.choice('ABCDEF0123456789') for i in range(6)])] + plotter.plot_tracks(reduce_tracks(node.tracks), + [0, 2], + track_label=str(node.label), + line=dict(color=hexcol[0]), + uncertainty=True) +plotter.plot_sensors(na_sensors) +plotter.plot_measurements(na_dets, [0, 2]) +plotter.fig + +# %% +# Run Information Architecture Simulation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Now we run the simulation over the information architecture. As before, we extract some extra +# information from the architecture to add to the plot - location of sensors, and detections. + + +for time in timesteps: + information_arch.measure(truths, noise=True) + information_arch.propagate(time_increment=1) + +# %% +ia_sensors = [] +ia_dets = set() +for sn in information_arch.sensor_nodes: + ia_sensors.append(sn.sensor) + for timestep in sn.data_held['created'].keys(): + for datapiece in sn.data_held['created'][timestep]: + ia_dets.add(datapiece.data) + +# %% +plotter = Plotterly() +plotter.plot_ground_truths(truths, [0, 2]) +for node in information_arch.fusion_nodes: + if True: + hexcol = ["#"+''.join([random.choice('ABCDEF0123456789') for i in range(6)])] + plotter.plot_tracks(reduce_tracks(node.tracks), [0, 2], + track_label=str(node.label), + line=dict(color=hexcol[0]), uncertainty=True) +plotter.plot_sensors(ia_sensors) +plotter.plot_measurements(ia_dets, [0, 2]) +plotter.fig + +# %% +# Comparing Tracks from each Architecture +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# The information architecture we have studied is hierarchical, and while the network +# architecture isn't strictly a hierarchical graph, it does have one central node (Fusion Node 1) +# receiving all information. The code below plots SIAP metrics for the +# tracks maintained at Fusion Node 1 in both architecures. Some variation between the two is +# expected due to the randomness of the measurements, but we aim to show that the results from +# both architectures are near identical. + +top_node = network_arch.top_level_nodes.pop() + +# %% +from stonesoup.metricgenerator.tracktotruthmetrics import SIAPMetrics +from stonesoup.measures import Euclidean +from stonesoup.dataassociator.tracktotrack import TrackToTruth +from stonesoup.metricgenerator.manager import MultiManager + +network_siap = SIAPMetrics(position_measure=Euclidean((0, 2)), + velocity_measure=Euclidean((1, 3)), + generator_name='network_siap', + tracks_key='network_tracks', + truths_key='truths' + ) + +associator = TrackToTruth(association_threshold=30) + + +# %% +network_metric_manager = MultiManager([network_siap], associator) +network_metric_manager.add_data({'network_tracks': top_node.tracks, + 'truths': truths}, overwrite=False) +network_metrics = network_metric_manager.generate_metrics() + +# %% +network_siap_metrics = network_metrics['network_siap'] +network_siap_averages = {network_siap_metrics.get(metric) for metric in network_siap_metrics if + metric.startswith("SIAP") and not metric.endswith(" at times")} + +# %% +top_node = information_arch.top_level_nodes.pop() + +# %% +information_siap = SIAPMetrics(position_measure=Euclidean((0, 2)), + velocity_measure=Euclidean((1, 3)), + generator_name='information_siap', + tracks_key='information_tracks', + truths_key='truths' + ) + +associator = TrackToTruth(association_threshold=30) + +# %% +information_metric_manager = MultiManager([information_siap], associator) +information_metric_manager.add_data({'information_tracks': top_node.tracks, + 'truths': truths}, overwrite=False) +information_metrics = information_metric_manager.generate_metrics() + +# %% +information_siap_metrics = information_metrics['information_siap'] +information_siap_averages = {information_siap_metrics.get(metric) for + metric in information_siap_metrics if + metric.startswith("SIAP") and not metric.endswith(" at times")} + +# %% +from stonesoup.metricgenerator.metrictables import SIAPDiffTableGenerator +SIAPDiffTableGenerator([network_siap_averages, information_siap_averages]).compute_metric() + +# %% +# 5 - Remove edges from each architecture and re-run +# -------------------------------------------------- +# In this section, we take an identical copy of each of the architectures above, and remove an +# edge. We aim to show the following: +# +# * It is possible to remove certain edges from a network architecture without affecting the +# performance of the network. +# * Removing an edge from an information architecture will likely have an effect on performance. +# +# First, we must set up the two architectures, and remove an edge from each. In the network +# architecture, there are multiple routes between some pairs of nodes. This redundency increases +# the resilience of the network when an edge, or node, is taken out of action. In this example, +# we remove edges connecting repeater node r3, in turn, disabling a route from sensor node s0 +# to fusion node f0. As another route from s0 to f0 exists (via repeater node r4), the +# performance of the network should not be effected (assuming unlimited bandwidth). + +# %% +# Network and Information arch pair +network_arch_rm = id_net_archs[2] +information_arch_rm = id_net_archs[3].information_arch + +# %% +rm = [] +for edge in network_arch_rm.edges: + if 'r3' in [node.label for node in edge.nodes]: + rm.append(edge) + +for edge in rm: + network_arch_rm.edges.remove(edge) + +# %% +network_arch_rm + +# %% +# Now we remove an edge from the information architecture. You could choose pretty much any +# edge here, but removing the edge between sf0 and f1 is likely to cause the greatest destruction +# (in the interest of the reader). Removing this edge creates a disconnected graph. The Stone +# Soup architecture module can deal with this with no issues, but for this example we will now +# only consider the connected subgraph containing node f1. + + +rm = [] +for edge in information_arch_rm.edges: + if ('sf0' in [node.label for node in edge.nodes]) and \ + ('f1' in [node.label for node in edge.nodes]): + rm.append(edge) + +for edge in rm: + information_arch_rm.edges.remove(edge) + +# %% +information_arch_rm + +# %% +# We now run the simulation for both architectures and calculate the same SIAP metrics as we +# did before for the original architectures. + + +for time in timesteps: + network_arch_rm.measure(truths, noise=True) + network_arch_rm.propagate(time_increment=1) + information_arch_rm.measure(truths, noise=True) + information_arch_rm.propagate(time_increment=1) + +# %% +top_node = [node for node in network_arch_rm.all_nodes if node.label == 'f1'][0] + +network_rm_siap = SIAPMetrics(position_measure=Euclidean((0, 2)), + velocity_measure=Euclidean((1, 3)), + generator_name='network_rm_siap', + tracks_key='network_rm_tracks', + truths_key='truths' + ) + +network_rm_metric_manager = MultiManager([network_rm_siap], associator) +network_rm_metric_manager.add_data({'network_rm_tracks': top_node.tracks, + 'truths': truths}, overwrite=False) +network_rm_metrics = network_rm_metric_manager.generate_metrics() + +network_rm_siap_metrics = network_rm_metrics['network_rm_siap'] +network_rm_siap_averages = {network_rm_siap_metrics.get(metric) for + metric in network_rm_siap_metrics + if metric.startswith("SIAP") and not metric.endswith(" at times")} + +# %% +top_node = [node for node in information_arch_rm.all_nodes if node.label == 'f1'][0] + +information_rm_siap = SIAPMetrics(position_measure=Euclidean((0, 2)), + velocity_measure=Euclidean((1, 3)), + generator_name='information_rm_siap', + tracks_key='information_rm_tracks', + truths_key='truths' + ) + +information_rm_metric_manager = MultiManager([information_rm_siap], + associator) # associator for generating SIAP metrics +information_rm_metric_manager.add_data({'information_rm_tracks': top_node.tracks, + 'truths': truths}, overwrite=False) +information_rm_metrics = information_rm_metric_manager.generate_metrics() + +information_rm_siap_metrics = information_rm_metrics['information_rm_siap'] +information_rm_siap_averages = {information_rm_siap_metrics.get(metric) for + metric in information_rm_siap_metrics + if metric.startswith("SIAP") and not metric.endswith(" at times")} + +# %% +# Plotting the metrics for the two original architectures, and the metrics for the copies with +# edges removed, should display the result we predicted at the start of this section. + +# %% +SIAPDiffTableGenerator([network_siap_averages, + information_siap_averages, + network_rm_siap_averages, + information_rm_siap_averages], + ['Network', 'Info', 'Network RM', 'Info RM']).compute_metric(); diff --git a/docs/tutorials/architecture/03_Avoiding_Data_Incest.py b/docs/tutorials/architecture/03_Avoiding_Data_Incest.py new file mode 100644 index 000000000..35ce0e44a --- /dev/null +++ b/docs/tutorials/architecture/03_Avoiding_Data_Incest.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +======================== +3 - Avoiding Data Incest +======================== +""" + +# %% +# Introduction +# ------------ +# This tutorial uses the Stone Soup architecture module to provide an example +# of how data incest can occur in a poorly designed network. +# +# In this example, data incest is shown in a simple architecture. A top-level +# fusion node receives data from two sources, which contain information (tracks) +# sourced from two sensors. However, one sensor is overly represented, due to a +# triangle in the information architecture graph. As a consequence, the fusion +# node becomes overconfident, or biased towards the duplicated data. +# +# The aim is to demonstrate this effect by modelling two similar +# information architectures: a centralised (non-hierarchical) architecture, +# and a hierarchical alternative, and looks to compare the fused results +# at the top-level node. +# +# We will follow the following steps: +# +# 1) Define sensors for sensor nodes +# +# 2) Simulate a ground truth, as a basis for the simulation +# +# 3) Create trackers for fusion nodes +# +# 4) Build a non-hierarchical architecture, containing a triangle +# +# 5) Build a hierarchical architecture by removing an edge +# from the non-hierarchical architecture +# +# 6) Compare and contrast. What difference, if any, will the +# hierarchical alternative make? +# + +import random +import copy +import math +import numpy as np +import matplotlib.pyplot as plt +from datetime import datetime, timedelta + +start_time = datetime.now().replace(microsecond=0) +np.random.seed(1990) +random.seed(1990) + +# %% +# 1) Sensors +# ^^^^^^^^^^ +# +# We need two sensors to be assigned to the two sensor nodes. +# Notice they vary only in their position. + +from stonesoup.models.clutter import ClutterModel +from stonesoup.models.measurement.linear import LinearGaussian +from stonesoup.types.state import CovarianceMatrix + +mm = LinearGaussian(ndim_state=4, + mapping=[0, 2], + noise_covar=CovarianceMatrix(np.diag([0.5, 0.5])), + seed=6) + +mm2 = LinearGaussian(ndim_state=4, + mapping=[0, 2], + noise_covar=CovarianceMatrix(np.diag([0.5, 0.5])), + seed=6) + +# %% +from stonesoup.sensor.sensor import SimpleSensor +from stonesoup.models.measurement.base import MeasurementModel +from stonesoup.base import Property + + +class DummySensor(SimpleSensor): + measurement_model: MeasurementModel = Property(doc="TODO") + + def is_detectable(self, state): + return True + + def is_clutter_detectable(self, state): + return True + + +sensor1 = DummySensor(measurement_model=mm, + position=np.array([[10], [-20]]), + clutter_model=ClutterModel(clutter_rate=5, + dist_params=((-100, 100), (-50, 60)), seed=6)) +sensor1.clutter_model.distribution = sensor1.clutter_model.random_state.uniform +sensor2 = DummySensor(measurement_model=mm2, + position=np.array([[10], [20]]), + clutter_model=ClutterModel(clutter_rate=5, + dist_params=((-100, 100), (-50, 60)), seed=6)) +sensor2.clutter_model.distribution = sensor2.clutter_model.random_state.uniform + +# %% +# 2) Ground Truth +# ^^^^^^^^^^^^^^^ + +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from ordered_set import OrderedSet + +# Generate transition model +transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(0.005), + ConstantVelocity(0.005)]) + +yps = range(0, 100, 10) # y value for prior state +truths = OrderedSet() +ntruths = 3 # number of ground truths in simulation +time_max = 60 # timestamps the simulation is observed over +timesteps = [start_time + timedelta(seconds=k) for k in range(time_max)] + +xdirection = 1 +ydirection = 1 + +# Generate ground truths +for j in range(0, ntruths): + truth = GroundTruthPath([GroundTruthState([0, xdirection, yps[j], ydirection], + timestamp=timesteps[0])], id=f"id{j}") + + for k in range(1, time_max): + truth.append( + GroundTruthState(transition_model.function(truth[k - 1], noise=True, + time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + truths.add(truth) + + xdirection *= -1 + if j % 2 == 0: + ydirection *= -1 + +# %% +# 3) Build Trackers +# ^^^^^^^^^^^^^^^^^ +# We use the same configuration of trackers and track-trackers as we did in the +# previous tutorial. +# + +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.updater.kalman import ExtendedKalmanUpdater +from stonesoup.hypothesiser.distance import DistanceHypothesiser +from stonesoup.measures import Mahalanobis +from stonesoup.dataassociator.neighbour import GNNWith2DAssignment +from stonesoup.deleter.error import CovarianceBasedDeleter +from stonesoup.types.state import GaussianState +from stonesoup.initiator.simple import MultiMeasurementInitiator +from stonesoup.tracker.simple import MultiTargetTracker +from stonesoup.architecture.edge import FusionQueue + +prior = GaussianState([[0], [1], [0], [1]], np.diag([1, 1, 1, 1])) +predictor = KalmanPredictor(transition_model) +updater = ExtendedKalmanUpdater(measurement_model=None) +hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(), missed_distance=5) +data_associator = GNNWith2DAssignment(hypothesiser) +deleter = CovarianceBasedDeleter(covar_trace_thresh=3) +initiator = MultiMeasurementInitiator( + prior_state=prior, + measurement_model=None, + deleter=deleter, + data_associator=data_associator, + updater=updater, + min_points=5, +) + +tracker = MultiTargetTracker(initiator, deleter, None, data_associator, updater) + +# %% +# Track Tracker +# ^^^^^^^^^^^^^ +# + +from stonesoup.updater.wrapper import DetectionAndTrackSwitchingUpdater +from stonesoup.updater.chernoff import ChernoffUpdater +from stonesoup.feeder.track import Tracks2GaussianDetectionFeeder + +track_updater = ChernoffUpdater(None) +detection_updater = ExtendedKalmanUpdater(None) +detection_track_updater = DetectionAndTrackSwitchingUpdater(None, detection_updater, track_updater) + +fq = FusionQueue() + +track_tracker = MultiTargetTracker( + initiator, deleter, None, data_associator, detection_track_updater) + +# %% +# 4) Non-Hierarchical Architecture +# -------------------------------- +# We start by constructing the non-hierarchical, centralised architecture. +# +# Nodes +# ^^^^^ + +from stonesoup.architecture.node import SensorNode, FusionNode + +sensornode1 = SensorNode(sensor=copy.deepcopy(sensor1), label='Sensor Node 1') +sensornode1.sensor.clutter_model.distribution = \ + sensornode1.sensor.clutter_model.random_state.uniform + +sensornode2 = SensorNode(sensor=copy.deepcopy(sensor2), label='Sensor Node 2') +sensornode2.sensor.clutter_model.distribution = \ + sensornode2.sensor.clutter_model.random_state.uniform + +f1_tracker = copy.deepcopy(track_tracker) +f1_fq = FusionQueue() +f1_tracker.detector = Tracks2GaussianDetectionFeeder(f1_fq) +fusion_node1 = FusionNode(tracker=f1_tracker, fusion_queue=f1_fq, label='Fusion Node 1') + +f2_tracker = copy.deepcopy(track_tracker) +f2_fq = FusionQueue() +f2_tracker.detector = Tracks2GaussianDetectionFeeder(f2_fq) +fusion_node2 = FusionNode(tracker=f2_tracker, fusion_queue=f2_fq, label='Fusion Node 2') + +# %% +# Edges +# ^^^^^ +# Here we define the set of edges for the non-hierarchical (NH) architecture. + +from stonesoup.architecture import InformationArchitecture +from stonesoup.architecture.edge import Edge, Edges + +NH_edges = Edges([Edge((sensornode1, fusion_node1), edge_latency=0), + Edge((sensornode1, fusion_node2), edge_latency=0), + Edge((sensornode2, fusion_node2), edge_latency=0), + Edge((fusion_node2, fusion_node1), edge_latency=0)]) + +# %% +# Create the Non-Hierarchical Architecture +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The cell below should create and plot the architecture we have built. +# This architecture is at risk of data incest, due to the fact that +# information from sensor node 1 could reach Fusion Node 1 via two routes, +# while appearing to not be from the same source: +# +# * Route 1: Sensor Node 1 (S1) passes its information straight to +# Fusion Node 1 (F1) +# * Route 2: S1 also passes its information to Fusion Node 2 (F2). +# Here it is fused with information from Sensor Node 2 (S2). This +# resulting information is then passed to Fusion Node 1. +# +# Ultimately, F1 is recieving information from S1, and information from +# F2 which is based on the same information from S1. This can cause a +# bias towards the information created at S1. In this example, we would +# expect to see overconfidence in the form of unrealistically small +# uncertainty of the output tracks. + + +NH_architecture = InformationArchitecture(NH_edges, current_time=start_time, + use_arrival_time=True) +NH_architecture + +# %% +# Run the Non-Hierarchical Simulation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +for time in timesteps: + NH_architecture.measure(truths, noise=True) + NH_architecture.propagate(time_increment=1) + +# %% +# Extract all Detections that arrived at Non-Hierarchical Node C +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +NH_sensors = [] +NH_dets = set() +for sn in NH_architecture.sensor_nodes: + NH_sensors.append(sn.sensor) + for timestep in sn.data_held['created'].keys(): + for datapiece in sn.data_held['created'][timestep]: + NH_dets.add(datapiece.data) + +# %% +# Plot the tracks stored at Non-Hierarchical Node C +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +from stonesoup.plotter import Plotterly + + +def reduce_tracks(tracks): + return { + type(track)([s for s in track.last_timestamp_generator()]) + for track in tracks} + + +plotter = Plotterly() +plotter.plot_ground_truths(truths, [0, 2]) +for node in NH_architecture.fusion_nodes: + hexcol = ["#" + ''.join([random.choice('ABCDEF0123456789') for i in range(6)])] + plotter.plot_tracks(reduce_tracks(node.tracks), [0, 2], track_label=str(node.label), + line=dict(color=hexcol[0]), uncertainty=True) +plotter.plot_sensors(NH_sensors) +plotter.plot_measurements(NH_dets, [0, 2]) +plotter.fig + +# %% +# 5) Hierarchical Architecture +# ---------------------------- +# We now create an alternative architecture. We recreate the same set of +# nodes as before, but with a new edge set, which is a subset of the edge +# set used in the non-hierarchical architecture. +# +# In this architecture, by removing the edge joining sensor node 1 to +# fusion node 2, we prevent data incest by removing the second path which +# data from sensor node 1 can take to reach fusion node 1. + +# %% +# Regenerate nodes identical to those in the non-hierarchical example +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +from stonesoup.architecture.node import SensorNode, FusionNode + +sensornode1B = SensorNode(sensor=copy.deepcopy(sensor1), label='Sensor Node 1') +sensornode1B.sensor.clutter_model.distribution = \ + sensornode1B.sensor.clutter_model.random_state.uniform + +sensornode2B = SensorNode(sensor=copy.deepcopy(sensor2), label='Sensor Node 2') +sensornode2B.sensor.clutter_model.distribution = \ + sensornode2B.sensor.clutter_model.random_state.uniform + +f1_trackerB = copy.deepcopy(track_tracker) +f1_fqB = FusionQueue() +f1_trackerB.detector = Tracks2GaussianDetectionFeeder(f1_fqB) +fusion_node1B = FusionNode(tracker=f1_trackerB, fusion_queue=f1_fqB, label='Fusion Node 1') + +f2_trackerB = copy.deepcopy(track_tracker) +f2_fqB = FusionQueue() +f2_trackerB.detector = Tracks2GaussianDetectionFeeder(f2_fqB) +fusion_node2B = FusionNode(tracker=f2_trackerB, fusion_queue=f2_fqB, label='Fusion Node 2') + +# %% +# Create Edges forming a Hierarchical Architecture +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +H_edges = Edges([Edge((sensornode1B, fusion_node1B), edge_latency=0), + Edge((sensornode2B, fusion_node2B), edge_latency=0), + Edge((fusion_node2B, fusion_node1B), edge_latency=0)]) + +# %% +# Create the Hierarchical Architecture +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The only difference between the two architectures is the removal +# of the edge from Sensor Node 1 to Fusion Node 2. This change removes +# the second route for information to travel from Sensor Node 1 to +# Fusion Node 1. + +H_architecture = InformationArchitecture(H_edges, current_time=start_time, + use_arrival_time=True) +H_architecture + +# %% +# Run the Hierarchical Simulation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +for time in timesteps: + H_architecture.measure(truths, noise=True) + H_architecture.propagate(time_increment=1) + +# %% +# Extract all detections that arrived at Hierarchical Node C +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +H_sensors = [] +H_dets = set() +for sn in H_architecture.sensor_nodes: + H_sensors.append(sn.sensor) + for timestep in sn.data_held['created'].keys(): + for datapiece in sn.data_held['created'][timestep]: + H_dets.add(datapiece.data) + +# %% +# Plot the tracks stored at Hierarchical Node C +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# + +plotter = Plotterly() +plotter.plot_ground_truths(truths, [0, 2]) +for node in H_architecture.fusion_nodes: + hexcol = ["#" + ''.join([random.choice('ABCDEF0123456789') for i in range(6)])] + plotter.plot_tracks(reduce_tracks(node.tracks), [0, 2], track_label=str(node.label), + line=dict(color=hexcol[0]), uncertainty=True) +plotter.plot_sensors(H_sensors) +plotter.plot_measurements(H_dets, [0, 2]) +plotter.fig + +# %% +# Metrics +# ------- +# At a glance, the results from the hierarchical architecture look similar +# to the results from the original centralised architecture. We will now +# calculate and plot some metrics to give an insight into the differences. + +# %% +# Trace of Covariance Matrix +# ^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# A consequence of data incest in tracking is overconfidence in track states. +# In this example we would expect to see unrealistically small uncertainty in +# the tracks generated by Fusion Node 1 in the non-hierarchical architecture. +# +# To investigate this, we plot the mean trace of the covariance matrix of track +# states at each time step -- for both architectures. We should expect to see +# that the uncertainty of the non-hierarchical architecture is lower than the +# hierarchical architecture, despite both receiving an identical set of +# measurements. +# + +NH_tracks = [node.tracks for node in + NH_architecture.fusion_nodes if node.label == 'Fusion Node 1'][0] +H_tracks = [node.tracks for node in + H_architecture.fusion_nodes if node.label == 'Fusion Node 1'][0] + +NH_mean_covar_trace = [] +H_mean_covar_trace = [] + +for t in timesteps: + NH_states = sum([[state for state in track.states if state.timestamp == t] for track in + NH_tracks], []) + H_states = sum([[state for state in track.states if state.timestamp == t] for track in + H_tracks], []) + + NH_trace_mean = np.mean([np.trace(s.covar) for s in NH_states]) + H_trace_mean = np.mean([np.trace(s.covar) for s in H_states]) + + NH_mean_covar_trace.append(NH_trace_mean if not math.isnan(NH_trace_mean) else 0) + H_mean_covar_trace.append(H_trace_mean if not math.isnan(H_trace_mean) else 0) + +# %% + +plt.plot(NH_mean_covar_trace, label="Non-Hierarchical") +plt.plot(H_mean_covar_trace, label="Hierarchical") +plt.legend(loc="upper right") +plt.show() + +# As expected, the plot shows that the non-hierarchical architecture has a +# lower mean covariance trace. A naive observer may think this makes it higher +# performing, but we know that in fact it is a sign of overconfidence. \ No newline at end of file diff --git a/docs/tutorials/architecture/README.rst b/docs/tutorials/architecture/README.rst new file mode 100644 index 000000000..05495f2e4 --- /dev/null +++ b/docs/tutorials/architecture/README.rst @@ -0,0 +1,3 @@ +Architecture +------------ +Here are some tutorials which cover the use of fusion architectures in Stone Soup. diff --git a/setup.cfg b/setup.cfg index 1eba8c37d..0869823c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,10 @@ orbital = astropy mfa = ortools +architectures = + networkx + graphviz + pydot ehm = pyehm diff --git a/stonesoup/architecture/__init__.py b/stonesoup/architecture/__init__.py new file mode 100644 index 000000000..0d1abe4fa --- /dev/null +++ b/stonesoup/architecture/__init__.py @@ -0,0 +1,651 @@ +from abc import abstractmethod +from datetime import datetime, timedelta +from operator import attrgetter +from typing import List, Tuple, Collection, Set, Union, Dict +from ordered_set import OrderedSet + +import graphviz +import numpy as np +import networkx as nx +import pydot + +from .edge import Edges, DataPiece, Edge +from .node import Node, SensorNode, RepeaterNode, FusionNode +from ._functions import _default_label_gen +from ..base import Base, Property +from ..types.detection import TrueDetection, Clutter +from ..types.groundtruth import GroundTruthPath + + +class Architecture(Base): + """Abstract Architecture Base class. Subclasses must implement the + :meth:`~Architecture.propogate` method. + """ + + edges: Edges = Property( + doc="An Edges object containing all edges. For A to be connected to B we would have an " + "Edge with edge_pair=(A, B) in this object.") + current_time: datetime = Property( + default=datetime.now(), + doc="The time which the instance is at for the purpose of simulation. " + "This is increased by the propagate method. This should be set to the earliest " + "timestep from the ground truth") + name: str = Property( + default=None, + doc="A name for the architecture, to be used to name files and/or title plots. Default is " + "the class name") + force_connected: bool = Property( + default=True, + doc="If True, the undirected version of the graph must be connected, ie. all nodes should " + "be connected via some path. Set this to False to allow an unconnected architecture. " + "Default is True") + use_arrival_time: bool = Property( + default=False, + doc="If True, the timestamp on data passed around the network will not be assigned when " + "it is opened by the fusing node - simulating an architecture where time of recording " + "is not registered by the sensor nodes" + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.name: + self.name = type(self).__name__ + if not self.current_time: + self.current_time = datetime.now() + + self.di_graph = nx.to_networkx_graph(self.edges.edge_list, create_using=nx.DiGraph) + self._viz_graph = None + + if self.force_connected and not self.is_connected and len(self) > 0: + raise ValueError("The graph is not connected. Use force_connected=False, " + "if you wish to override this requirement") + + node_label_gens = {} + labels = {node.label.replace("\n", " ") for node in self.di_graph.nodes if node.label} + for node in self.di_graph.nodes: + if not node.label: + label_gen = node_label_gens.setdefault(type(node), _default_label_gen(type(node))) + while not node.label or node.label.replace("\n", " ") in labels: + node.label = next(label_gen) + self.di_graph.nodes[node].update(self._node_kwargs(node)) + + def recipients(self, node: Node): + """Returns a set of all nodes to which the input node has a direct edge to + + Args: + node (Node): Node of which to return the recipients of. + + Raises: + ValueError: Errors if given node is not in the Architecture. + + Returns: + set: Set of nodes that are recipients of the given node. + """ + if node not in self.all_nodes: + raise ValueError("Node not in this architecture") + recipients = set() + for other in self.all_nodes: + if (node, other) in self.edges.edge_list: + recipients.add(other) + return recipients + + def senders(self, node: Node): + """Returns a set of all nodes to which the input node has a direct edge from + + Args: + node (Node): Node of which to return the senders of. + + Raises: + ValueError: Errors if given node is not in the Architecture. + + Returns: + set: Set of nodes that are senders to the given node. + """ + if node not in self.all_nodes: + raise ValueError("Node not in this architecture") + senders = set() + for other in self.all_nodes: + if (other, node) in self.edges.edge_list: + senders.add(other) + return senders + + @property + def shortest_path_dict(self): + """Returns a dictionary where dict[key1][key2] gives the distance of the shortest path + from node1 to node2 if key1=node1 and key2=node2. If no path exists from node1 to node2, + a KeyError is raised. + + Returns: + dict: Nested dictionary where dict[node1][node2] gives the distance of the shortest + path from node1 to node2. + """ + # Cannot use self.di_graph as it is not adjusted when edges are removed after + # instantiation of architecture. + g = nx.DiGraph() + for edge in self.edges.edge_list: + g.add_edge(edge[0], edge[1]) + path = nx.all_pairs_shortest_path_length(g) + dpath = {x[0]: x[1] for x in path} + return dpath + + @property + def top_level_nodes(self): + """Returns a list of 'top level nodes' - These are nodes with no recipients. E.g. the + single node at the top of a hierarchical architecture. + + Returns: + set: Set of nodes that have no recipients. + """ + top_nodes = set() + for node in self.all_nodes: + if len(self.recipients(node)) == 0: + top_nodes.add(node) + + return top_nodes + + def number_of_leaves(self, node: Node): + """Returns the number of leaf nodes which are connected to the node given as a parameter + by apath from the leaf node to the parameter node. + + Args: + node (Node): Node of which to calculate number of leaf nodes. + + Returns: + int: Number of leaf nodes that are connected to a given node. + """ + node_leaves = set() + non_leaves = 0 + + for leaf_node in self.leaf_nodes: + try: + shortest_path = self.shortest_path_dict[leaf_node][node] + if node != leaf_node or shortest_path != 0: + node_leaves.add(leaf_node) + else: + return 1 + except KeyError: + non_leaves += 1 + + return len(node_leaves) + + @property + def leaf_nodes(self): + """Returns all the nodes in the :class:`Architecture` which have no sender nodes. i.e. + all nodes that do not receive any data from other nodes. + + Returns: + set: Set of all leaf nodes that exist in the Architecture + """ + leaf_nodes = set() + for node in self.all_nodes: + if len(self.senders(node)) == 0: + # This must be a leaf node + leaf_nodes.add(node) + return leaf_nodes + + @abstractmethod + def propagate(self, time_increment: float): + raise NotImplementedError + + @property + def all_nodes(self): + """Returns a set of all Nodes in the :class:`Architecture`. + + Returns: + set: Set of all nodes in the Architecture + """ + return set(self.di_graph.nodes) + + @property + def sensor_nodes(self): + """Returns a set of all SensorNodes in the :class:`Architecture`. + + Returns: + set: Set of nodes in the Architecture that have a Sensor. + """ + sensor_nodes = set() + for node in self.all_nodes: + if isinstance(node, SensorNode): + sensor_nodes.add(node) + return sensor_nodes + + @property + def fusion_nodes(self): + """Returns a set of all FusionNodes in the :class:`Architecture`. + + Returns: + set: Set of nodes in the Architecture that perform data fusion. + """ + fusion = set() + for node in self.all_nodes: + if isinstance(node, FusionNode): + fusion.add(node) + return fusion + + @property + def repeater_nodes(self): + """Returns a set of all RepeaterNodes in the :class:`Architecture`. + + Returns: + set: Set of nodes in the Architecture whose only role is to link two other nodes + together. + """ + + repeater_nodes = set() + for node in self.all_nodes: + if isinstance(node, RepeaterNode): + repeater_nodes.add(node) + return repeater_nodes + + @staticmethod + def _node_kwargs(node, use_position=True): + node_kwargs = { + 'label': node.label, + 'shape': node.shape, + 'color': node.colour, + } + if node.font_size: + node_kwargs['fontsize'] = node.font_size + if node.node_dim: + node_kwargs['width'] = node.node_dim[0] + node_kwargs['height'] = node.node_dim[1] + if use_position and node.position: + if not isinstance(node.position, Tuple): + raise TypeError("Node position, must be Sequence of length 2") + node_kwargs["pos"] = f"{node.position[0]},{node.position[1]}!" + return node_kwargs + + def plot(self, use_positions=False, plot_title=False, + bgcolour="transparent", node_style="filled", font_name='helvetica', plot_style=None): + """Creates a pdf plot of the directed graph and displays it + + :param use_positions: + :param plot_title: If a string is supplied, makes this the title of the plot. If True, uses + the name attribute of the graph to title the plot. If False, no title is used. + Default is False + :param bgcolour: String containing the background colour for the plot. + Default is "transparent". See graphviz attributes for more information. + One alternative is "white" + :param node_style: String containing the node style for the plot. + Default is "filled". See graphviz attributes for more information. + One alternative is "solid". + :param plot_style: String providing a style to be used to plot the graph. Currently only + one option for plot style given by plot_style = 'hierarchical'. + """ + is_hierarchical = self.is_hierarchical or plot_style == 'hierarchical' + if is_hierarchical: + # Find top node and assign location + top_nodes = self.top_level_nodes + if len(top_nodes) == 1: + top_node = top_nodes.pop() + else: + raise ValueError("Graph with more than one top level node provided.") + + # Initialise a layer count + node_layers = [[top_node]] + processed_nodes = {top_node} + while self.all_nodes - processed_nodes: + senders = [ + sender + for node in node_layers[-1] + for sender in sorted(self.senders(node), key=attrgetter('label'))] + if not senders: + break + else: + node_layers.append(senders) + processed_nodes.update(senders) + + strict = nx.number_of_selfloops(self.di_graph) == 0 and not self.di_graph.is_multigraph() + graph = pydot.Dot(graph_name='', strict=strict, graph_type='digraph', rankdir='BT') + if isinstance(plot_title, str): + graph.set_graph_defaults(label=plot_title, labelloc='t') + elif isinstance(plot_title, bool) and plot_title: + graph.set_graph_defaults(label=self.name, labelloc='t') + elif not isinstance(plot_title, bool): + raise ValueError("Plot title must be a string or bool") + graph.set_graph_defaults(bgcolor=bgcolour) + graph.set_node_defaults(fontname=font_name, style=node_style) + + if is_hierarchical: + for n, layer_nodes in enumerate(node_layers): + subgraph = pydot.Subgraph(rank='max' if n == 0 else 'same') + for node in layer_nodes: + new_node = pydot.Node( + node.label.replace("\n", " "), **self._node_kwargs(node, use_positions)) + subgraph.add_node(new_node) + graph.add_subgraph(subgraph) + else: + graph.set_overlap('false') + for node in self.all_nodes: + new_node = pydot.Node( + node.label.replace("\n", " "), **self._node_kwargs(node, use_positions)) + graph.add_node(new_node) + + for edge in self.edges.edge_list: + new_edge = pydot.Edge( + edge[0].label.replace("\n", " "), edge[1].label.replace("\n", " ")) + graph.add_edge(new_edge) + + viz_graph = graphviz.Source( + graph.to_string(), engine='dot' if is_hierarchical else 'neato') + self._viz_graph = viz_graph + return viz_graph + + def _repr_html_(self): + if getattr(self, '_viz_graph', None) is None: + self.plot() + return self._viz_graph._repr_image_svg_xml() + + @property + def density(self): + """Returns the density of the graph, i.e. the proportion of possible edges between nodes + that exist in the graph""" + num_nodes = len(self.all_nodes) + num_edges = len(self.edges) + architecture_density = num_edges / ((num_nodes * (num_nodes - 1)) / 2) + return architecture_density + + @property + def is_hierarchical(self): + """Returns `True` if the :class:`Architecture` is hierarchical, otherwise `False`. Uses + the following logic: An architecture is hierarchical if and only if there exists only + one node with 0 recipients and all other nodes have exactly 1 recipient.""" + top_nodes = self.top_level_nodes + if len(self.top_level_nodes) != 1: + return False + for node in self.all_nodes: + if node not in top_nodes and len(self.recipients(node)) != 1: + return False + return True + + @property + def is_centralised(self): + """ + Returns 'True' if the :class:`Architecture` is centralised, otherwise 'False'. + Uses the following logic: An architecture is centralised if and only if there exists only + one node with 0 recipients, and there exists a path to this node from every other node in + the architecture. + """ + top_nodes = self.top_level_nodes + if len(top_nodes) != 1: + return False + top_node = top_nodes.pop() + for node in self.all_nodes - {top_node}: + try: + _ = self.shortest_path_dict[node][top_node] + except KeyError: + return False + return True + + @property + def is_connected(self): + """Property of Architecture class stating whether the graph is connected or not. + + Returns: + bool: Returns True if graph is connected, otherwise False. + """ + return nx.is_connected(self.to_undirected) + + @property + def to_undirected(self): + """Returns an undirected version of self.digraph + + Returns: + _type_: _description_ + """ + return self.di_graph.to_undirected() + + def __len__(self): + return len(self.di_graph) + + @property + def fully_propagated(self): + """Checks if all data for each node have begun transfer + to its recipients. With zero latency, this should be the case after running propagate""" + for edge in self.edges.edges: + if len(edge.unsent_data) != 0: + return False + elif len(edge.unpassed_data) != 0: + return False + + return True + + +class NonPropagatingArchitecture(Architecture): + """ + A simple Architecture class that does not simulate propagation of any data. Can be used for + performing network operations on an :class:`~.Edges` object. + """ + def propagate(self, time_increment: float): + pass + + +class InformationArchitecture(Architecture): + """The architecture for how information is shared through the network. Node A is " + "connected to Node B if and only if the information A creates by processing and/or " + "sensing is received and opened by B without modification by another node. """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if any([isinstance(node, RepeaterNode) for node in self.all_nodes]): + raise TypeError("Information architecture should not contain any repeater " + "nodes") + + def measure(self, ground_truths: List[GroundTruthPath], noise: Union[bool, np.ndarray] = True, + **kwargs) -> Dict[SensorNode, Set[Union[TrueDetection, Clutter]]]: + """ Similar to the method for :class:`~.SensorSuite`. Updates each node. """ + all_detections = dict() + + # Filter out only the ground truths that have already happened at self.current_time + current_ground_truths = OrderedSet() + for ground_truth_path in ground_truths: + available_gtp = GroundTruthPath(ground_truth_path[:self.current_time + + timedelta(microseconds=1)]) + if len(available_gtp) > 0: + current_ground_truths.add(available_gtp) + + for sensor_node in self.sensor_nodes: + all_detections[sensor_node] = set() + for detection in sensor_node.sensor.measure(current_ground_truths, noise, **kwargs): + all_detections[sensor_node].add(detection) + + for data in all_detections[sensor_node]: + # The sensor acquires its own data instantly + sensor_node.update(data.timestamp, data.timestamp, + DataPiece(sensor_node, sensor_node, data, data.timestamp), + 'created') + + return all_detections + + def propagate(self, time_increment: float, failed_edges: Collection = None): + """Performs the propagation of the measurements through the network""" + # Update each edge with messages received/sent + for edge in self.edges.edges: + # TODO: Future work - Introduce failed edges functionality + # if failed_edges and edge in failed_edges: + # edge._failed(self.current_time, time_increment) + # continue # No data passed along these edges. + + # Initial update of message categories + edge.update_messages(self.current_time, use_arrival_time=self.use_arrival_time) + for data_piece, time_pertaining in edge.unsent_data: + edge.send_message(data_piece, time_pertaining, data_piece.time_arrived) + + # Need to re-run update messages so that messages aren't left as 'pending' + edge.update_messages(self.current_time, use_arrival_time=self.use_arrival_time) + + # Need to re-run update messages so that messages aren't left as 'pending' + edge.update_messages(self.current_time, + use_arrival_time=self.use_arrival_time) + + for fuse_node in self.fusion_nodes: + fuse_node.fuse() + + if self.fully_propagated: + self.current_time += timedelta(seconds=time_increment) + return + else: + self.propagate(time_increment, failed_edges) + + +class NetworkArchitecture(Architecture): + """The architecture for how data is propagated through the network. Node A is connected " + "to Node B if and only if A sends its data through B. """ + information_arch: InformationArchitecture = Property(default=None) + information_architecture_edges: Edges = Property(default=None) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Check whether an InformationArchitecture is provided, if not, see if one can be created + if self.information_arch is None: + + # If info edges are provided, we can deduce an information architecture, otherwise: + if self.information_architecture_edges is None: + + # If repeater nodes are present in the Network architecture, we can deduce an + # Information architecture + if len(self.repeater_nodes) > 0: + self.information_architecture_edges = Edges(inherit_edges(Edges(self.edges))) + self.information_arch = InformationArchitecture( + edges=self.information_architecture_edges, current_time=self.current_time) + else: + self.information_arch = InformationArchitecture(self.edges, self.current_time) + else: + self.information_arch = InformationArchitecture( + edges=self.information_architecture_edges, current_time=self.current_time) + + # Need to reset digraph for info-arch + self.di_graph = nx.to_networkx_graph(self.edges.edge_list, create_using=nx.DiGraph) + # Set attributes such as label, colour, shape, etc. for each node + for node in self.di_graph.nodes: + self.di_graph.nodes[node].update(self._node_kwargs(node)) + + def measure(self, ground_truths: List[GroundTruthPath], noise: Union[bool, np.ndarray] = True, + **kwargs) -> Dict[SensorNode, Set[Union[TrueDetection, Clutter]]]: + """ Similar to the method for :class:`~.SensorSuite`. Updates each node. """ + all_detections = dict() + + # Filter out only the ground truths that have already happened at self.current_time + current_ground_truths = set() + for ground_truth_path in ground_truths: + available_gtp = GroundTruthPath(ground_truth_path[:self.current_time + + timedelta(seconds=1e-6)]) + if len(available_gtp) > 0: + current_ground_truths.add(available_gtp) + + for sensor_node in self.sensor_nodes: + all_detections[sensor_node] = set() + for detection in sensor_node.sensor.measure(current_ground_truths, noise, **kwargs): + all_detections[sensor_node].add(detection) + + for data in all_detections[sensor_node]: + # The sensor acquires its own data instantly + sensor_node.update(data.timestamp, data.timestamp, + DataPiece(sensor_node, sensor_node, data, data.timestamp), + 'created') + + return all_detections + + def propagate(self, time_increment: float, failed_edges: Collection = None): + """Performs the propagation of the measurements through the network""" + # Update each edge with messages received/sent + for edge in self.edges.edges: + # TODO: Future work - Introduce failed edges functionality + # if failed_edges and edge in failed_edges: + # edge._failed(self.current_time, time_increment) + # continue # No data passed along these edges + + # Initial update of message categories + if edge.recipient not in self.information_arch.all_nodes: + edge.update_messages(self.current_time, to_network_node=True, + use_arrival_time=self.use_arrival_time) + else: + edge.update_messages(self.current_time, use_arrival_time=self.use_arrival_time) + + # Send available messages from nodes to the edges + if edge.sender in self.information_arch.all_nodes: + for data_piece, time_pertaining in edge.unsent_data: + edge.send_message(data_piece, time_pertaining, data_piece.time_arrived) + else: + for message in edge.sender.messages_to_pass_on: + if edge.recipient not in message.data_piece.sent_to: + edge.pass_message(message) + + # Need to re-run update messages so that messages aren't left as 'pending' + if edge.recipient not in self.information_arch.all_nodes: + edge.update_messages(self.current_time, to_network_node=True, + use_arrival_time=self.use_arrival_time) + else: + edge.update_messages(self.current_time, use_arrival_time=self.use_arrival_time) + + for fuse_node in self.fusion_nodes: + fuse_node.fuse() + + if self.fully_propagated: + self.current_time += timedelta(seconds=time_increment) + return + else: + self.propagate(time_increment, failed_edges) + + +def inherit_edges(network_architecture): + """ + Utility function that takes a NetworkArchitecture object and infers what the overlaying + InformationArchitecture graph would be. + + :param network_architecture: A NetworkArchitecture object + :return: A list of edges. + """ + + edges = list() + for edge in network_architecture.edges: + edges.append(edge) + temp_arch = NonPropagatingArchitecture(edges=Edges(edges)) + + # Iterate through repeater nodes in the Network Architecture to find edges to remove + for repeaternode in temp_arch.repeater_nodes: + to_replace = list() + to_add = list() + + senders = temp_arch.senders(repeaternode) + recipients = temp_arch.recipients(repeaternode) + + # Find all edges that pass data to the repeater node + for sender in senders: + edges = temp_arch.edges.get((sender, repeaternode)) + to_replace += edges + + # Find all edges that pass data from the repeater node + for recipient in recipients: + edges = temp_arch.edges.get((repeaternode, recipient)) + to_replace += edges + + # Create a new edge from every sender to every recipient + for sender in senders: + for recipient in recipients: + + # Could be possible edges from sender to node, choose path of minimum latency + poss_edges_to = temp_arch.edges.get((sender, repeaternode)) + latency_to = np.inf + for edge in poss_edges_to: + latency_to = edge.edge_latency if edge.edge_latency <= latency_to else \ + latency_to + + # Could be possible edges from node to recipient, choose path of minimum latency + poss_edges_from = temp_arch.edges.get((sender, repeaternode)) + latency_from = np.inf + for edge in poss_edges_from: + latency_from = edge.edge_latency if edge.edge_latency <= latency_from else \ + latency_from + + latency = latency_to + latency_from + repeaternode.latency + edge = Edge(nodes=(sender, recipient), edge_latency=latency) + to_add.append(edge) + + for edge in to_replace: + temp_arch.edges.remove(edge) + for edge in to_add: + temp_arch.edges.add(edge) + return temp_arch.edges diff --git a/stonesoup/architecture/_functions.py b/stonesoup/architecture/_functions.py new file mode 100644 index 000000000..3e63b56c9 --- /dev/null +++ b/stonesoup/architecture/_functions.py @@ -0,0 +1,34 @@ +from itertools import count, product +from string import ascii_uppercase as auc + + +def _dict_set(my_dict, value, key1, key2=None): + """Utility function to add value to my_dict at the specified key(s) + Returns True if the set increased in size, i.e. the value was new to its position""" + if not my_dict: + if key2: + my_dict = {key1: {key2: {value}}} + else: + my_dict = {key1: {value}} + elif key2: + if key1 in my_dict: + if key2 in my_dict[key1]: + old_len = len(my_dict[key1][key2]) + my_dict[key1][key2].add(value) + return len(my_dict[key1][key2]) == old_len + 1, my_dict + else: + my_dict[key1][key2] = {value} + else: + my_dict[key1] = {key2: {value}} + else: + if key1 in my_dict: + old_len = len(my_dict[key1]) + my_dict[key1].add(value) + return len(my_dict[key1]) == old_len + 1, my_dict + else: + my_dict[key1] = {value} + return True, my_dict + + +def _default_label_gen(type_): + return (f"{type_.__name__}\n{''.join(c)}" for n in count(1) for c in product(auc, repeat=n)) diff --git a/stonesoup/architecture/edge.py b/stonesoup/architecture/edge.py new file mode 100644 index 000000000..76916ad00 --- /dev/null +++ b/stonesoup/architecture/edge.py @@ -0,0 +1,326 @@ +import copy +from collections.abc import Collection +from typing import Union, Tuple, List, Set, TYPE_CHECKING +from numbers import Number +from datetime import datetime, timedelta +from queue import Queue + +from ..base import Base, Property +from ..types.time import TimeRange, CompoundTimeRange +from ..types.track import Track +from ..types.detection import Detection +from ..types.hypothesis import Hypothesis +from ._functions import _dict_set + +if TYPE_CHECKING: + from .node import Node + + +class FusionQueue(Queue): + """A queue from which fusion nodes draw data they have yet to fuse + + Iterable, where it blocks attempting to yield items on the queue + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._to_consume = 0 + self._consuming = False + self.received = set() + + def _put(self, *args, **kwargs): + super()._put(*args, **kwargs) + self._to_consume += 1 + + def __iter__(self): + self._consuming = True + while True: + yield super().get() + self._to_consume -= 1 + + @property + def waiting_for_data(self): + return self._consuming and not self._to_consume + + def get(self, *args, **kwargs): + raise NotImplementedError("Getting items from queue must use iteration") + + +class DataPiece(Base): + """A piece of data for use in an architecture. Sent via a :class:`~.Message`, + and stored in a Node's :attr:`data_held`""" + node: "Node" = Property( + doc="The Node this data piece belongs to") + originator: "Node" = Property( + doc="The node which first created this data, ie by sensing or fusing information together." + " If the data is simply passed along the chain, the originator remains unchanged. ") + data: Union[Detection, Track, Hypothesis] = Property( + doc="A Detection, Track, or Hypothesis") + time_arrived: datetime = Property( + doc="The time at which this piece of data was received by the Node, either by Message or " + "by sensing.") + track: Track = Property( + doc="The Track in the event of data being a Hypothesis", + default=None) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sent_to = set() # all Nodes the data_piece has been sent to, to avoid duplicates + + +class Edge(Base): + """Comprised of two connected :class:`~.Node`s""" + nodes: Tuple["Node", "Node"] = Property(doc="A pair of nodes in the form (sender, recipient)") + edge_latency: float = Property(doc="The latency stemming from the edge itself, " + "and not either of the nodes", + default=0.0) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not isinstance(self.edge_latency, Number): + raise TypeError(f"edge_latency should be a float, not a {type(self.edge_latency)}") + self.messages_held = {"pending": {}, # For pending, messages indexed by time sent. + "received": {}} # For received, by time received + self.time_range_failed = CompoundTimeRange() # Times during which this edge was failed + + def send_message(self, data_piece, time_pertaining, time_sent): + """ + Takes a piece of data retrieved from the edge's sender node, and propagates it + along the edge + :param data_piece: DataPiece object pulled from the edge's sender. + :param time_pertaining: The latest time for which the data pertains. For a Detection, this + would be the time of the Detection, or for a Track this is the time of the last State in + the Track + :param time_sent: Time at which the message was sent + """ + if not isinstance(data_piece, DataPiece): + raise TypeError(f"data_piece is type {type(data_piece)}. Expected DataPiece") + message = Message(edge=self, time_pertaining=time_pertaining, time_sent=time_sent, + data_piece=data_piece, destinations={self.recipient}) + _, self.messages_held = _dict_set(self.messages_held, message, 'pending', time_sent) + # ensure message not re-sent + data_piece.sent_to.add(self) + + def pass_message(self, message): + """ + Takes a message from a Node's 'messages_to_pass_on' store and propagates them to the + relevant edges. + :param message: :class:`~.Message` to propagate + """ + message_copy = copy.copy(message) + message_copy.edge = self + if message_copy.destinations == {self.sender} or message.destinations is None: + message_copy.destinations = {self.recipient} + _, self.messages_held = _dict_set(self.messages_held, message_copy, 'pending', + message_copy.time_sent) + # Message not opened by repeater node, remove node from 'sent_to' + message_copy.data_piece.sent_to.add(self) + + def update_messages(self, current_time, to_network_node=False, use_arrival_time=False): + """ + Updates the category of messages stored in edge.messages_held if latency time has passed. + Adds messages that have 'arrived' at recipient to the relevant holding area of the node. + :param use_arrival_time: Bool that is True if arriving data should use arrival time as + it's timestamp + :param current_time: Current time in simulation + :param to_network_node: Bool that is true if recipient node is not in the information + architecture + """ + # Check info type is what we expect + to_remove = set() # Needed as we can't change size of a set during iteration + for time in self.messages_held['pending']: + for message in self.messages_held['pending'][time]: + message.update(current_time) + if message.status == 'received': + # Then the latency has passed and message has been received + # Move message from pending to received messages in edge + to_remove.add((time, message)) + _, self.messages_held = _dict_set(self.messages_held, message, + 'received', message.arrival_time) + + # Assign destination as recipient of edge if no destination provided + if message.destinations is None: + message.destinations = {self.recipient} + + # Update node according to inclusion in Information Architecture + if not to_network_node and message.destinations == {self.recipient}: + # Add data to recipient's data_held + self.recipient.update(message.time_pertaining, + message.arrival_time, + message.data_piece, "unfused", + use_arrival_time=use_arrival_time) + + elif not to_network_node and self.recipient in message.destinations: + # Add data to recipient's data held, and message to messages_to_pass_on + self.recipient.update(message.time_pertaining, + message.arrival_time, + message.data_piece, "unfused", + use_arrival_time=use_arrival_time) + message.destinations = None + self.recipient.messages_to_pass_on.append(message) + + elif to_network_node or self.recipient not in message.destinations: + # Add message to recipient's messages_to_pass_on + message.destinations = None + self.recipient.messages_to_pass_on.append(message) + + for time, message in to_remove: + self.messages_held['pending'][time].remove(message) + if len(self.messages_held['pending'][time]) == 0: + del self.messages_held['pending'][time] + + def failed(self, current_time, delta): + """"Keeps track of when this edge was failed using the time_ranges_failed property. + Delta should be a timedelta instance""" + end_time = current_time + delta + self.time_range_failed.add(TimeRange(current_time, end_time)) + + @property + def sender(self): + return self.nodes[0] + + @property + def recipient(self): + return self.nodes[1] + + @property + def ovr_latency(self): + """Overall latency of this :class:`~.Edge`""" + return self.sender.latency + self.edge_latency + + @property + def unpassed_data(self): + unpassed = [] + for message in self.sender.messages_to_pass_on: + if self not in message.data_piece.sent_to: + unpassed.append(message) + return unpassed + + @property + def unsent_data(self): + """Data modified by the sender that has not been sent to the + recipient.""" + unsent = [] + if isinstance(type(self.sender.data_held), type(None)) or self.sender.data_held is None: + return unsent + else: + for status in ["fused", "created"]: + for time_pertaining in self.sender.data_held[status]: + for data_piece in self.sender.data_held[status][time_pertaining]: + # Data will be sent to any nodes it hasn't been sent to before + if self not in data_piece.sent_to: + unsent.append((data_piece, time_pertaining)) + return unsent + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + return all(getattr(self, name) == getattr(other, name) for name in type(self).properties) + + def __hash__(self): + return hash(tuple(getattr(self, name) for name in type(self).properties)) + + +class Edges(Base, Collection): + """Container class for :class:`~.Edge`""" + edges: List[Edge] = Property(doc="List of Edge objects", default=None) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.edges is None: + self.edges = [] + + def __iter__(self): + return self.edges.__iter__() + + def __contains__(self, item): + return item in self.edges + + def add(self, edge): + self.edges.append(edge) + + def remove(self, edge): + self.edges.remove(edge) + + def get(self, node_pair): + from .node import Node + if not (isinstance(node_pair, Tuple) and + all(isinstance(node, Node) for node in node_pair)): + raise TypeError("Must supply a tuple of nodes") + if not len(node_pair) == 2: + raise ValueError("Incorrect tuple length. Must be of length 2") + edges = list() + for edge in self.edges: + if edge.nodes == node_pair: + edges.append(edge) + return edges + + @property + def edge_list(self): + """Returns a list of tuples in the form (sender, recipient)""" + if not self.edges: + return [] + return [edge.nodes for edge in self.edges] + + def __len__(self): + return len(self.edges) + + +class Message(Base): + """A message, containing a piece of information, that gets propagated between two Nodes. + Messages are opened by nodes that are a recipient of the node that sent the message""" + edge: Edge = Property( + doc="The directed edge containing the sender and receiver of the message") + time_pertaining: datetime = Property( + doc="The latest time for which the data pertains. For a Detection, this would be the time " + "of the Detection, or for a Track this is the time of the last State in the Track. " + "Different from time_sent when data is passed on that was not generated by the " + "sender") + time_sent: datetime = Property( + doc="Time at which the message was sent") + data_piece: DataPiece = Property( + doc="Info that the sent message contains") + destinations: Set["Node"] = Property(doc="Nodes in the information architecture that the " + "message is being sent to", + default=None) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.status = "sending" + + @property + def sender_node(self): + return self.edge.sender + + @property + def recipient_node(self): + return self.edge.recipient + + @property + def arrival_time(self): + # TODO: incorporate failed time ranges here. + # Not essential for a first PR. Could do with merging of PR #664 + return self.time_sent + timedelta(seconds=self.edge.ovr_latency) + + def update(self, current_time): + progress = (current_time - self.time_sent).total_seconds() + if progress < 0: + raise ValueError("Current time cannot be before the Message was sent") + if progress < self.edge.sender.latency: + self.status = "sending" + elif progress < self.edge.ovr_latency: + self.status = "transferring" + else: + self.status = "received" + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, name) == getattr(other, name) + for name in type(self).properties + if name not in ['destinations', 'edge']) + + def __hash__(self): + return hash(tuple(getattr(self, name) + for name in type(self).properties + if name not in ['destinations', 'edge'])) diff --git a/stonesoup/architecture/generator.py b/stonesoup/architecture/generator.py new file mode 100644 index 000000000..f76239efb --- /dev/null +++ b/stonesoup/architecture/generator.py @@ -0,0 +1,285 @@ +import copy +import random + +from datetime import datetime + +import networkx as nx +import numpy as np + +from stonesoup.architecture import SensorNode, FusionNode, Edge, Edges, InformationArchitecture, \ + NetworkArchitecture +from stonesoup.architecture.edge import FusionQueue +from stonesoup.architecture.node import SensorFusionNode, RepeaterNode +from stonesoup.base import Base, Property +from stonesoup.feeder.track import Tracks2GaussianDetectionFeeder +from stonesoup.sensor.sensor import Sensor +from stonesoup.tracker import Tracker + + +class InformationArchitectureGenerator(Base): + """ + Class that can be used to generate InformationArchitecture classes given a set of input + parameters. + """ + arch_type: str = Property( + doc="Type of architecture to be modelled. Currently only 'hierarchical' and " + "'decentralised' are supported.", + default='decentralised') + start_time: datetime = Property( + doc="Start time of simulation to be passed to the Architecture class.", + default=datetime.now()) + node_ratio: tuple = Property( + doc="Tuple containing the number of each type of node, in the order of (sensor nodes, " + "sensor fusion nodes, fusion nodes).", + default=None) + mean_degree: float = Property( + doc="Average (mean) degree of nodes in the network.", + default=None) + base_sensor: Sensor = Property( + doc="Sensor class object that will be duplicated to create multiple sensors. Position of " + "this sensor is used with 'sensor_max_distance' to calculate a position for " + "duplicated sensors.", + default=None) + sensor_max_distance: tuple = Property( + doc="Max distance each sensor can be from base_sensor.position. Should be a tuple of " + "length equal to len(base_sensor.position_mapping)", + default=None) + base_tracker: Tracker = Property( + doc="Tracker class object that will be duplicated to create multiple trackers. " + "Should have detector=None.", + default=None) + iteration_limit: int = Property( + doc="Limit for the number of iterations the generate_edgelist() method can make when " + "attempting to build a suitable graph.", + default=10000) + allow_invalid_graph: bool = Property( + doc="Bool where True allows invalid graphs to be returned without throwing an error. " + "False by default", + default=False) + n_archs: int = Property( + doc="Tuple containing a minimum and maximum value for the number of routes created in the " + "network architecture to represent a single edge in the information architecture.", + default=2) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_nodes = sum(self.node_ratio) + self.n_sensor_nodes = self.node_ratio[0] + self.n_fusion_nodes = self.node_ratio[2] + self.n_sensor_fusion_nodes = self.node_ratio[1] + + self.n_edges = int(np.ceil(self.n_nodes * self.mean_degree * 0.5)) + + if self.sensor_max_distance is None: + self.sensor_max_distance = tuple(np.zeros(len(self.base_sensor.position_mapping))) + if self.arch_type not in ['decentralised', 'hierarchical']: + raise ValueError('arch_style must be "decentralised" or "hierarchical"') + + def generate(self): + edgelist, node_labels = self._generate_edgelist() + + nodes = self._assign_nodes(node_labels) + + archs = list() + for architecture in nodes.keys(): + arch = self._generate_architecture(nodes[architecture], edgelist) + archs.append(arch) + return archs + + def _generate_architecture(self, nodes, edgelist): + edges = [] + for t in edgelist: + edge = Edge((nodes[t[0]], nodes[t[1]])) + edges.append(edge) + + arch_edges = Edges(edges) + + arch = InformationArchitecture(arch_edges, self.start_time) + return arch + + def _assign_nodes(self, node_labels): + + nodes = {} + for architecture in range(self.n_archs): + nodes[architecture] = {} + + for label in node_labels: + + if label.startswith('f'): + for architecture in range(self.n_archs): + t = copy.deepcopy(self.base_tracker) + fq = FusionQueue() + t.detector = Tracks2GaussianDetectionFeeder(fq) + + node = FusionNode(tracker=t, + fusion_queue=fq, + label=label, + latency=0) + + nodes[architecture][label] = node + + elif label.startswith('sf'): + + pos = np.array( + [[p + random.uniform(-d, d)] for p, d in zip(self.base_sensor.position, + self.sensor_max_distance)]) + + for architecture in range(self.n_archs): + s = copy.deepcopy(self.base_sensor) + s.position = pos + + t = copy.deepcopy(self.base_tracker) + fq = FusionQueue() + t.detector = Tracks2GaussianDetectionFeeder(fq) + + node = SensorFusionNode(sensor=s, + tracker=t, + fusion_queue=fq, + label=label, + latency=0) + + nodes[architecture][label] = node + + elif label.startswith('s'): + pos = np.array( + [[p + random.uniform(-d, d)] for p, d in zip(self.base_sensor.position, + self.sensor_max_distance)]) + for architecture in range(self.n_archs): + s = copy.deepcopy(self.base_sensor) + s.position = pos + + node = SensorNode(sensor=s, + label=label, + latency=0) + nodes[architecture][label] = node + + else: + for architecture in range(self.n_archs): + node = RepeaterNode(label=label, + latency=0) + nodes[architecture][label] = node + + return nodes + + def _generate_edgelist(self): + + edges = [] + + nodes = ['f' + str(i) for i in range(self.n_fusion_nodes)] + \ + ['sf' + str(i) for i in range(self.n_sensor_fusion_nodes)] + \ + ['s' + str(i) for i in range(self.n_sensor_nodes)] + + valid = False + + if self.arch_type == 'hierarchical': + while not valid: + edges = [] + n = self.n_fusion_nodes + self.n_sensor_fusion_nodes + + for i, node in enumerate(nodes): + if i == 0: + continue + elif i == 1: + source = nodes[0] + target = nodes[1] + else: + if node.startswith('s') and not node.startswith('sf'): + source = node + target = nodes[random.randint(0, n - 1)] + else: + source = node + target = nodes[random.randint(0, i - 1)] + + # Create edge + edge = (source, target) + edges.append(edge) + + # Logic checks on graph + g = nx.DiGraph(edges) + for f_node in ['f' + str(i) for i in range(self.n_fusion_nodes)]: + if g.in_degree(f_node) == 0: + break + else: + valid = True + + else: + + while not valid: + + for i in range(1, self.n_edges + 1): + source = target = -1 + if i < self.n_nodes: + source = nodes[i] + target = nodes[random.randint( + 0, min(i - 1, len(nodes) - self.n_sensor_nodes - 1))] + + else: + while source == target \ + or (source, target) in edges \ + or (target, source) in edges: + source = nodes[random.randint(0, len(nodes) - 1)] + target = nodes[random.randint(0, len(nodes) - self.n_sensor_nodes - 1)] + + edges.append((source, target)) + + # Logic checks on graph + g = nx.DiGraph(edges) + for f_node in ['f' + str(i) for i in range(self.n_fusion_nodes)]: + if g.in_degree(f_node) == 0: + break + else: + valid = True + + return edges, nodes + + +class NetworkArchitectureGenerator(InformationArchitectureGenerator): + """ + Class that can be used to generate NetworkArchitecture classes given a set of input + parameters. + """ + n_routes: tuple = Property( + doc="Tuple containing a minimum and maximum value for the number of routes created in the " + "network architecture to represent a single edge in the information architecture.", + default=(1, 2)) + + def generate(self): + edgelist, node_labels = self._generate_edgelist() + + edgelist, node_labels = self._add_network(edgelist, node_labels) + + nodes = self._assign_nodes(node_labels) + + archs = list() + for architecture in nodes.keys(): + arch = self._generate_architecture(nodes[architecture], edgelist) + archs.append(arch) + return archs + + def _add_network(self, edgelist, nodes): + network_edgelist = [] + i = 0 + for e in edgelist: + # Choose number of routes between two information architecture nodes + n = self.n_routes[0] if len(self.n_routes) == 1 else \ + random.randint(self.n_routes[0], self.n_routes[1]) + + for route in range(n): + r_lab = 'r' + str(i) + network_edgelist.append((e[0], r_lab)) + network_edgelist.append((r_lab, e[1])) + nodes.append(r_lab) + i += 1 + + return network_edgelist, nodes + + def _generate_architecture(self, nodes, edgelist): + edges = [] + for t in edgelist: + edge = Edge((nodes[t[0]], nodes[t[1]])) + edges.append(edge) + + arch_edges = Edges(edges) + + arch = NetworkArchitecture(arch_edges, self.start_time) + return arch diff --git a/stonesoup/architecture/node.py b/stonesoup/architecture/node.py new file mode 100644 index 000000000..cd77b8881 --- /dev/null +++ b/stonesoup/architecture/node.py @@ -0,0 +1,193 @@ +import copy +import threading +from datetime import datetime +from queue import Queue, Empty +from typing import Tuple + +from ..base import Property, Base +from ..sensor.sensor import Sensor +from ..types.detection import Detection +from ..types.hypothesis import Hypothesis +from ..types.track import Track +from .edge import DataPiece, FusionQueue +from ..tracker.base import Tracker +from ._functions import _dict_set + + +class Node(Base): + """Base Node class. Generally a subclass should be used. Note that most user-defined + properties are for graphical use only, all with default values. """ + latency: float = Property( + doc="Contribution to edge latency stemming from this node. Default is 0.0", + default=0.0) + label: str = Property( + doc="Label to be displayed on graph. Default is to label by class and then " + "differentiate via alphabetical labels", + default=None) + position: Tuple[float] = Property( + default=None, + doc="Cartesian coordinates for node. Determined automatically by default") + colour: str = Property( + default='#909090', + doc='Colour to be displayed on graph. Default is grey') + shape: str = Property( + default='rectangle', + doc='Shape used to display nodes. Default is a rectangle') + font_size: int = Property( + default=None, + doc='Font size for node labels. Default is None') + node_dim: tuple = Property( + default=None, + doc='Width and height of nodes for graph icons. ' + 'Default is None, which will size to label automatically.') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_held = {"fused": {}, "created": {}, "unfused": {}} + self.messages_to_pass_on = [] + + def update(self, time_pertaining, time_arrived, data_piece, category, track=None, + use_arrival_time=False): + """Updates this :class:`~.Node`'s :attr:`~.data_held` using a new data piece. """ + if not (isinstance(time_pertaining, datetime) and isinstance(time_arrived, datetime)): + raise TypeError("Times must be datetime objects") + if not isinstance(data_piece, DataPiece): + raise TypeError(f"data_piece must be a DataPiece. Provided type {type(data_piece)}") + if category not in self.data_held.keys(): + raise ValueError(f"category must be one of {self.data_held.keys()}") + if not track: + if not isinstance(data_piece.data, Detection) and \ + not isinstance(data_piece.data, Track): + raise TypeError(f"Data provided without accompanying Track must be a Detection or " + f"a Track, not a " + f"{type(data_piece.data).__name__}") + new_data_piece = DataPiece(self, data_piece.originator, data_piece.data, time_arrived) + else: + if not isinstance(data_piece.data, Hypothesis): + raise TypeError("Data provided with Track must be a Hypothesis") + new_data_piece = DataPiece(self, data_piece.originator, data_piece.data, + time_arrived, track) + + added, self.data_held[category] = _dict_set(self.data_held[category], + new_data_piece, time_pertaining) + + if use_arrival_time and isinstance(self, FusionNode) and \ + category in ("created", "unfused"): + data = copy.copy(data_piece.data) + data.timestamp = time_arrived + if data not in self.fusion_queue.received: + self.fusion_queue.received.add(data) + self.fusion_queue.put((time_pertaining, {data})) + + elif isinstance(self, FusionNode) and \ + category in ("created", "unfused") and \ + data_piece.data not in self.fusion_queue.received: + self.fusion_queue.received.add(data_piece.data) + self.fusion_queue.put((time_pertaining, {data_piece.data})) + + return added + + +class SensorNode(Node): + """A :class:`~.Node` corresponding to a :class:`~.Sensor`. Fresh data is created here""" + sensor: Sensor = Property(doc="Sensor corresponding to this node") + colour: str = Property( + default='#006eff', + doc='Colour to be displayed on graph. Default is the hex colour code #006eff') + shape: str = Property( + default='oval', + doc='Shape used to display nodes. Default is an oval') + + +class FusionNode(Node): + """A :class:`~.Node` that does not measure new data, but does process data it receives""" + # feeder probably as well + tracker: Tracker = Property( + doc="Tracker used by this Node to fuse together Tracks and Detections") + fusion_queue: FusionQueue = Property( + default=None, + doc="The queue from which this node draws data to be fused. Default is a standard " + "FusionQueue") + tracks: set = Property(default=None, + doc="Set of tracks tracked by the fusion node") + colour: str = Property( + default='#00b53d', + doc='Colour to be displayed on graph. Default is the hex colour code #00b53d') + shape: str = Property( + default='hexagon', + doc='Shape used to display nodes. Default is a hexagon') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tracks = set() # Set of tracks this Node has recorded + if not self.fusion_queue: + if self.tracker.detector: + self.fusion_queue = self.tracker.detector + else: + self.fusion_queue = FusionQueue() + self.tracker.detector = self.fusion_queue + + self._track_queue = Queue() + self._tracking_thread = threading.Thread( + target=self._track_thread, + args=(self.tracker, self.fusion_queue, self._track_queue), + daemon=True) + + def fuse(self): + if not self._tracking_thread.is_alive(): + try: + self._tracking_thread.start() + except RuntimeError: # Previously started + raise RuntimeError(f"Tracking thread in {self.label!r} unexpectedly ended") + + added = False + updated_tracks = set() + while True: + waiting_for_data = self.fusion_queue.waiting_for_data + try: + data = self._track_queue.get(timeout=1e-6) + except Empty: + if not self._tracking_thread.is_alive() or waiting_for_data: + break + else: + _, tracks = data + self.tracks.update(tracks) + updated_tracks = updated_tracks.union(tracks) + + for track in updated_tracks: + data_piece = DataPiece(self, self, copy.copy(track), track.timestamp, True) + added, self.data_held['fused'] = _dict_set( + self.data_held['fused'], data_piece, track.timestamp) + return added + + @staticmethod + def _track_thread(tracker, input_queue, output_queue): + for time, tracks in tracker: + output_queue.put((time, tracks)) + input_queue.task_done() + + +class SensorFusionNode(SensorNode, FusionNode): + """A :class:`~.Node` that is both a :class:`~.Sensor` and also processes data""" + colour: str = Property( + default='#fc9000', + doc='Colour to be displayed on graph. Default is the hex colour code #fc9000') + shape: str = Property( + default='diamond', + doc='Shape used to display nodes. Default is a diamond') + + +class RepeaterNode(Node): + """A :class:`~.Node` which simply passes data along to others, without manipulating the + data itself. Consequently, :class:`~.RepeaterNode`s are only used within a + :class:`~.NetworkArchitecture`""" + colour: str = Property( + default='#909090', + doc='Colour to be displayed on graph. Default is the hex colour code #909090') + shape: str = Property( + default='rectangle', + doc='Shape used to display nodes. Default is a rectangle') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_held = None diff --git a/stonesoup/architecture/tests/__init__.py b/stonesoup/architecture/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stonesoup/architecture/tests/conftest.py b/stonesoup/architecture/tests/conftest.py new file mode 100644 index 000000000..8cbd0be3e --- /dev/null +++ b/stonesoup/architecture/tests/conftest.py @@ -0,0 +1,439 @@ +import pytest +import random +import numpy as np +from ordered_set import OrderedSet +from datetime import datetime, timedelta +import copy + +from ..edge import Edge, DataPiece, Edges +from ..node import Node, RepeaterNode, SensorNode, FusionNode, SensorFusionNode +from ...types.track import Track +from ...sensor.categorical import HMMSensor +from ...models.measurement.categorical import MarkovianMeasurementModel +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from stonesoup.types.state import StateVector +from stonesoup.sensor.radar.radar import RadarRotatingBearingRange +from stonesoup.types.angle import Angle + +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.updater.kalman import ExtendedKalmanUpdater +from stonesoup.hypothesiser.distance import DistanceHypothesiser +from stonesoup.measures import Mahalanobis +from stonesoup.dataassociator.neighbour import GNNWith2DAssignment +from stonesoup.deleter.error import CovarianceBasedDeleter +from stonesoup.types.state import GaussianState +from stonesoup.initiator.simple import MultiMeasurementInitiator +from stonesoup.tracker.simple import MultiTargetTracker +from stonesoup.architecture.edge import FusionQueue +from stonesoup.updater.wrapper import DetectionAndTrackSwitchingUpdater +from stonesoup.updater.chernoff import ChernoffUpdater +from stonesoup.feeder.track import Tracks2GaussianDetectionFeeder +from ...types.hypothesis import Hypothesis + + +@pytest.fixture +def edges(): + edge_a = Edge((Node(label="edge_a sender"), Node(label="edge_a recipient"))) + edge_b = Edge((Node(label="edge_b sender"), Node(label="edge_b recipient"))) + edge_c = Edge((Node(label="edge_c sender"), Node(label="edge_c recipient"))) + return {'a': edge_a, 'b': edge_b, 'c': edge_c} + + +@pytest.fixture +def nodes(): + E = np.array([[0.8, 0.1], # P(small | bike), P(small | car) + [0.19, 0.3], # P(medium | bike), P(medium | car) + [0.01, 0.6]]) # P(large | bike), P(large | car) + + model = MarkovianMeasurementModel(emission_matrix=E, + measurement_categories=['small', 'medium', 'large']) + + hmm_sensor = HMMSensor(measurement_model=model) + + node_a = Node(label="node a") + node_b = Node(label="node b") + sensornode_1 = SensorNode(sensor=hmm_sensor, label="s1", font_size=10, node_dim=(1, 1)) + sensornode_2 = SensorNode(sensor=hmm_sensor) + sensornode_3 = SensorNode(sensor=hmm_sensor, label='s3') + sensornode_4 = SensorNode(sensor=hmm_sensor, label='s4') + sensornode_5 = SensorNode(sensor=hmm_sensor, label='s5') + sensornode_6 = SensorNode(sensor=hmm_sensor, label='s6') + sensornode_7 = SensorNode(sensor=hmm_sensor, label='s7') + sensornode_8 = SensorNode(sensor=hmm_sensor, label='s8') + repeaternode_1 = RepeaterNode(label='r1') + repeaternode_2 = RepeaterNode(label='r2') + repeaternode_3 = RepeaterNode(label='r3') + repeaternode_4 = RepeaterNode(label='r4') + + pnode_1 = SensorNode(sensor=hmm_sensor, label='p1', position=(0, 0)) + pnode_2 = SensorNode(sensor=hmm_sensor, label='p2', position=(-1, -1)) + pnode_3 = SensorNode(sensor=hmm_sensor, label='p3', position=(1, -1)) + + return {"a": node_a, "b": node_b, "s1": sensornode_1, "s2": sensornode_2, "s3": sensornode_3, + "s4": sensornode_4, "s5": sensornode_5, "s6": sensornode_6, "s7": sensornode_7, + "s8": sensornode_8, "r1": repeaternode_1, "r2": repeaternode_2, "r3": repeaternode_3, + "r4": repeaternode_4, "p1": pnode_1, "p2": pnode_2, "p3": pnode_3} + + +@pytest.fixture +def data_pieces(times, nodes): + data_piece_a = DataPiece(node=nodes['a'], originator=nodes['a'], + data=Track([]), time_arrived=times['a']) + data_piece_b = DataPiece(node=nodes['a'], originator=nodes['b'], + data=Track([]), time_arrived=times['b']) + data_piece_fail = DataPiece(node=nodes['a'], originator=nodes['b'], + data="Not a compatible data type", time_arrived=times['b']) + data_piece_hyp = DataPiece(node=nodes['a'], originator=nodes['b'], + data=Hypothesis(), time_arrived=times['b']) + return {'a': data_piece_a, 'b': data_piece_b, 'fail': data_piece_fail, 'hyp': data_piece_hyp} + + +@pytest.fixture +def times(): + time_a = datetime.strptime("23/08/2023 13:36:00", "%d/%m/%Y %H:%M:%S") + time_b = datetime.strptime("23/08/2023 13:37:00", "%d/%m/%Y %H:%M:%S") + start_time = datetime.strptime("25/12/1306 23:47:00", "%d/%m/%Y %H:%M:%S") + return {'a': time_a, 'b': time_b, 'start': start_time} + + +@pytest.fixture +def transition_model(): + transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(0.005), + ConstantVelocity(0.005)]) + return transition_model + + +@pytest.fixture +def timesteps(times): + start_time = times["start"] + time_max = 60 # timestamps the simulation is observed over + timesteps = [start_time + timedelta(seconds=k) for k in range(time_max)] + return timesteps + + +@pytest.fixture +def ground_truths(transition_model, times, timesteps): + yps = range(0, 100, 10) # y value for prior state + truths = OrderedSet() + ntruths = 3 # number of ground truths in simulation + time_max = 60 # timestamps the simulation is observed over + + xdirection = 1 + ydirection = 1 + + # Generate ground truths + for j in range(0, ntruths): + truth = GroundTruthPath([GroundTruthState([0, xdirection, yps[j], ydirection], + timestamp=timesteps[0])], id=f"id{j}") + + for k in range(1, time_max): + truth.append( + GroundTruthState(transition_model.function(truth[k - 1], noise=True, + time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + truths.add(truth) + + xdirection *= -1 + if j % 2 == 0: + ydirection *= -1 + + return truths + + +@pytest.fixture +def radar_sensors(times): + start_time = times["start"] + total_no_sensors = 6 + sensor_set = OrderedSet() + for n in range(0, total_no_sensors): + sensor = RadarRotatingBearingRange( + position_mapping=(0, 2), + noise_covar=np.array([[np.radians(0.5) ** 2, 0], + [0, 1 ** 2]]), + ndim_state=4, + position=np.array([[10], [n * 20 - 40]]), + rpm=60, + fov_angle=np.radians(360), + dwell_centre=StateVector([0.0]), + max_range=np.inf, + resolution=Angle(np.radians(30)) + ) + sensor_set.add(sensor) + for sensor in sensor_set: + sensor.timestamp = start_time + + return sensor_set + + +@pytest.fixture +def predictor(transition_model): + predictor = KalmanPredictor(transition_model) + return predictor + + +@pytest.fixture +def updater(): + updater = ExtendedKalmanUpdater(measurement_model=None) + return updater + + +@pytest.fixture +def hypothesiser(predictor, updater): + hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(), + missed_distance=5) + return hypothesiser + + +@pytest.fixture +def data_associator(hypothesiser): + data_associator = GNNWith2DAssignment(hypothesiser) + return data_associator + + +@pytest.fixture +def deleter(hypothesiser): + deleter = CovarianceBasedDeleter(covar_trace_thresh=7) + return deleter + + +@pytest.fixture +def initiator(data_associator, deleter, updater): + initiator = MultiMeasurementInitiator( + prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 1, 0, 1])), + measurement_model=None, + deleter=deleter, + data_associator=data_associator, + updater=updater, + min_points=2, + ) + return initiator + + +@pytest.fixture +def tracker(initiator, deleter, data_associator, updater): + tracker = MultiTargetTracker(initiator, deleter, None, data_associator, updater) + return tracker + + +@pytest.fixture +def track_updater(): + track_updater = ChernoffUpdater(None) + return track_updater + + +@pytest.fixture +def detection_updater(): + detection_updater = ExtendedKalmanUpdater(None) + return detection_updater + + +@pytest.fixture +def detection_track_updater(detection_updater, track_updater): + detection_track_updater = DetectionAndTrackSwitchingUpdater(None, detection_updater, + track_updater) + return detection_track_updater + + +@pytest.fixture +def fusion_queue(): + fq = FusionQueue() + return fq + + +@pytest.fixture +def track_tracker(initiator, deleter, fusion_queue, data_associator, detection_track_updater): + track_tracker = MultiTargetTracker( + initiator, deleter, Tracks2GaussianDetectionFeeder(fusion_queue), data_associator, + detection_track_updater) + return track_tracker + + +@pytest.fixture +def radar_nodes(tracker, track_tracker, radar_sensors, fusion_queue): + sensor_set = radar_sensors + node_A = SensorNode(sensor=sensor_set[0]) + node_B = SensorNode(sensor=sensor_set[2]) + + node_C_tracker = copy.deepcopy(tracker) + node_C_tracker.detector = FusionQueue() + node_C = FusionNode(tracker=node_C_tracker, fusion_queue=node_C_tracker.detector, latency=0) + + node_D = SensorNode(sensor=sensor_set[1]) + node_E = SensorNode(sensor=sensor_set[3]) + + node_F_tracker = copy.deepcopy(tracker) + node_F_tracker.detector = FusionQueue() + node_F = FusionNode(tracker=node_F_tracker, fusion_queue=node_F_tracker.detector, latency=0) + + node_H = SensorNode(sensor=sensor_set[4]) + + node_G = FusionNode(tracker=track_tracker, fusion_queue=fusion_queue, latency=0) + + node_I_tracker = copy.deepcopy(tracker) + node_I_tracker.detector = FusionQueue() + + node_I = SensorFusionNode(sensor=sensor_set[5], tracker=node_I_tracker, + fusion_queue=node_I_tracker.detector) + + return {'a': node_A, 'b': node_B, 'c': node_C, 'd': node_D, 'e': node_E, 'f': node_F, + 'g': node_G, 'h': node_H, 'i': node_I} + + +@pytest.fixture +def edge_lists(nodes, radar_nodes): + hierarchical_edges = Edges([Edge((nodes['s2'], nodes['s1'])), Edge((nodes['s3'], nodes['s1'])), + Edge((nodes['s4'], nodes['s2'])), Edge((nodes['s5'], nodes['s2'])), + Edge((nodes['s6'], nodes['s3'])), Edge((nodes['s7'], nodes['s6']))] + ) + + centralised_edges = Edges( + [Edge((nodes['s2'], nodes['s1'])), Edge((nodes['s3'], nodes['s1'])), + Edge((nodes['s4'], nodes['s2'])), Edge((nodes['s5'], nodes['s2'])), + Edge((nodes['s6'], nodes['s3'])), Edge((nodes['s7'], nodes['s6'])), + Edge((nodes['s7'], nodes['s5'])), Edge((nodes['s5'], nodes['s3']))]) + + simple_edges = Edges([Edge((nodes['s2'], nodes['s1'])), Edge((nodes['s3'], nodes['s1']))]) + + linear_edges = Edges([Edge((nodes['s1'], nodes['s2'])), Edge((nodes['s2'], nodes['s3'])), + Edge((nodes['s3'], nodes['s4'])), + Edge((nodes['s4'], nodes['s5']))]) + + decentralised_edges = Edges( + [Edge((nodes['s2'], nodes['s1'])), Edge((nodes['s3'], nodes['s1'])), + Edge((nodes['s3'], nodes['s4'])), Edge((nodes['s3'], nodes['s5'])), + Edge((nodes['s5'], nodes['s4']))]) + + disconnected_edges = Edges([Edge((nodes['s2'], nodes['s1'])), + Edge((nodes['s4'], nodes['s3']))]) + + k4_edges = Edges( + [Edge((nodes['s1'], nodes['s2'])), Edge((nodes['s1'], nodes['s3'])), + Edge((nodes['s1'], nodes['s4'])), Edge((nodes['s2'], nodes['s3'])), + Edge((nodes['s2'], nodes['s4'])), Edge((nodes['s3'], nodes['s4']))]) + + circular_edges = Edges( + [Edge((nodes['s1'], nodes['s2'])), Edge((nodes['s2'], nodes['s3'])), + Edge((nodes['s3'], nodes['s4'])), Edge((nodes['s4'], nodes['s5'])), + Edge((nodes['s5'], nodes['s1']))]) + + disconnected_loop_edges = Edges( + [Edge((nodes['s2'], nodes['s1'])), Edge((nodes['s4'], nodes['s3'])), + Edge((nodes['s3'], nodes['s4']))]) + + repeater_edges = Edges([Edge((nodes['s2'], nodes['r1'])), Edge((nodes['s4'], nodes['r1']))]) + + radar_edges = Edges([Edge((radar_nodes['a'], radar_nodes['c'])), + Edge((radar_nodes['b'], radar_nodes['c'])), + Edge((radar_nodes['d'], radar_nodes['f'])), + Edge((radar_nodes['e'], radar_nodes['f'])), + Edge((radar_nodes['c'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['f'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['h'], radar_nodes['g']))]) + + sf_radar_edges = Edges([Edge((radar_nodes['a'], radar_nodes['c'])), + Edge((radar_nodes['b'], radar_nodes['c'])), + Edge((radar_nodes['d'], radar_nodes['f'])), + Edge((radar_nodes['e'], radar_nodes['f'])), + Edge((radar_nodes['c'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['f'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['h'], radar_nodes['i'])), + Edge((radar_nodes['i'], radar_nodes['g']))]) + + network_edges = Edges([ + Edge((radar_nodes['a'], nodes['r1']), edge_latency=0.5), + Edge((nodes['r1'], radar_nodes['c']), edge_latency=0.5), + Edge((radar_nodes['b'], radar_nodes['c'])), + Edge((radar_nodes['a'], nodes['r2']), edge_latency=0.5), + Edge((nodes['r2'], radar_nodes['c'])), + Edge((nodes['r1'], nodes['r2'])), + Edge((radar_nodes['d'], radar_nodes['f'])), + Edge((radar_nodes['e'], radar_nodes['f'])), + Edge((radar_nodes['c'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['f'], radar_nodes['g']), edge_latency=0), + Edge((radar_nodes['h'], radar_nodes['g'])) + ]) + + return {"hierarchical_edges": hierarchical_edges, "centralised_edges": centralised_edges, + "simple_edges": simple_edges, "linear_edges": linear_edges, + "decentralised_edges": decentralised_edges, "disconnected_edges": disconnected_edges, + "k4_edges": k4_edges, "circular_edges": circular_edges, + "disconnected_loop_edges": disconnected_loop_edges, "repeater_edges": repeater_edges, + "radar_edges": radar_edges, "sf_radar_edges": sf_radar_edges, + "network_edges": network_edges} + + +@pytest.fixture() +def generator_params(): + start_time = datetime.now().replace(microsecond=0) + np.random.seed(1990) + random.seed(1990) + + from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity + from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState + + # Generate transition model + transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(0.005), + ConstantVelocity(0.005)]) + + yps = range(0, 100, 10) # y value for prior state + truths = OrderedSet() + ntruths = 3 # number of ground truths in simulation + time_max = 60 # timestamps the simulation is observed over + timesteps = [start_time + timedelta(seconds=k) for k in range(time_max)] + + xdirection = 1 + ydirection = 1 + + # Generate ground truths + for j in range(0, ntruths): + truth = GroundTruthPath([GroundTruthState([0, xdirection, yps[j], ydirection], + timestamp=timesteps[0])], id=f"id{j}") + + for k in range(1, time_max): + truth.append( + GroundTruthState(transition_model.function(truth[k - 1], noise=True, + time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + truths.add(truth) + + xdirection *= -1 + if j % 2 == 0: + ydirection *= -1 + + base_sensor = RadarRotatingBearingRange( + position_mapping=(0, 2), + noise_covar=np.array([[np.radians(0.5) ** 2, 0], + [0, 1 ** 2]]), + ndim_state=4, + position=np.array([[10], [10]]), + rpm=60, + fov_angle=np.radians(360), + dwell_centre=StateVector([0.0]), + max_range=np.inf, + resolution=Angle(np.radians(30)) + ) + + predictor = KalmanPredictor(transition_model) + updater = ExtendedKalmanUpdater(measurement_model=None) + hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(), + missed_distance=5) + data_associator = GNNWith2DAssignment(hypothesiser) + deleter = CovarianceBasedDeleter(covar_trace_thresh=1) + initiator = MultiMeasurementInitiator( + prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 1, 0, 1])), + measurement_model=None, + deleter=deleter, + data_associator=data_associator, + updater=updater, + min_points=2, + ) + + base_tracker = MultiTargetTracker(initiator, deleter, None, data_associator, updater) + + return {'start_time': start_time, + 'truths': truths, + 'base_sensor': base_sensor, + 'base_tracker': base_tracker} diff --git a/stonesoup/architecture/tests/test_architecture.py b/stonesoup/architecture/tests/test_architecture.py new file mode 100644 index 000000000..6f5d9ca65 --- /dev/null +++ b/stonesoup/architecture/tests/test_architecture.py @@ -0,0 +1,784 @@ +import copy + +import pytest +import datetime +import numpy as np + +from .. import InformationArchitecture, NetworkArchitecture, \ + NonPropagatingArchitecture +from ..edge import Edge, Edges, FusionQueue, Message, DataPiece +from ..generator import NetworkArchitectureGenerator +from ..node import RepeaterNode, SensorNode, FusionNode, Node +from ...types.detection import TrueDetection +from ...types.state import GaussianState +from ...types.track import Track + + +def test_hierarchical_plot(nodes, edge_lists): + + edges = edge_lists["hierarchical_edges"] + sf_radar_edges = edge_lists["sf_radar_edges"] + + arch = InformationArchitecture(edges=edges) + + arch.plot() + + decentralised_edges = edge_lists["decentralised_edges"] + arch = InformationArchitecture(edges=decentralised_edges) + + with pytest.raises(ValueError): + arch.plot(plot_style='hierarchical') + + arch2 = InformationArchitecture(edges=sf_radar_edges) + + arch2.plot() + + +def test_plot_title(nodes, edge_lists): + edges = edge_lists["decentralised_edges"] + + arch = InformationArchitecture(edges=edges) + + # Check that plot function runs when plot_title is given as a str. + arch.plot(plot_title="This is the title of my plot") + + # Check that plot function runs when plot_title is True. + arch.plot(plot_title=True) + + # Check that error is raised when plot_title is not a str or a bool. + x = RepeaterNode() + with pytest.raises(ValueError): + arch.plot(plot_title=x) + + +def test_plot_positions(nodes): + edges1 = Edges([Edge((nodes['p2'], nodes['p1'])), Edge((nodes['p3'], nodes['p1']))]) + + arch1 = InformationArchitecture(edges=edges1) + + arch1.plot(use_positions=True) + + # Assert positions are correct after plot() has run + assert nodes['p1'].position == (0, 0) + assert nodes['p2'].position == (-1, -1) + assert nodes['p3'].position == (1, -1) + + # Change plot positions to non tuple values + nodes['p1'].position = RepeaterNode() + nodes['p2'].position = 'Not a tuple' + nodes['p3'].position = ['Definitely', 'not', 'a', 'tuple'] + + edges2 = Edges([Edge((nodes['p2'], nodes['p1'])), Edge((nodes['p3'], nodes['p1']))]) + + with pytest.raises(TypeError): + InformationArchitecture(edges=edges2) + + +def test_density(edge_lists): + + simple_edges = edge_lists["simple_edges"] + k4_edges = edge_lists["k4_edges"] + + # Graph k3 (complete graph with 3 nodes) has 3 edges + # Simple architecture has 3 nodes and 2 edges: density should be 2/3 + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.density == 2/3 + + # Graph k4 has 6 edges and density 1 + k4_architecture = InformationArchitecture(edges=k4_edges) + assert k4_architecture.density == 1 + + +def test_is_hierarchical(edge_lists): + + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + + # Simple architecture should be hierarchical + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.is_hierarchical + + # Hierarchical architecture should be hierarchical + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.is_hierarchical + + # Centralised architecture should not be hierarchical + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.is_hierarchical is False + + # Linear architecture should be hierarchical + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.is_hierarchical + + # Decentralised architecture should not be hierarchical + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.is_hierarchical is False + + # Disconnected architecture should not be connected + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.is_hierarchical is False + + +def test_is_centralised(edge_lists): + + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + disconnected_loop_edges = edge_lists["disconnected_loop_edges"] + + # Simple architecture should be centralised + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.is_centralised + + # Hierarchical architecture should be centralised + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.is_centralised + + # Centralised architecture should be centralised + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.is_centralised + + # Decentralised architecture should not be centralised + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.is_centralised is False + + # Linear architecture should be centralised + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.is_centralised + + # Disconnected architecture should not be centralised + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.is_centralised is False + + disconnected_loop_architecture = InformationArchitecture(edges=disconnected_loop_edges, + force_connected=False) + assert disconnected_loop_architecture.is_centralised is False + + +def test_is_connected(edge_lists): + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + + # Simple architecture should be connected + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.is_connected + + # Hierarchical architecture should be connected + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.is_connected + + # Centralised architecture should be connected + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.is_connected + + # Decentralised architecture should be connected + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.is_connected + + # Linear architecture should be connected + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.is_connected + + # Disconnected architecture should not be connected + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.is_connected is False + + # Raise error with force_connected=True on a disconnected graph + with pytest.raises(ValueError): + _ = InformationArchitecture(edges=disconnected_edges) + + +def test_recipients(nodes, edge_lists): + centralised_edges = edge_lists["centralised_edges"] + + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.recipients(nodes['s1']) == set() + assert centralised_architecture.recipients(nodes['s2']) == {nodes['s1']} + assert centralised_architecture.recipients(nodes['s3']) == {nodes['s1']} + assert centralised_architecture.recipients(nodes['s4']) == {nodes['s2']} + assert centralised_architecture.recipients(nodes['s5']) == {nodes['s2'], nodes['s3']} + assert centralised_architecture.recipients(nodes['s6']) == {nodes['s3']} + assert centralised_architecture.recipients(nodes['s7']) == {nodes['s6'], nodes['s5']} + + with pytest.raises(ValueError): + centralised_architecture.recipients(nodes['s8']) + + +def test_senders(nodes, edge_lists): + centralised_edges = edge_lists["centralised_edges"] + + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.senders(nodes['s1']) == {nodes['s2'], nodes['s3']} + assert centralised_architecture.senders(nodes['s2']) == {nodes['s4'], nodes['s5']} + assert centralised_architecture.senders(nodes['s3']) == {nodes['s5'], nodes['s6']} + assert centralised_architecture.senders(nodes['s4']) == set() + assert centralised_architecture.senders(nodes['s5']) == {nodes['s7']} + assert centralised_architecture.senders(nodes['s6']) == {nodes['s7']} + assert centralised_architecture.senders(nodes['s7']) == set() + + with pytest.raises(ValueError): + centralised_architecture.senders(nodes['s8']) + + +def test_shortest_path_dict(nodes, edge_lists): + + hierarchical_edges = edge_lists["hierarchical_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + + h_arch = InformationArchitecture(edges=hierarchical_edges) + + assert h_arch.shortest_path_dict[nodes['s7']][nodes['s6']] == 1 + assert h_arch.shortest_path_dict[nodes['s7']][nodes['s3']] == 2 + assert h_arch.shortest_path_dict[nodes['s7']][nodes['s1']] == 3 + assert h_arch.shortest_path_dict[nodes['s7']][nodes['s7']] == 0 + assert h_arch.shortest_path_dict[nodes['s5']][nodes['s2']] == 1 + assert h_arch.shortest_path_dict[nodes['s5']][nodes['s1']] == 2 + + with pytest.raises(KeyError): + _ = h_arch.shortest_path_dict[nodes['s2']][nodes['s3']] + + with pytest.raises(KeyError): + _ = h_arch.shortest_path_dict[nodes['s3']][nodes['s6']] + + disconnected_arch = InformationArchitecture(edges=disconnected_edges, force_connected=False) + + assert disconnected_arch.shortest_path_dict[nodes['s2']][nodes['s1']] == 1 + assert disconnected_arch.shortest_path_dict[nodes['s4']][nodes['s3']] == 1 + + with pytest.raises(KeyError): + _ = disconnected_arch.shortest_path_dict[nodes['s1']][nodes['s4']] + _ = disconnected_arch.shortest_path_dict[nodes['s3']][nodes['s4']] + + +def test_top_level_nodes(nodes, edge_lists): + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + circular_edges = edge_lists["circular_edges"] + disconnected_loop_edges = edge_lists["disconnected_loop_edges"] + + # Simple architecture 1 top node + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.top_level_nodes == {nodes['s1']} + + # Hierarchical architecture should have 1 top node + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.top_level_nodes == {nodes['s1']} + + # Centralised architecture should have 1 top node + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.top_level_nodes == {nodes['s1']} + + # Decentralised architecture should have 2 top nodes + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.top_level_nodes == {nodes['s1'], nodes['s4']} + + # Linear architecture should have 1 top node + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.top_level_nodes == {nodes['s5']} + + # Disconnected architecture should have 2 top nodes + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.top_level_nodes == {nodes['s1'], nodes['s3']} + + # Circular architecture should have no top node + circular_architecture = InformationArchitecture(edges=circular_edges) + assert circular_architecture.top_level_nodes == set() + + disconnected_loop_architecture = InformationArchitecture(edges=disconnected_loop_edges, + force_connected=False) + assert disconnected_loop_architecture.top_level_nodes == {nodes['s1']} + + +def test_number_of_leaves(nodes, edge_lists): + + hierarchical_edges = edge_lists["hierarchical_edges"] + circular_edges = edge_lists["circular_edges"] + + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + + # Check number of leaves for top node and senders of top node + assert hierarchical_architecture.number_of_leaves(nodes['s1']) == 3 + assert hierarchical_architecture.number_of_leaves(nodes['s2']) == 2 + assert hierarchical_architecture.number_of_leaves(nodes['s3']) == 1 + + # Check number of leafs of a leaf node is 1 despite having no senders + assert hierarchical_architecture.number_of_leaves(nodes['s7']) == 1 + + circular_architecture = InformationArchitecture(edges=circular_edges) + + # Check any node in a circular architecture has no leaves + assert circular_architecture.number_of_leaves(nodes['s1']) == 0 + assert circular_architecture.number_of_leaves(nodes['s2']) == 0 + assert circular_architecture.number_of_leaves(nodes['s3']) == 0 + assert circular_architecture.number_of_leaves(nodes['s4']) == 0 + assert circular_architecture.number_of_leaves(nodes['s5']) == 0 + + # Test loop case + r1 = Node() + + edges = Edges([Edge((r1, r1))]) + arch = InformationArchitecture(edges) + + assert arch.number_of_leaves(r1) == 0 + + +def test_leaf_nodes(nodes, edge_lists): + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + circular_edges = edge_lists["circular_edges"] + + # Simple architecture should have 2 leaf nodes + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.leaf_nodes == {nodes['s2'], nodes['s3']} + + # Hierarchical architecture should have 3 leaf nodes + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.leaf_nodes == {nodes['s4'], nodes['s5'], nodes['s7']} + + # Centralised architecture should have 2 leaf nodes + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.leaf_nodes == {nodes['s4'], nodes['s7']} + + # Decentralised architecture should have 2 leaf nodes + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.leaf_nodes == {nodes['s3'], nodes['s2']} + + # Linear architecture should have 1 leaf node + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.leaf_nodes == {nodes['s1']} + + # Disconnected architecture should have 2 leaf nodes + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.leaf_nodes == {nodes['s2'], nodes['s4']} + + # Circular architecture should have no leaf nodes + circular_architecture = InformationArchitecture(edges=circular_edges) + assert circular_architecture.top_level_nodes == set() + + +def test_all_nodes(nodes, edge_lists): + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + circular_edges = edge_lists["circular_edges"] + + # Simple architecture should have 3 nodes + simple_architecture = InformationArchitecture(edges=simple_edges) + assert simple_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3']} + + # Hierarchical architecture should have 7 nodes + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert hierarchical_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], + nodes['s4'], nodes['s5'], nodes['s6'], + nodes['s7']} + + # Centralised architecture should have 7 nodes + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert centralised_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], + nodes['s4'], nodes['s5'], nodes['s6'], + nodes['s7']} + + # Decentralised architecture should have 5 nodes + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert decentralised_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], + nodes['s4'], nodes['s5']} + + # Linear architecture should have 5 nodes + linear_architecture = InformationArchitecture(edges=linear_edges) + assert linear_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], nodes['s4'], + nodes['s5']} + + # Disconnected architecture should have 4 nodes + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert disconnected_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], + nodes['s4']} + + # Circular architecture should have 4 nodes + circular_architecture = InformationArchitecture(edges=circular_edges) + assert circular_architecture.all_nodes == {nodes['s1'], nodes['s2'], nodes['s3'], nodes['s4'], + nodes['s5']} + + +def test_sensor_nodes(edge_lists, ground_truths, radar_nodes): + radar_edges = edge_lists["radar_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + + network = InformationArchitecture(edges=radar_edges) + + assert network.sensor_nodes == {radar_nodes['a'], radar_nodes['b'], radar_nodes['d'], + radar_nodes['e'], radar_nodes['h']} + + h_arch = InformationArchitecture(edges=hierarchical_edges) + + assert h_arch.sensor_nodes == h_arch.all_nodes + assert len(h_arch.sensor_nodes) == 7 + + +def test_fusion_nodes(edge_lists, ground_truths, radar_nodes): + radar_edges = edge_lists["radar_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + + network = InformationArchitecture(edges=radar_edges) + + assert network.fusion_nodes == {radar_nodes['c'], radar_nodes['f'], radar_nodes['g']} + + h_arch = InformationArchitecture(edges=hierarchical_edges) + + assert h_arch.fusion_nodes == set() + + +def test_len(edge_lists): + simple_edges = edge_lists["simple_edges"] + hierarchical_edges = edge_lists["hierarchical_edges"] + centralised_edges = edge_lists["centralised_edges"] + linear_edges = edge_lists["linear_edges"] + decentralised_edges = edge_lists["decentralised_edges"] + disconnected_edges = edge_lists["disconnected_edges"] + + # Simple architecture should be connected + simple_architecture = InformationArchitecture(edges=simple_edges) + assert len(simple_architecture) == len(simple_architecture.all_nodes) + + # Hierarchical architecture should be connected + hierarchical_architecture = InformationArchitecture(edges=hierarchical_edges) + assert len(hierarchical_architecture) == len(hierarchical_architecture.all_nodes) + + # Centralised architecture should be connected + centralised_architecture = InformationArchitecture(edges=centralised_edges) + assert len(centralised_architecture) == len(centralised_architecture.all_nodes) + + # Decentralised architecture should be connected + decentralised_architecture = InformationArchitecture(edges=decentralised_edges) + assert len(decentralised_architecture) == len(decentralised_architecture.all_nodes) + + # Linear architecture should be connected + linear_architecture = InformationArchitecture(edges=linear_edges) + assert len(linear_architecture) == len(linear_architecture.all_nodes) + + # Disconnected architecture should not be connected + disconnected_architecture = InformationArchitecture(edges=disconnected_edges, + force_connected=False) + assert len(disconnected_architecture) == len(disconnected_architecture.all_nodes) + + +def test_information_arch_measure(edge_lists, ground_truths, times): + edges = edge_lists["radar_edges"] + start_time = times['start'] + + network = InformationArchitecture(edges=edges) + all_detections = network.measure(ground_truths=ground_truths, current_time=start_time) + + # Check all_detections is a dictionary + assert isinstance(all_detections, dict) + + # Check that number all_detections contains data for all sensor nodes + assert all_detections.keys() == network.sensor_nodes + + # Check that correct number of detections recorded for each sensor node is equal to the number + # of targets + for sensornode in network.sensor_nodes: + # Check that a detection is made for all 3 targets + assert len(all_detections[sensornode]) == 3 + assert isinstance(all_detections[sensornode], set) + for detection in all_detections[sensornode]: + assert isinstance(detection, TrueDetection) + + for node in network.sensor_nodes: + # Check that each sensor node has data held for the detection of all 3 targets + assert len(node.data_held['created'][datetime.datetime(1306, 12, 25, 23, 47, 59)]) == 3 + + +def test_information_arch_measure_no_noise(edge_lists, ground_truths, times): + edges = edge_lists["radar_edges"] + start_time = times['start'] + network = InformationArchitecture(edges=edges) + all_detections = network.measure(ground_truths=ground_truths, current_time=start_time, + noise=False) + + assert isinstance(all_detections, dict) + assert all_detections.keys() == network.sensor_nodes + for sensornode in network.sensor_nodes: + assert len(all_detections[sensornode]) == 3 + assert isinstance(all_detections[sensornode], set) + for detection in all_detections[sensornode]: + assert isinstance(detection, TrueDetection) + + +def test_information_arch_measure_no_detections(edge_lists, ground_truths, times): + edges = edge_lists["radar_edges"] + start_time = times['start'] + network = InformationArchitecture(edges=edges, current_time=None) + all_detections = network.measure(ground_truths=[], current_time=start_time) + + assert isinstance(all_detections, dict) + assert all_detections.keys() == network.sensor_nodes + + # There should exist a key for each sensor node containing an empty list + for sensornode in network.sensor_nodes: + assert len(all_detections[sensornode]) == 0 + assert isinstance(all_detections[sensornode], set) + + +def test_information_arch_measure_no_time(edge_lists, ground_truths): + edges = edge_lists["radar_edges"] + network = InformationArchitecture(edges=edges) + all_detections = network.measure(ground_truths=ground_truths) + + assert isinstance(all_detections, dict) + assert all_detections.keys() == network.sensor_nodes + for sensornode in network.sensor_nodes: + assert len(all_detections[sensornode]) == 3 + assert isinstance(all_detections[sensornode], set) + for detection in all_detections[sensornode]: + assert isinstance(detection, TrueDetection) + + +def test_fully_propagated(edge_lists, times, ground_truths): + edges = edge_lists["radar_edges"] + start_time = times['start'] + + network = InformationArchitecture(edges=edges, current_time=start_time) + network.measure(ground_truths=ground_truths, noise=True) + + for node in network.sensor_nodes: + # Check that each sensor node has data held for the detection of all 3 targets + for key in node.data_held['created'].keys(): + assert len(node.data_held['created'][key]) == 3 + + # Network should not be fully propagated + assert network.fully_propagated is False + + network.propagate(time_increment=1) + + # Network should now be fully propagated + assert network.fully_propagated + + +def test_information_arch_propagate(edge_lists, ground_truths, times): + edges = edge_lists["radar_edges"] + start_time = times['start'] + network = InformationArchitecture(edges=edges, current_time=start_time) + + network.measure(ground_truths=ground_truths, noise=True) + network.propagate(time_increment=1) + + assert network.fully_propagated + + +def test_architecture_init(edge_lists, times): + time = times['start'] + edges = edge_lists["decentralised_edges"] + arch = InformationArchitecture(edges=edges, name='Name of Architecture', current_time=time) + + assert arch.name == 'Name of Architecture' + assert arch.current_time == time + + +def test_information_arch_init(edge_lists): + edges = edge_lists["repeater_edges"] + + # Network contains a repeater node, InformationArchitecture should raise a type error. + with pytest.raises(TypeError): + _ = InformationArchitecture(edges=edges) + + +def test_network_arch(radar_sensors, ground_truths, tracker, track_tracker, times): + start_time = times['start'] + sensor_set = radar_sensors + fq = FusionQueue() + + node_A = SensorNode(sensor=sensor_set[0], label='SensorNode A') + node_B = SensorNode(sensor=sensor_set[2], label='SensorNode B') + + node_C_tracker = copy.deepcopy(tracker) + node_C_tracker.detector = FusionQueue() + node_C = FusionNode(tracker=node_C_tracker, fusion_queue=node_C_tracker.detector, latency=0, + label='FusionNode C') + + ## + node_D = SensorNode(sensor=sensor_set[1], label='SensorNode D') + node_E = SensorNode(sensor=sensor_set[3], label='SensorNode E') + + node_F_tracker = copy.deepcopy(tracker) + node_F_tracker.detector = FusionQueue() + node_F = FusionNode(tracker=node_F_tracker, fusion_queue=node_F_tracker.detector, latency=0) + + node_H = SensorNode(sensor=sensor_set[4]) + + node_G = FusionNode(tracker=track_tracker, fusion_queue=fq, latency=0) + + repeaternode1 = RepeaterNode(label='RepeaterNode 1') + repeaternode2 = RepeaterNode(label='RepeaterNode 2') + + network_arch = NetworkArchitecture( + edges=Edges([Edge((node_A, repeaternode1), edge_latency=0.5), + Edge((repeaternode1, node_C), edge_latency=0.5), + Edge((node_B, node_C)), + Edge((node_A, repeaternode2), edge_latency=0.5), + Edge((repeaternode2, node_C)), + Edge((repeaternode1, repeaternode2)), + Edge((node_D, node_F)), Edge((node_E, node_F)), + Edge((node_C, node_G), edge_latency=0), + Edge((node_F, node_G), edge_latency=0), + Edge((node_H, node_G)) + ]), + current_time=start_time) + + # Check all Nodes are present in the Network Architecture + assert node_A in network_arch.all_nodes + assert node_B in network_arch.all_nodes + assert node_C in network_arch.all_nodes + assert node_D in network_arch.all_nodes + assert node_E in network_arch.all_nodes + assert node_F in network_arch.all_nodes + assert node_G in network_arch.all_nodes + assert node_H in network_arch.all_nodes + assert repeaternode1 in network_arch.all_nodes + assert repeaternode2 in network_arch.all_nodes + assert len(network_arch.all_nodes) == 10 + + # Check Repeater Nodes are not present in the inherited Information Architecture + assert repeaternode1 not in network_arch.information_arch.all_nodes + assert repeaternode2 not in network_arch.information_arch.all_nodes + assert len(network_arch.information_arch.all_nodes) == 8 + + # Check correct number of edges + assert len(network_arch.edges) == 11 + assert len(network_arch.information_arch.edges) == 8 + + # Check time is correct + assert network_arch.current_time == network_arch.information_arch.current_time == start_time + + # Test node 'get' methods work + assert network_arch.repeater_nodes == {repeaternode1, repeaternode2} + assert network_arch.sensor_nodes == {node_A, node_B, node_D, node_E, node_H} + assert network_arch.fusion_nodes == {node_C, node_F, node_G} + + assert network_arch.information_arch.repeater_nodes == set() + assert network_arch.information_arch.sensor_nodes == {node_A, node_B, node_D, node_E, node_H} + assert network_arch.information_arch.fusion_nodes == {node_C, node_F, node_G} + + +def test_network_arch_instantiation_methods(radar_nodes, times): + time = times['start'] + + nodeA = radar_nodes['a'] + nodeB = radar_nodes['c'] + nodeR = RepeaterNode() + + info_edges = Edges([Edge((nodeA, nodeB))]) + network_edges = Edges([Edge((nodeA, nodeR)), Edge((nodeR, nodeB))]) + + # Method 1: Provide InformationArchitecture to NetworkArchitecture + i_arch = InformationArchitecture(edges=info_edges, current_time=time) + + net_arch1 = NetworkArchitecture(edges=network_edges, information_arch=i_arch) + + assert net_arch1.information_arch.edges == info_edges + assert net_arch1.edges == network_edges + + # Method 2: Provide set of information architecture edges to NetworkArchitecture + net_arch2 = NetworkArchitecture(edges=network_edges, + information_architecture_edges=info_edges) + assert net_arch2.information_arch.edges == info_edges + assert net_arch2.edges == network_edges + + # Method 3: Identical Information and Network Architectures + net_arch3 = NetworkArchitecture(edges=info_edges) + assert net_arch3.information_arch.edges == info_edges + assert net_arch3.edges == info_edges + + +def test_net_arch_fully_propagated(generator_params, ground_truths): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = NetworkArchitectureGenerator(arch_type='hierarchical', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 2, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=1, + sensor_max_distance=(10, 10)) + + arch = gen.generate()[0] + + # Pre-test checks on generated architecture + assert isinstance(arch, NetworkArchitecture) + assert len(arch.sensor_nodes) == sum(gen.node_ratio[:2]) + assert len(arch.fusion_nodes) == sum(gen.node_ratio[1:]) + + arch.measure(ground_truths=ground_truths, noise=True) + + for node in arch.sensor_nodes: + # Check that each sensor node has data held for the detection of all 3 targets + for key in node.data_held['created'].keys(): + print(key) + assert len(node.data_held['created'][key]) == 3 + + edge = {edge for edge in arch.edges if + isinstance(edge.nodes[0], RepeaterNode) and + isinstance(edge.nodes[1], FusionNode)}.pop() + + message = Message( + edge, + datetime.datetime(2016, 1, 2, 3, 4, 5), + start_time, + DataPiece( + edge.sender, + edge.sender, + Track([GaussianState([1, 2, 3, 4], np.diag([1, 1, 1, 1]), + datetime.datetime(2016, 1, 2, 3, 4, 5))]), + datetime.datetime(2016, 1, 2, 3, 4, 5), + ), + ) + + edge.sender.messages_to_pass_on.append(message) + + # Network should not be fully propagated + assert not arch.fully_propagated + + arch.propagate(time_increment=1) + + # Network should now be fully propagated + assert arch.fully_propagated + + +def test_non_propagating_arch(edge_lists, times): + edges = edge_lists['hierarchical_edges'] + start_time = times['start'] + + np_arch = NonPropagatingArchitecture(edges, start_time) + + assert np_arch.current_time == start_time + assert np_arch.edges == edges diff --git a/stonesoup/architecture/tests/test_edge.py b/stonesoup/architecture/tests/test_edge.py new file mode 100644 index 000000000..eae7b4748 --- /dev/null +++ b/stonesoup/architecture/tests/test_edge.py @@ -0,0 +1,352 @@ +import datetime + +import pytest + +from .. import RepeaterNode, Node +from ..edge import Edges, Edge, DataPiece, Message, FusionQueue +from ...types.track import Track +from ...types.time import CompoundTimeRange, TimeRange +from .._functions import _dict_set + +from datetime import timedelta + + +def test_data_piece(nodes, times): + with pytest.raises(TypeError): + _ = DataPiece() + + data_piece = DataPiece(node=nodes['a'], originator=nodes['a'], + data=Track([]), time_arrived=times['a']) + assert data_piece.sent_to == set() + assert data_piece.track is None + + +def test_edge_init(nodes, times, data_pieces): + with pytest.raises(TypeError): + Edge() + with pytest.raises(TypeError): + Edge(nodes['a'], nodes['b']) # e.g. forgetting to put nodes inside a tuple + edge = Edge((nodes['a'], nodes['b'])) + assert edge.edge_latency == 0.0 + assert edge.sender == nodes['a'] + assert edge.recipient == nodes['b'] + assert edge.nodes == (nodes['a'], nodes['b']) + assert all(len(edge.messages_held[status]) == 0 for status in ['pending', 'received']) + + assert edge.unsent_data == [] + nodes['a'].data_held['fused'][times['a']] = [data_pieces['a'], data_pieces['b']] + assert (data_pieces['a'], times['a']) in edge.unsent_data + assert (data_pieces['b'], times['a']) in edge.unsent_data + assert len(edge.unsent_data) == 2 + + assert edge.ovr_latency == 0.0 + nodes['a'].latency = 1.0 + nodes['b'].latency = 2.0 + assert edge.ovr_latency == 1.0 + + assert (edge == 10) is False + + +def test_send_update_message(edges, times, data_pieces): + edge = edges['a'] + assert len(edge.messages_held['pending']) == 0 + + message = Message(edge, times['a'], times['a'], data_pieces['a']) + edge.send_message(data_pieces['a'], times['a'], times['a']) + + with pytest.raises(TypeError): + edge.send_message('not_a_data_piece', times['a'], times['a']) + + assert len(edge.messages_held['pending']) == 1 + assert times['a'] in edge.messages_held['pending'] + assert len(edge.messages_held['pending'][times['a']]) == 1 + assert message in edge.messages_held['pending'][times['a']] + assert len(edge.messages_held['received']) == 0 + # times_b is 1 min later + edge.update_messages(current_time=times['b']) + + assert len(edge.messages_held['received']) == 1 + assert len(edge.messages_held['pending']) == 0 + assert message in edge.messages_held['received'][times['a']] + + +def test_update_messages(): + + # Test scenario where message has not yet reached recipient + A = Node(label='A') + B = Node(label='B') + edge = Edge((A, B), edge_latency=0.5) + + time_created = datetime.datetime.now() + time_sent = time_created + data = DataPiece(A, A, Track([]), time_created) + + # Add message to edge + edge.send_message(data, time_created, time_sent) + assert len(edge.messages_held['pending']) == 1 + + # Message should not have arrived yet + edge.update_messages(time_created) + assert len(edge.messages_held['pending']) == 1 + + # Try again a secomd later + edge.update_messages(time_created + timedelta(seconds=1)) + assert len(edge.messages_held['pending']) == 0 + + # Test scenario when message has no destinations + message = Message(edge, time_created, time_sent, data, destinations=None) + _, edge.messages_held = _dict_set(edge.messages_held, + message, + 'pending', + message.arrival_time) + + assert message in edge.messages_held['pending'][message.arrival_time] + + edge.update_messages(time_created + datetime.timedelta(seconds=2)) + assert len(edge.messages_held['pending']) == 0 + + +def test_failed(edges, times): + edge = edges['a'] + assert edge.time_range_failed == CompoundTimeRange() + edge.failed(times['a'], timedelta(seconds=5)) + new_time_range = TimeRange(times['a'], times['a'] + timedelta(seconds=5)) + assert edge.time_range_failed == CompoundTimeRange([new_time_range]) + + +def test_edges(edges, nodes): + edges_list = Edges([edges['a'], edges['b']]) + assert edges_list.edges == [edges['a'], edges['b']] + edges_list.add(edges['c']) + assert edges['c'] in edges_list + edges_list.add(Edge((nodes['a'], nodes['b']))) + assert len(edges_list) == 4 + assert (nodes['a'], nodes['b']) in edges_list.edge_list + assert (nodes['a'], nodes['b']) in edges_list.edge_list + + empty_edges = Edges() + assert len(empty_edges) == 0 + assert empty_edges.edge_list == [] + assert empty_edges.edges == [] + assert [edge for edge in empty_edges] == [] + empty_edges.add(Edge((nodes['a'], nodes['b']))) + assert [edge for edge in empty_edges] == [Edge((nodes['a'], nodes['b']))] + + +def test_message(edges, data_pieces, times): + with pytest.raises(TypeError): + Message() + edge = edges['a'] + message = Message(edge=edge, time_pertaining=times['a'], time_sent=times['b'], + data_piece=data_pieces['a']) + assert message.sender_node == edge.sender + assert message.recipient_node == edge.recipient + edge.edge_latency = 5.0 + edge.sender.latency = 1.0 + assert message.arrival_time == times['b'] + timedelta(seconds=6.0) + assert message.status == 'sending' + with pytest.raises(ValueError): + message.update(times['a']) + + message.update(times['b']) + assert message.status == 'sending' + + message.update(times['b'] + timedelta(seconds=3)) + assert message.status == 'transferring' + + message.update(times['b'] + timedelta(seconds=8)) + assert message.status == 'received' + + +def test_fusion_queue(): + q = FusionQueue() + iter_q = iter(q) + assert q._to_consume == 0 + assert not q.waiting_for_data + assert not q._consuming + q.put("item") + q.put("another item") + + with pytest.raises(NotImplementedError): + q.get("anything") + + assert q._to_consume == 2 + a = next(iter_q) + assert a == "item" + assert q._to_consume == 2 + b = next(iter_q) + assert b == "another item" + assert q._to_consume == 1 + + +def test_message_destinations(times, radar_nodes): + start_time = times['start'] + node1 = RepeaterNode(label='n1') + node2 = radar_nodes['a'] + node2.label = 'n2' + node3 = RepeaterNode(label='n3') + edge1 = Edge((node1, node2)) + edge2 = Edge((node1, node3)) + + # Create a message without defining a destination + message1 = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, Track([]), + datetime.datetime(2016, 1, 2, 3, 4, 5))) + + # Create a message with node 2 as a destination + message2 = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, Track([]), + datetime.datetime(2016, 1, 2, 3, 4, 5)), + destinations={node2}) + + # Create a message with as a defined destination that isn't node 2 + message3 = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, Track([]), + datetime.datetime(2016, 1, 2, 3, 4, 5)), + destinations={node3}) + + # Create message that has node2 and node3 as a destination + message4 = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, Track([]), + datetime.datetime(2016, 1, 2, 3, 4, 5)), + destinations={node2, node3}) + + # Add messages to node1.messages_to_pass_on and check that unpassed_data() catches it + node1.messages_to_pass_on = [message1, message2, message3, message4] + assert edge1.unpassed_data == [message1, message2, message3, message4] + assert edge2.unpassed_data == [message1, message2, message3, message4] + + # Pass data to edges + for edge in [edge1, edge2]: + for message in edge.unpassed_data: + edge.pass_message(message) + + # Check that no 'unsent' data remains + assert edge1.unsent_data == [] + assert edge2.unsent_data == [] + + # Check that all messages are sent to both edges + assert len(edge1.messages_held['pending'][start_time]) == 4 + assert len(edge2.messages_held['pending'][start_time]) == 4 + + # Check node2 and node3 have no messages to pass on + assert node2.messages_to_pass_on == [] + assert node3.messages_to_pass_on == [] + + # Update both edges + edge1.update_messages(start_time+datetime.timedelta(minutes=1), to_network_node=False) + edge2.update_messages(start_time + datetime.timedelta(minutes=1), to_network_node=True) + + # Check node2.messages_to_pass_on contains message3 that does not have node 2 as a destination + assert len(node2.messages_to_pass_on) == 2 + # Check node3.messages_to_pass_on contains all messages as it is not in information arch + assert len(node3.messages_to_pass_on) == 4 + + # Check that node2 has opened message1 and message3 that were intended to be processed by node3 + data_held = [] + for time in node2.data_held['unfused'].keys(): + data_held += node2.data_held['unfused'][time] + assert len(data_held) == 3 + + +def test_unpassed_data(times): + start_time = times['start'] + node1 = RepeaterNode() + node2 = RepeaterNode() + edge = Edge((node1, node2)) + + # Create a message without defining a destination (send to all) + message = Message(edge, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, 'test_data', + datetime.datetime(2016, 1, 2, 3, 4, 5))) + + # Add message to node.messages_to_pass_on and check that unpassed_data catches it + node1.messages_to_pass_on.append(message) + assert edge.unpassed_data == [message] + + # Pass message on and check that unpassed_data no longer flags it as unsent + edge.pass_message(message) + assert edge.unpassed_data == [] + + +def test_add(): + node1 = RepeaterNode() + node2 = RepeaterNode() + node3 = RepeaterNode() + + edge1 = Edge((node1, node2)) + edge2 = Edge((node1, node2)) + edge3 = Edge((node2, node3)) + edge4 = Edge((node1, node3)) + + edges = Edges([edge1, edge2, edge3]) + + # Check edges.edges returns all edges + assert edges.edges == [edge1, edge2, edge3] + + # Add an edge and check the change is reflected in edges.edges + edges.add(edge4) + assert edges.edges == [edge1, edge2, edge3, edge4] + + +def test_remove(): + node1 = RepeaterNode() + node2 = RepeaterNode() + node3 = RepeaterNode() + + edge1 = Edge((node1, node2)) + edge2 = Edge((node1, node2)) + edge3 = Edge((node2, node3)) + + edges = Edges([edge1, edge2, edge3]) + + # Check edges.edges returns all edges + assert edges.edges == [edge1, edge2, edge3] + + # Remove an edge and check the change is reflected in edges.edges + edges.remove(edge1) + assert edges.edges == [edge2, edge3] + + +def test_get(): + node1 = RepeaterNode() + node2 = RepeaterNode() + node3 = RepeaterNode() + + edge1 = Edge((node1, node2)) + edge2 = Edge((node1, node2)) + edge3 = Edge((node2, node3)) + + edges = Edges([edge1, edge2, edge3]) + + assert edges.get((node1, node2)) == [edge1, edge2] + assert edges.get((node2, node3)) == [edge3] + assert edges.get((node3, node2)) == [] + assert edges.get((node1, node3)) == [] + + with pytest.raises(ValueError): + edges.get(node_pair=(node1, node2, node3)) + + with pytest.raises(TypeError): + edges.get(node_pair=[2, node3]) + + +def test_pass_message(times): + start_time = times['start'] + node1 = RepeaterNode() + node2 = RepeaterNode() + edge = Edge((node1, node2)) + message = Message(edge, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time, + DataPiece(node1, node1, 'test_data', datetime.datetime(2016, 1, 2, 3, 4, 5))) + + node1.messages_to_pass_on.append(message) + + assert node1.messages_to_pass_on == [message] + + edge.pass_message(message) + assert node1.messages_to_pass_on == [message] + assert node2.messages_to_pass_on == [] + assert message in edge.messages_held['pending'][start_time] + assert edge.unpassed_data == [] + + assert (message == 10) is False diff --git a/stonesoup/architecture/tests/test_functions.py b/stonesoup/architecture/tests/test_functions.py new file mode 100644 index 000000000..d6206b2b7 --- /dev/null +++ b/stonesoup/architecture/tests/test_functions.py @@ -0,0 +1,51 @@ + +from .._functions import _dict_set, _default_label_gen +from ..node import RepeaterNode + + +def test_dict_set(): + d = dict() + assert d == {} + + inc, d = _dict_set(d, "c", "cow") + assert inc + assert d == {"cow": {"c"}} + inc, d = _dict_set(d, "o", "cow") + assert inc + assert d == {"cow": {"c", "o"}} + inc, d = _dict_set(d, "c", "cow") + assert not inc + assert d == {"cow": {"c", "o"}} + + d2 = dict() + assert d2 == {} + + inc, d2 = _dict_set(d2, "africa", "lion", "yes") + assert inc + assert d2 == {"lion": {"yes": {"africa"}}} + + inc, d2 = _dict_set(d2, "europe", "polar bear", "no") + assert inc + assert d2 == {"lion": {"yes": {"africa"}}, "polar bear": {"no": {"europe"}}} + + inc, d2 = _dict_set(d2, "europe", "lion", "no") + assert inc + assert d2 == {"lion": {"yes": {"africa"}, "no": {"europe"}}, "polar bear": {"no": {"europe"}}} + + inc, d2 = _dict_set(d2, "north america", "lion", "no") + assert inc + assert d2 == {"lion": {"yes": {"africa"}, "no": {"europe", "north america"}}, + "polar bear": {"no": {"europe"}}} + + +def test_default_label(nodes): + node = nodes['a'] + label = next(_default_label_gen(type(node))) + assert label == 'Node\nA' + + repeater = RepeaterNode() + gen = _default_label_gen(type(repeater)) + label = [next(gen) for i in range(26)][-1] # A-Z 26 chars + assert label.split("\n")[-1] == 'Z' + label = next(gen) + assert label == 'RepeaterNode\nAA' diff --git a/stonesoup/architecture/tests/test_generator.py b/stonesoup/architecture/tests/test_generator.py new file mode 100644 index 000000000..b936d5eda --- /dev/null +++ b/stonesoup/architecture/tests/test_generator.py @@ -0,0 +1,244 @@ +import numpy as np +import pytest + +from stonesoup.architecture.edge import FusionQueue +from stonesoup.architecture.generator import InformationArchitectureGenerator, \ + NetworkArchitectureGenerator +from stonesoup.sensor.sensor import Sensor +from stonesoup.tracker import Tracker + + +def test_info_arch_gen_init(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = InformationArchitectureGenerator(start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor) + + # Test default values + assert gen.arch_type == 'decentralised' + assert gen.iteration_limit == 10000 + assert gen.allow_invalid_graph is False + assert gen.n_archs == 2 + + # Test variables set in __init__() + assert gen.n_sensor_nodes == 3 + assert gen.n_sensor_fusion_nodes == 1 + assert gen.n_fusion_nodes == 1 + + assert gen.sensor_max_distance == (0, 0) + + with pytest.raises(ValueError): + InformationArchitectureGenerator(arch_type='not_valid', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor) + + +def test_info_generate_hierarchical(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = InformationArchitectureGenerator(arch_type='hierarchical', + start_time=start_time, + mean_degree=2, + node_ratio=[2, 2, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=3) + + archs = gen.generate() + + for arch in archs: + + # Check correct number of nodes + assert len(arch.all_nodes) == sum([2, 2, 1]) + + # Check node types + assert len(arch.fusion_nodes) == sum([2, 1]) + assert len(arch.sensor_nodes) == sum([2, 2]) + + assert len(arch.edges) == sum([2, 2, 1]) - 1 + + for node in arch.fusion_nodes: + # Check each fusion node has a tracker + assert isinstance(node.tracker, Tracker) + # Check each tracker has a FusionQueue + assert isinstance(node.tracker.detector.reader, FusionQueue) + + for node in arch.sensor_nodes: + # Check each sensor node has a Sensor + assert isinstance(node.sensor, Sensor) + + +def test_info_generate_decentralised(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + mean_deg = 2.5 + + gen = InformationArchitectureGenerator(arch_type='decentralised', + start_time=start_time, + mean_degree=mean_deg, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=2) + + archs = gen.generate() + + for arch in archs: + + # Check correct number of nodes + assert len(arch.all_nodes) == sum([3, 1, 1]) + + # Check node types + assert len(arch.fusion_nodes) == sum([1, 1]) + assert len(arch.sensor_nodes) == sum([3, 1]) + + assert len(arch.edges) == np.ceil(sum([3, 1, 1]) * mean_deg * 0.5) + + for node in arch.fusion_nodes: + # Check each fusion node has a tracker + assert isinstance(node.tracker, Tracker) + # Check each tracker has a FusionQueue + assert isinstance(node.tracker.detector.reader, FusionQueue) + + for node in arch.sensor_nodes: + # Check each sensor node has a Sensor + assert isinstance(node.sensor, Sensor) + + +def test_info_generate_invalid(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + mean_deg = 2.5 + + with pytest.raises(ValueError): + InformationArchitectureGenerator(arch_type='invalid', + start_time=start_time, + mean_degree=mean_deg, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=2) + + +def test_net_arch_gen_init(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = NetworkArchitectureGenerator(start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor) + + # Test default values + assert gen.arch_type == 'decentralised' + assert gen.iteration_limit == 10000 + assert gen.allow_invalid_graph is False + assert gen.n_archs == 2 + + # Test variables set in __init__() + assert gen.n_sensor_nodes == 3 + assert gen.n_sensor_fusion_nodes == 1 + assert gen.n_fusion_nodes == 1 + + assert gen.sensor_max_distance == (0, 0) + + with pytest.raises(ValueError): + gen = NetworkArchitectureGenerator(arch_type='not_valid', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor) + gen.generate() + + +def test_net_generate_hierarchical(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = NetworkArchitectureGenerator(arch_type='hierarchical', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=3) + + archs = gen.generate() + + for arch in archs: + + # Check correct number of nodes + assert len(arch.all_nodes) == sum([3, 1, 1]) + len(arch.repeater_nodes) + + # Check node types + assert len(arch.fusion_nodes) == sum([1, 1]) + assert len(arch.sensor_nodes) == sum([3, 1]) + + assert len(arch.edges) == 2 * len(arch.repeater_nodes) + + for node in arch.fusion_nodes: + # Check each fusion node has a tracker + assert isinstance(node.tracker, Tracker) + # Check each tracker has a FusionQueue + assert isinstance(node.tracker.detector.reader, FusionQueue) + + for node in arch.sensor_nodes: + # Check each sensor node has a Sensor + assert isinstance(node.sensor, Sensor) + + +def test_net_generate_decentralised(generator_params): + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + + gen = NetworkArchitectureGenerator(arch_type='decentralised', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=3) + + archs = gen.generate() + + for arch in archs: + + # Check correct number of nodes + assert len(arch.all_nodes) == sum([3, 1, 1]) + len(arch.repeater_nodes) + + # Check node types + assert len(arch.fusion_nodes) == sum([1, 1]) + assert len(arch.sensor_nodes) == sum([3, 1]) + for edge in arch.edges: + print((edge.nodes[0].label, edge.nodes[1].label)) + print(len(arch.repeater_nodes)) + assert len(arch.edges) == 2 * len(arch.repeater_nodes) + + for node in arch.fusion_nodes: + # Check each fusion node has a tracker + assert isinstance(node.tracker, Tracker) + # Check each tracker has a FusionQueue + assert isinstance(node.tracker.detector.reader, FusionQueue) + + for node in arch.sensor_nodes: + # Check each sensor node has a Sensor + assert isinstance(node.sensor, Sensor) diff --git a/stonesoup/architecture/tests/test_node.py b/stonesoup/architecture/tests/test_node.py new file mode 100644 index 000000000..ed7038c83 --- /dev/null +++ b/stonesoup/architecture/tests/test_node.py @@ -0,0 +1,204 @@ +import pytest + +import copy +from datetime import datetime +import numpy as np + +from ..node import Node, SensorNode, FusionNode, SensorFusionNode, RepeaterNode +from ..edge import FusionQueue, DataPiece +from ..generator import NetworkArchitectureGenerator +from ...types.hypothesis import Hypothesis +from ...types.track import Track +from ...types.detection import Detection +from ... types.groundtruth import GroundTruthPath +from ...types.state import State, StateVector + + +def test_node(data_pieces, times, nodes): + node = Node() + assert node.latency == 0.0 + assert node.font_size is None + assert len(node.data_held) == 3 + assert node.data_held == {"fused": {}, "created": {}, "unfused": {}} + + node.update(times['a'], times['b'], data_pieces['a'], "fused") + new_data_piece = node.data_held['fused'][times['a']].pop() + assert new_data_piece.originator == nodes['a'] + assert isinstance(new_data_piece.data, Track) and len(new_data_piece.data) == 0 + assert new_data_piece.time_arrived == times['b'] + + with pytest.raises(TypeError): + node.update(times['b'], times['a'], data_pieces['hyp'], "created") + node.update(times['b'], times['a'], data_pieces['hyp'], "created", + track=Track([State(state_vector=StateVector([1]))])) + new_data_piece2 = node.data_held['created'][times['b']].pop() + assert new_data_piece2.originator == nodes['b'] + assert isinstance(new_data_piece2.data, Hypothesis) + assert new_data_piece2.time_arrived == times['a'] + + with pytest.raises(TypeError): + node.update(times['a'], + times['b'], + data_pieces['fail'], + "fused", + track=Track([]), + use_arrival_time=False) + + +def test_sensor_node(nodes): + with pytest.raises(TypeError): + SensorNode() + + sensor = nodes['s1'].sensor + snode = SensorNode(sensor=sensor) + assert snode.sensor == sensor + assert snode.colour == '#006eff' + assert snode.shape == 'oval' + + +def test_fusion_node(tracker): + fnode_tracker = copy.deepcopy(tracker) + fnode_tracker.detector = FusionQueue() + fnode = FusionNode(tracker=fnode_tracker, fusion_queue=fnode_tracker.detector, latency=0) + + assert fnode.colour == '#00b53d' + assert fnode.shape == 'hexagon' + assert fnode.tracks == set() + + with pytest.raises(TypeError): + FusionNode() + + fnode.fuse() # Works. Thorough testing left to test_architecture.py + + # Test FusionNode instantiation with no fusion_queue + fnode2_tracker = copy.deepcopy(tracker) + fnode2_tracker.detector = FusionQueue() + fnode2 = FusionNode(fnode2_tracker) + + assert fnode2.fusion_queue == fnode2_tracker.detector + + # Test FusionNode instantiation with no fusion_queue or tracker.detector + fnode3_tracker = copy.deepcopy(tracker) + fnode3_tracker.detector = None + fnode3 = FusionNode(fnode3_tracker) + + assert isinstance(fnode3.fusion_queue, FusionQueue) + + +def test_sf_node(tracker, nodes): + with pytest.raises(TypeError): + SensorFusionNode() + sfnode_tracker = copy.deepcopy(tracker) + sfnode_tracker.detector = FusionQueue() + sfnode = SensorFusionNode(tracker=sfnode_tracker, fusion_queue=sfnode_tracker.detector, + latency=0, sensor=nodes['s1'].sensor) + + assert sfnode.colour == '#fc9000' + assert sfnode.shape == 'diamond' + + assert sfnode.tracks == set() + + +def test_repeater_node(): + rnode = RepeaterNode() + + assert rnode.colour == '#909090' + assert rnode.shape == 'rectangle' + + +def test_update(tracker): + A = Node() + B = Node() + C = Node() + + dt0 = "This ain't no datetime object" + dt1 = datetime.now() + + t_data = DataPiece(A, A, Track([]), dt1) + d_data = DataPiece(A, A, Detection(state_vector=StateVector(np.random.rand(4, 1)), + timestamp=dt1), dt1) + h_data = DataPiece(A, A, Hypothesis(), dt1) + + # Test invalid time inputs + with pytest.raises(TypeError): + A.update(dt0, dt0, 'faux DataPiece', 'created') + + # Test invalid data_piece + with pytest.raises(TypeError): + A.update(dt1, dt1, 'faux DataPiece', 'created') + + # Test invalid category + with pytest.raises(ValueError): + A.update(dt1, dt1, t_data, 'forged') + + # Test non-detection-or-track-datapiece with Track=False + with pytest.raises(TypeError): + A.update(dt1, dt1, h_data, 'created') + + # Test non-hypothesis-datapiece with Track=True + with pytest.raises(TypeError): + A.update(dt1, dt1, d_data, 'created', track=True) + + # For track DataPiece, test new DataPiece is created and placed in data_held + A.update(dt1, dt1, t_data, 'created') + new_data_piece = A.data_held['created'][dt1].pop() + + assert t_data.originator == new_data_piece.originator + assert t_data.data == new_data_piece.data + assert t_data.time_arrived == new_data_piece.time_arrived + + # For detection DataPiece, test new DataPiece is created and placed in data_held + B.update(dt1, dt1, d_data, 'created') + new_data_piece = B.data_held['created'][dt1].pop() + + assert d_data.originator == new_data_piece.originator + assert d_data.data == new_data_piece.data + assert d_data.time_arrived == new_data_piece.time_arrived + + # For hypothesis DataPiece, test new DataPiece is created and placed in data_held + C.update(dt1, dt1, h_data, 'created', track=True) + new_data_piece = C.data_held['created'][dt1].pop() + + assert h_data.originator == new_data_piece.originator + assert h_data.data == new_data_piece.data + assert h_data.time_arrived == new_data_piece.time_arrived + + # Test placing data into fusion queue - use_arrival_time=False + D = FusionNode(tracker=tracker) + D.update(dt1, dt1, d_data, 'created', use_arrival_time=False) + assert d_data.data in D.fusion_queue.received + + # Test placing data into fusion queue - use_arrival_time=True + D = FusionNode(tracker=tracker) + D.update(dt1, dt1, d_data, 'created', use_arrival_time=True) + copied_data = D.fusion_queue.received.pop() + assert sum(copied_data.state_vector - d_data.data.state_vector) == 0 + assert copied_data.measurement_model == d_data.data.measurement_model + assert copied_data.metadata == d_data.data.metadata + + +def test_fuse(generator_params, ground_truths, timesteps): + # Full data fusion simulation + start_time = generator_params['start_time'] + base_sensor = generator_params['base_sensor'] + base_tracker = generator_params['base_tracker'] + gen = NetworkArchitectureGenerator(arch_type='hierarchical', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 2, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=1, + sensor_max_distance=(10, 10)) + + arch = gen.generate()[0] + + assert all([isinstance(gt, GroundTruthPath) for gt in ground_truths]) + + for time in timesteps: + arch.measure(ground_truths, noise=True) + arch.propagate(time_increment=1) + + for node in arch.fusion_nodes: + for track in node.tracks: + assert isinstance(track, Track) diff --git a/stonesoup/feeder/__init__.py b/stonesoup/feeder/__init__.py index a8f072897..757ba9e7f 100644 --- a/stonesoup/feeder/__init__.py +++ b/stonesoup/feeder/__init__.py @@ -4,6 +4,6 @@ framework, and feed them into the tracking algorithms. These can then optionally be used to drop detections, deliver detections out of sequence, synchronise out of sequence detections, etc. """ -from .base import Feeder +from .base import Feeder, DetectionFeeder, GroundTruthFeeder -__all__ = ['Feeder'] +__all__ = ['Feeder', 'DetectionFeeder', 'GroundTruthFeeder'] diff --git a/stonesoup/feeder/track.py b/stonesoup/feeder/track.py index 37f2aebea..77aff240b 100644 --- a/stonesoup/feeder/track.py +++ b/stonesoup/feeder/track.py @@ -1,13 +1,14 @@ import numpy as np -from stonesoup.types.detection import GaussianDetection -from stonesoup.feeder.base import DetectionFeeder -from stonesoup.models.measurement.linear import LinearGaussian +from . import DetectionFeeder from ..buffered_generator import BufferedGenerator +from ..models.measurement.linear import LinearGaussian +from ..types.detection import GaussianDetection, Detection +from ..types.track import Track class Tracks2GaussianDetectionFeeder(DetectionFeeder): - ''' + """ Feeder consumes Track objects and outputs GaussianDetection objects. At each time step, the :attr:`Reader` feeds in a set of live tracks. The feeder takes the most @@ -15,24 +16,31 @@ class Tracks2GaussianDetectionFeeder(DetectionFeeder): :class:`~.GaussianDetection` objects. Each detection is given a :class:`~.LinearGaussian` measurement model whose covariance is equal to the state covariance. The feeder assumes that the tracks are all live, that is each track has a state at the most recent time step. - ''' + """ @BufferedGenerator.generator_method def data_gen(self): for time, tracks in self.reader: detections = set() for track in tracks: - dim = len(track.state.state_vector) - metadata = track.metadata.copy() - metadata['track_id'] = track.id - detections.add( - GaussianDetection.from_state( - track.state, - state_vector=track.mean, - covar=track.covar, - measurement_model=LinearGaussian( - dim, list(range(dim)), np.asarray(track.covar)), - metadata=metadata, - target_type=GaussianDetection) - ) + + if isinstance(track, Track): + dim = len(track.state.state_vector) + metadata = track.metadata.copy() + metadata['track_id'] = track.id + detections.add( + GaussianDetection.from_state( + track.state, + state_vector=track.mean, + covar=track.covar, + measurement_model=LinearGaussian( + dim, list(range(dim)), np.asarray(track.covar)), + metadata=metadata, + target_type=GaussianDetection) + ) + else: + if not isinstance(track, (Detection, Track)): + raise TypeError(f"track is of type {type(track)}. Expected Track or " + f"Detection") + detections.add(track) yield time, detections diff --git a/stonesoup/metricgenerator/metrictables.py b/stonesoup/metricgenerator/metrictables.py index c5add46f6..91933b337 100644 --- a/stonesoup/metricgenerator/metrictables.py +++ b/stonesoup/metricgenerator/metrictables.py @@ -6,7 +6,7 @@ from matplotlib import pyplot as plt from .base import MetricTableGenerator, MetricGenerator -from ..base import Property +from ..base import Property, Base class RedGreenTableGenerator(MetricTableGenerator): @@ -135,3 +135,112 @@ def set_default_descriptions(self): "SIAP ID Correctness": "Fraction of true objects with correct ID assignment", "SIAP ID Ambiguity": "Fraction of true objects with ambiguous ID assignment" } + + +class SIAPDiffTableGenerator(Base): + """ + Given two sets of metric generators, the SiapDiffTableGenerator returns a table displaying the + difference between two sets of metrics. Allows quick comparison of two sets of metrics. + """ + metrics: Collection[Collection[MetricGenerator]] = Property(doc="Set of metrics to put in the " + "table") + + metrics_labels: Collection[str] = Property(doc='List of titles for ', + default=None) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.targets = None + self.descriptions = None + self.set_default_targets() + self.set_default_descriptions() + if self.metrics_labels is None: + self.metrics_labels = ['Metrics ' + str(i) + ' Value' + for i in range(1, len(self.metrics) + 1)] + + def compute_metric(self, **kwargs): + """Generate table method + + Returns a matplotlib Table of metrics for two sets of values. Table contains metric + descriptions, target values and a coloured value cell for each set of metrics. + The colour of each value cell represents how the pair of values of the metric compare to + eachother, with the better value showing in green. Table also contains a 'Diff' value + displaying the difference between the pair of metric values.""" + + white = (1, 1, 1) + cellText = [["Metric", "Description", "Target"] + list(self.metrics_labels) + ["Max Diff"]] + cellColors = [[white] * (4 + len(self.metrics))] + + sorted_metrics = [sorted(metric_gens, key=attrgetter('title')) + for metric_gens in self.metrics] + + for row_num in range(len(sorted_metrics[0])): + row_metrics = [m[row_num] for m in sorted_metrics] + + metric_name = row_metrics[0].title + description = self.descriptions[metric_name] + target = self.targets[metric_name] + + values = [metric.value for metric in row_metrics] + + diff = round(max(values) - min(values), ndigits=3) + + row_vals = [metric_name, description, target] + \ + ["{:.2f}".format(value) for value in values] + [diff] + + cellText.append(row_vals) + + colours = [] + for i, value in enumerate(values): + other_values = values[:i] + values[i+1:] + if all(num <= 0 for num in [abs(value - target) - abs(v - target) + for v in other_values]): + colours.append((0, 1, 0, 0.5)) + elif all(num >= 0 for num in [abs(value - target) - abs(v - target) + for v in other_values]): + colours.append((1, 0, 0, 0.5)) + else: + colours.append((1, 1, 0, 0.5)) + + cellColors.append([white, white, white] + colours + [white]) + + # "Plot" table + scale = (1, 3) + fig = plt.figure(figsize=(len(cellText)*scale[0] + 1, len(cellText[0])*scale[1]/2)) + ax = fig.add_subplot(1, 1, 1) + ax.axis('off') + table = matplotlib.table.table(ax, cellText, cellColors, loc='center') + table.auto_set_column_width([0, 1, 2, 3]) + table.scale(*scale) + + return table + + def set_default_targets(self): + self.targets = { + "SIAP Completeness": 1, + "SIAP Ambiguity": 1, + "SIAP Spuriousness": 0, + "SIAP Position Accuracy": 0, + "SIAP Velocity Accuracy": 0, + "SIAP Rate of Track Number Change": 0, + "SIAP Longest Track Segment": 1, + "SIAP ID Completeness": 1, + "SIAP ID Correctness": 1, + "SIAP ID Ambiguity": 0 + } + + def set_default_descriptions(self): + self.descriptions = { + "SIAP Completeness": "Fraction of true objects being tracked", + "SIAP Ambiguity": "Number of tracks assigned to a true object", + "SIAP Spuriousness": "Fraction of tracks that are unassigned to a true object", + "SIAP Position Accuracy": "Positional error of associated tracks to their respective " + "truths", + "SIAP Velocity Accuracy": "Velocity error of associated tracks to their respective " + "truths", + "SIAP Rate of Track Number Change": "Rate of number of track changes per truth", + "SIAP Longest Track Segment": "Duration of longest associated track segment per truth", + "SIAP ID Completeness": "Fraction of true objects with an assigned ID", + "SIAP ID Correctness": "Fraction of true objects with correct ID assignment", + "SIAP ID Ambiguity": "Fraction of true objects with ambiguous ID assignment" + } diff --git a/stonesoup/types/hypothesis.py b/stonesoup/types/hypothesis.py index 12834bd14..35fc5cff8 100644 --- a/stonesoup/types/hypothesis.py +++ b/stonesoup/types/hypothesis.py @@ -101,6 +101,10 @@ def weight(self): class SingleProbabilityHypothesis(ProbabilityHypothesis, SingleHypothesis): """Single Measurement Probability scored hypothesis subclass.""" + def __hash__(self): + return hash((self.probability, self.prediction, self.measurement, + self.measurement_prediction)) + class JointHypothesis(Type, UserDict): """Joint Hypothesis base type diff --git a/stonesoup/updater/wrapper.py b/stonesoup/updater/wrapper.py new file mode 100644 index 000000000..4f2db6555 --- /dev/null +++ b/stonesoup/updater/wrapper.py @@ -0,0 +1,24 @@ +from . import Updater +from ..base import Property +from ..types.detection import GaussianDetection + + +class DetectionAndTrackSwitchingUpdater(Updater): + """Updater which sorts between Detections and Tracks""" + + detection_updater: Updater = Property() + track_updater: Updater = Property() + + def predict_measurement(self, state_prediction, measurement_model=None, **kwargs): + if measurement_model.ndim == state_prediction.ndim: + return self.track_updater.predict_measurement( + state_prediction, measurement_model, **kwargs) + else: + return self.detection_updater.predict_measurement( + state_prediction, measurement_model, **kwargs) + + def update(self, hypothesis, **kwargs): + if isinstance(hypothesis.measurement, GaussianDetection): + return self.track_updater.update(hypothesis, **kwargs) + else: + return self.detection_updater.update(hypothesis, **kwargs)