Skip to content

Commit

Permalink
Merge pull request #589 from ekhunter123/simple_initiator_measmodel
Browse files Browse the repository at this point in the history
Make measurement model optional in measurement-based initiators
  • Loading branch information
sdhiscocks authored Feb 17, 2022
2 parents efb1ba2 + b6a52b1 commit 14db74f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
22 changes: 18 additions & 4 deletions stonesoup/initiator/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class SinglePointInitiator(GaussianInitiator):
"""

prior_state: GaussianState = Property(doc="Prior state information")
measurement_model: MeasurementModel = Property(doc="Measurement model")
measurement_model: MeasurementModel = Property(
default=None,
doc="Measurement model. Can be left as None if all detections have a "
"valid measurement model.")

def initiate(self, detections, timestamp, **kwargs):
"""Initiates tracks given unassociated measurements
Expand Down Expand Up @@ -64,6 +67,8 @@ class SimpleMeasurementInitiator(GaussianInitiator):
This initiator utilises the :class:`~.MeasurementModel` matrix to convert
:class:`~.Detection` state vector and model covariance into state space.
It either takes the :class:`~.MeasurementModel` from the given detection
or uses the :attr:`measurement_model`.
Utilises the ReversibleModel inverse function to convert
non-linear spherical co-ordinates into Cartesian x/y co-ordinates
Expand All @@ -77,7 +82,10 @@ class SimpleMeasurementInitiator(GaussianInitiator):
decompositions.
"""
prior_state: GaussianState = Property(doc="Prior state information")
measurement_model: MeasurementModel = Property(doc="Measurement model")
measurement_model: MeasurementModel = Property(
default=None,
doc="Measurement model. Can be left as None if all detections have a "
"valid measurement model.")
skip_non_reversible: bool = Property(default=False)
diag_load: float = Property(default=0.0, doc="Positive float value for diagonal loading")

Expand All @@ -94,7 +102,10 @@ def initiate(self, detections, timestamp, **kwargs):
if detection.measurement_model is not None:
measurement_model = detection.measurement_model
else:
measurement_model = self.measurement_model
if self.measurement_model is None:
raise ValueError("No measurement model specified")
else:
measurement_model = self.measurement_model

if isinstance(measurement_model, LinearModel):
model_matrix = measurement_model.matrix()
Expand Down Expand Up @@ -155,12 +166,15 @@ class MultiMeasurementInitiator(GaussianInitiator):
Does cause slight delay in initiation to tracker."""

prior_state: GaussianState = Property(doc="Prior state information")
measurement_model: MeasurementModel = Property(doc="Measurement model")
deleter: Deleter = Property(doc="Deleter used to delete the track.")
data_associator: DataAssociator = Property(
doc="Association algorithm to pair predictions to detections.")
updater: Updater = Property(
doc="Updater used to update the track object to the new state.")
measurement_model: MeasurementModel = Property(
default=None,
doc="Measurement model. Can be left as None if all detections have a "
"valid measurement model.")
min_points: int = Property(
default=2, doc="Minimum number of track points required to confirm a track.")
updates_only: bool = Property(
Expand Down
24 changes: 22 additions & 2 deletions stonesoup/initiator/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...hypothesiser.distance import DistanceHypothesiser
from ...dataassociator.neighbour import NearestNeighbour
from ...measures import Mahalanobis
from ...types.detection import Detection
from ...types.detection import Detection, TrueDetection
from ...types.hypothesis import SingleHypothesis
from ...types.prediction import Prediction
from ...types.state import GaussianState
Expand Down Expand Up @@ -294,7 +294,8 @@ def test_multi_measurement(updates_only):

measurement_initiator = MultiMeasurementInitiator(
GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
measurement_model, deleter, data_associator, updater, updates_only=updates_only)
deleter, data_associator, updater,
measurement_model=measurement_model, updates_only=updates_only)

timestamp = datetime.datetime.now()
first_detections = {Detection(np.array([[5], [2]]), timestamp),
Expand All @@ -318,6 +319,25 @@ def test_multi_measurement(updates_only):
assert len(measurement_initiator.holding_tracks) == 0


@pytest.mark.parametrize("initiator", [
SinglePointInitiator(
GaussianState(np.array([[0]]), np.array([[100]]))
),
SimpleMeasurementInitiator(
GaussianState(np.array([[0]]), np.array([[100]]))
),
], ids=['SinglePoint', 'LinearMeasurement'])
def test_measurement_model(initiator):
timestamp = datetime.datetime.now()
dummy_detection = TrueDetection(np.array([0, 0]), timestamp)
# The SinglePointInitiator will raise an error when the ExtendedKalmanUpdater
# is called and neither the detection nor the initiator has a measurement
# model. The SimpleMeasurementInitiator will raise an error in the if/else
# blocks.
with pytest.raises(ValueError):
_ = initiator.initiate({dummy_detection}, timestamp)


@pytest.mark.parametrize("gaussian_initiator", [
SinglePointInitiator(
GaussianState(np.array([[0]]), np.array([[100]])),
Expand Down

0 comments on commit 14db74f

Please sign in to comment.