Skip to content

Commit

Permalink
Implement suggested comments and optimise user interface of regularis…
Browse files Browse the repository at this point in the history
…e method
  • Loading branch information
timothy-glover committed Aug 21, 2023
1 parent 62d43e6 commit 465b167
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 73 deletions.
1 change: 1 addition & 0 deletions stonesoup/predictor/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def predict(self, prior, timestamp=None, **kwargs):
**kwargs)

return Prediction.from_state(prior,
parent=prior,
state_vector=new_state_vector,
timestamp=timestamp,
transition_model=self.transition_model)
Expand Down
27 changes: 15 additions & 12 deletions stonesoup/regulariser/particle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import numpy as np
from scipy.stats import multivariate_normal, uniform
from typing import Sequence

from .base import Regulariser
from ..functions import cholesky_eps
Expand Down Expand Up @@ -29,20 +30,18 @@ class MCMCRegulariser(Regulariser):
.. [2] Ristic, Branko & Arulampalam, Sanjeev & Gordon, Neil, Beyond the Kalman Filter:
Particle Filters for Target Tracking Applications, Artech House, 2004. """

transition_model: TransitionModel = Property(doc="Transition model used for prediction")
transition_model: TransitionModel = Property(doc="Transition model used for prediction",
default=None)

def regularise(self, prior, posterior, detections):
def regularise(self, prior, posterior):
"""Regularise the particles
Parameters
----------
prior : :class:`~.ParticleState` type
prior particle distribution.
posterior : :class:`~.ParticleState` type
posterior particle distribution
detections : set of :class:`~.Detection`
set of detections containing clutter,
true detections or both
posterior particle distribution.
Returns
-------
Expand All @@ -60,14 +59,18 @@ def regularise(self, prior, posterior, detections):
moved_particles = copy.copy(posterior)
transitioned_prior = copy.copy(prior)

if self.transition_model is not None:
hypotheses = posterior.hypothesis if isinstance(posterior.hypothesis, Sequence) \
else [posterior.hypothesis]

transition_model = hypotheses[0].prediction.transition_model or self.transition_model
if transition_model is not None:
time_interval = posterior.timestamp - prior.timestamp
new_state_vector = self.transition_model.function(prior,
noise=False,
time_interval=time_interval)
transitioned_prior.state_vector = new_state_vector
transitioned_prior.state_vector = \
transition_model.function(prior, noise=False, time_interval=time_interval)

detections = {hypothesis.measurement for hypothesis in hypotheses if hypothesis}

if detections is not None:
if detections:
ndim = prior.state_vector.shape[0]
nparticles = len(posterior)

Expand Down
73 changes: 50 additions & 23 deletions stonesoup/regulariser/tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,29 @@
from ...models.measurement.linear import LinearGaussian
from ...models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity
from ...types.detection import Detection
from ...types.update import ParticleStateUpdate
from ...types.update import Update, ParticleStateUpdate
from ..particle import MCMCRegulariser


def test_regulariser():
transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])])

@pytest.mark.parametrize(
"transition_model, model_flag",
[
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False, # model_flag
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
True, # model_flag
),
(
None, # transition_model
False, # model_flag
)
],
ids=["with_transition_model_init", "without_transition_model_init", "no_transition_model"]
)
def test_regulariser(transition_model, model_flag):
particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]),
1 / 9),
Particle(np.array([[10], [20]]),
Expand All @@ -36,25 +52,38 @@ def test_regulariser():
1 / 9),
])
timestamp = datetime.datetime.now()
new_state_vector = transition_model.function(particles,
noise=True,
time_interval=datetime.timedelta(seconds=1))
if transition_model is not None:
new_state_vector = transition_model.function(particles,
noise=True,
time_interval=datetime.timedelta(seconds=1))
else:
new_state_vector = particles.state_vector

prediction = ParticleStatePrediction(new_state_vector,
timestamp=timestamp,
transition_model=transition_model)
meas_pred = ParticleMeasurementPrediction(prediction, timestamp=timestamp)

measurement_model = LinearGaussian(ndim_state=2, mapping=(0, 1), noise_covar=np.eye(2))
measurement = [Detection(state_vector=np.array([[5], [7]]),
timestamp=timestamp, measurement_model=measurement_model)]
state_update = ParticleStateUpdate(None, SingleHypothesis(prediction=prediction,
measurement=measurement,
measurement_prediction=meas_pred),
particle_list=particles.particle_list,
timestamp=timestamp+datetime.timedelta(seconds=1))
regulariser = MCMCRegulariser(transition_model=transition_model)
measurement = Detection(state_vector=np.array([[5], [7]]),
timestamp=timestamp, measurement_model=measurement_model)
hypothesis = SingleHypothesis(prediction=prediction,
measurement=measurement,
measurement_prediction=None)

state_update = Update.from_state(state=prediction,
hypothesis=hypothesis,
timestamp=timestamp+datetime.timedelta(seconds=1))
# A PredictedParticleState is used here as the point at which the regulariser is implemented
# in the updater is before the updated state has taken the updated state type.
state_update.weight = np.array([1/6, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48])

if model_flag:
regulariser = MCMCRegulariser()
else:
regulariser = MCMCRegulariser(transition_model=transition_model)

# state check
new_particles = regulariser.regularise(prediction, state_update, measurement)
new_particles = regulariser.regularise(prediction, state_update)
# Check the shape of the new state vector
assert new_particles.state_vector.shape == state_update.state_vector.shape
# Check weights are unchanged
Expand All @@ -65,13 +94,11 @@ def test_regulariser():
# list check3
with pytest.raises(TypeError) as e:
new_particles = regulariser.regularise(particles.particle_list,
state_update,
measurement)
state_update)
assert "Only ParticleState type is supported!" in str(e.value)
with pytest.raises(Exception) as e:
new_particles = regulariser.regularise(particles,
state_update.particle_list,
measurement)
state_update.particle_list)
assert "Only ParticleState type is supported!" in str(e.value)


Expand Down Expand Up @@ -103,9 +130,9 @@ def test_no_measurement():
measurement=None,
measurement_prediction=meas_pred),
particle_list=particles.particle_list, timestamp=timestamp)
regulariser = MCMCRegulariser(transition_model=None)
regulariser = MCMCRegulariser()

new_particles = regulariser.regularise(particles, state_update, detections=None)
new_particles = regulariser.regularise(particles, state_update)

# Check the shape of the new state vector
assert new_particles.state_vector.shape == state_update.state_vector.shape
Expand Down
12 changes: 6 additions & 6 deletions stonesoup/types/multihypothesis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sized, Iterable, Container
from typing import Sequence
from collections.abc import Sequence
import typing

from .detection import MissedDetection
from .numeric import Probability
Expand All @@ -10,13 +10,13 @@
from ..types.prediction import Prediction


class MultipleHypothesis(Type, Sized, Iterable, Container):
class MultipleHypothesis(Type, Sequence):
"""Multiple Hypothesis base type
A Multiple Hypothesis is a container to store a collection of hypotheses.
"""

single_hypotheses: Sequence[SingleHypothesis] = Property(
single_hypotheses: typing.Sequence[SingleHypothesis] = Property(
default=None,
doc="The initial list of :class:`~.SingleHypothesis`. Default `None` "
"which initialises with empty list.")
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_missed_detection_probability(self):
return None


class MultipleCompositeHypothesis(Type, Sized, Iterable, Container):
class MultipleCompositeHypothesis(Type, Sequence):
"""Multiple composite hypothesis type
A Multiple Composite Hypothesis is a container to store a collection of composite hypotheses.
Expand All @@ -128,7 +128,7 @@ class MultipleCompositeHypothesis(Type, Sized, Iterable, Container):
redefined.
"""

single_hypotheses: Sequence[CompositeHypothesis] = Property(
single_hypotheses: typing.Sequence[CompositeHypothesis] = Property(
default=None,
doc="The initial list of :class:`~.CompositeHypothesis`. Default `None` which initialises "
"with empty list.")
Expand Down
43 changes: 20 additions & 23 deletions stonesoup/updater/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def update(self, hypothesis, **kwargs):
: :class:`~.ParticleState`
The state posterior
"""
predicted_state = copy.copy(hypothesis.prediction)

predicted_state = Update.from_state(
state=hypothesis.prediction,
hypothesis=hypothesis,
timestamp=hypothesis.prediction.timestamp
)

if hypothesis.measurement.measurement_model is None:
measurement_model = self.measurement_model
Expand All @@ -66,17 +71,11 @@ def update(self, hypothesis, **kwargs):
predicted_state = self.resampler.resample(predicted_state)

if self.regulariser is not None:
predicted_state = self.regulariser.regularise(predicted_state.parent,
predicted_state,
[hypothesis.measurement])
prior = hypothesis.prediction.parent
predicted_state = self.regulariser.regularise(prior,
predicted_state)

return Update.from_state(
state=hypothesis.prediction,
state_vector=predicted_state.state_vector,
log_weight=predicted_state.log_weight,
hypothesis=hypothesis,
timestamp=hypothesis.measurement.timestamp,
)
return predicted_state

@lru_cache()
def predict_measurement(self, state_prediction, measurement_model=None,
Expand Down Expand Up @@ -419,8 +418,12 @@ def update(self, hypotheses, **kwargs):
# copy prediction
prediction = hypotheses.single_hypotheses[0].prediction

updated_state = copy.copy(prediction)

# updated_state = copy.copy(prediction)
updated_state = Update.from_state(
state=prediction,
hypothesis=hypotheses,
timestamp=prediction.timestamp
)
if any(hypotheses):
detections = [single_hypothesis.measurement
for single_hypothesis in hypotheses.single_hypotheses]
Expand Down Expand Up @@ -468,16 +471,10 @@ def update(self, hypotheses, **kwargs):
if any(hypotheses):
# Regularisation
if self.regulariser is not None:
regularised_parts = self.regulariser.regularise(updated_state.parent,
updated_state,
detections)
updated_state.state_vector = regularised_parts.state_vector

return Update.from_state(
updated_state,
timestamp=updated_state.timestamp,
hypothesis=hypotheses,
)
updated_state = self.regulariser.regularise(updated_state.parent,
updated_state)

return updated_state

@staticmethod
def _log_space_product(A, B):
Expand Down
38 changes: 29 additions & 9 deletions stonesoup/updater/tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,28 @@ def test_bernoulli_particle():
assert update.existence_probability is not None


def test_regularised_particle():
@pytest.mark.parametrize("transition_model, model_flag", [
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False # model_flag
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
True # model_flag
)
], ids=["with_transition_model_init", "without_transition_model_init"]
)
def test_regularised_particle(transition_model, model_flag):

transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])])
measurement_model = LinearGaussian(
ndim_state=2, mapping=[0], noise_covar=np.array([[10]]))

updater = ParticleUpdater(regulariser=MCMCRegulariser(transition_model=transition_model),
measurement_model=measurement_model)
if model_flag:
updater = ParticleUpdater(regulariser=MCMCRegulariser(),
measurement_model=measurement_model)
else:
updater = ParticleUpdater(regulariser=MCMCRegulariser(transition_model=transition_model),
measurement_model=measurement_model)
# Measurement model
timestamp = datetime.datetime.now()
particles = [Particle([[10], [10]], 1 / 9),
Expand All @@ -198,11 +212,17 @@ def test_regularised_particle():
predicted_state = transition_model.function(particles,
noise=True,
time_interval=datetime.timedelta(seconds=1))
prediction = ParticleStatePrediction(predicted_state,
weight=np.array([1/9]*9),
timestamp=timestamp,
transition_model=transition_model,
parent=particles)
if not model_flag:
prediction = ParticleStatePrediction(predicted_state,
weight=np.array([1/9]*9),
timestamp=timestamp,
parent=particles)
else:
prediction = ParticleStatePrediction(predicted_state,
weight=np.array([1 / 9] * 9),
timestamp=timestamp,
transition_model=transition_model,
parent=particles)

measurement = Detection([[40.0]], timestamp=timestamp, measurement_model=measurement_model)
eval_measurement_prediction = ParticleMeasurementPrediction(
Expand Down

0 comments on commit 465b167

Please sign in to comment.