Skip to content

Commit

Permalink
Merge pull request #664 dstl/association_assignment
Browse files Browse the repository at this point in the history
One to one Assignment for Association
  • Loading branch information
sdhiscocks committed Aug 10, 2023
2 parents 7f4b754 + 5045a47 commit 8c4fcf2
Show file tree
Hide file tree
Showing 15 changed files with 976 additions and 160 deletions.
108 changes: 108 additions & 0 deletions stonesoup/dataassociator/_assignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import copy

import numpy as np
from scipy.optimize import linear_sum_assignment

from ..types.association import AssociationSet


def multidimensional_deconfliction(association_set):
"""Solves the Multidimensional Assignment Problem (MAP)
The assignment problem becomes more complex when time is added as a dimension.
This basic solution finds all the conflicts in an association set and then creates a
matrix of sums of conflicts in seconds, which is then passed to linear_sum_assignment to
solve as a simple 2D assignment problem.
Therefore, each object will only ever be assigned to one other
at any one time. In the case of an association that only partially overlaps, the time range
of the "weaker" one (the one eliminated by assign2D) will be trimmed
until there is no conflict.
Due to the possibility of more than two conflicting associations at the same time,
this algorithm is recursive, but it is not expected many (if any) recursions will be required
for most uses.
Parameters
----------
association_set: The :class:`AssociationSet` to de-conflict
Returns
-------
: :class:`AssociationSet`
The association set without contradictory associations
"""
if check_if_no_conflicts(association_set):
return copy.copy(association_set)

objects = list(association_set.object_set)
length = len(objects)
totals = np.zeros((length, length)) # Time objects i and j are associated for in total

for association in association_set.associations:
if len(association.objects) != 2:
raise ValueError("Supplied set must only contain pairs of associated objects")
i, j = (objects.index(object_) for object_ in association.objects)
totals[i, j] = association.time_range.duration.total_seconds()
totals = np.maximum(totals, totals.transpose()) # make symmetric

totals = np.rint(totals).astype(int)
np.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself
solved_2d = linear_sum_assignment(totals, maximize=True)[1]
cleaned_set = AssociationSet()
association_set_reduced = copy.copy(association_set)
for i, j in enumerate(solved_2d):
if i == j:
# Can't associate with self
continue
try:
assoc = next(iter(association_set_reduced.associations_including_objects(
{objects[i], objects[j]}))) # There should only be 1 association in this set
except StopIteration:
# We took the association out previously in the loop
continue
if all(assoc.duration > clean_assoc.duration or not conflicts(assoc, clean_assoc)
for clean_assoc in cleaned_set):
cleaned_set.add(copy.copy(assoc))
association_set_reduced.remove(assoc)

if len(cleaned_set) == 0:
raise ValueError("Problem unsolvable using this method")

if len(association_set_reduced) == 0:
if check_if_no_conflicts(cleaned_set):
raise RuntimeError("Conflicts still present in cleaned set")
# If no conflicts after this iteration and all objects return
return cleaned_set
else:
# Recursive step
runners_up = multidimensional_deconfliction(association_set_reduced).associations

for runner_up in runners_up:
runner_up_remaining_time = runner_up.time_range
for winner in cleaned_set:
if conflicts(runner_up, winner):
runner_up_remaining_time -= winner.time_range
if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0:
runner_up_copy = copy.copy(runner_up)
runner_up_copy.time_range = runner_up_remaining_time
cleaned_set.add(runner_up_copy)
return cleaned_set


def conflicts(assoc1, assoc2):
if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \
len(assoc1.objects.intersection(assoc2.objects)) > 0 and \
(assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0 and \
assoc1 != assoc2:
return True
else:
return False


def check_if_no_conflicts(association_set):
for assoc1 in range(0, len(association_set)):
for assoc2 in range(assoc1, len(association_set)):
if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]):
return False
return True
74 changes: 74 additions & 0 deletions stonesoup/dataassociator/tests/test_assignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from ...types.association import AssociationSet, TimeRangeAssociation
from ...types.time import TimeRange, CompoundTimeRange
from ...types.track import Track
from .._assignment import multidimensional_deconfliction
import datetime
import pytest


def is_assoc_in_assoc_set(assoc, assoc_set):
return any(assoc.time_range == set_assoc.time_range and
assoc.objects == set_assoc.objects for set_assoc in assoc_set)


def test_multi_deconfliction():
test = AssociationSet()
tested = multidimensional_deconfliction(test)
assert test.associations == tested.associations
tracks = [Track(id=0), Track(id=1), Track(id=2), Track(id=3)]
times = [datetime.datetime(year=2022, month=6, day=1, hour=0),
datetime.datetime(year=2022, month=6, day=1, hour=1),
datetime.datetime(year=2022, month=6, day=1, hour=5),
datetime.datetime(year=2022, month=6, day=2, hour=5),
datetime.datetime(year=2022, month=6, day=1, hour=9)]
ranges = [TimeRange(times[0], times[1]),
TimeRange(times[1], times[2]),
TimeRange(times[2], times[3]),
TimeRange(times[0], times[4]),
TimeRange(times[2], times[4])]

assoc1 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=ranges[0])
assoc2 = TimeRangeAssociation({tracks[2], tracks[3]}, time_range=ranges[0])
assoc3 = TimeRangeAssociation({tracks[0], tracks[3]},
time_range=CompoundTimeRange([ranges[0], ranges[4]]))
assoc4 = TimeRangeAssociation({tracks[0], tracks[1]},
time_range=CompoundTimeRange([ranges[1], ranges[4]]))
a4_clone = TimeRangeAssociation({tracks[0], tracks[1]},
time_range=CompoundTimeRange([ranges[1], ranges[4]]))
# Will fail as there is only one track, rather than two
assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0])
with pytest.raises(ValueError):
multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2}))

# Objects do not conflict, so should return input association set
test2 = AssociationSet({assoc1, assoc2})
assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2}

# Objects do conflict, so remove the shorter time range
test3 = AssociationSet({assoc1, assoc3})
# Should entirely remove assoc1
tested3 = multidimensional_deconfliction(test3)
assert len(tested3) == 1
test_assoc3 = next(iter(tested3.associations))
for var in vars(test_assoc3):
assert getattr(test_assoc3, var) == getattr(assoc3, var)

test4 = AssociationSet({assoc1, assoc2, assoc3, assoc4})
# assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain
tested4 = multidimensional_deconfliction(test4)
assert len(tested4) == 2
assert is_assoc_in_assoc_set(assoc2, tested4)
merged = tested4.associations_including_objects({tracks[0], tracks[1]})
assert len(merged) == 1
merged = next(iter(merged.associations))
assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]])

test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, a4_clone})
# Very similar to above, but we add a duplicate assoc4 - should have no effect on the result.
tested5 = multidimensional_deconfliction(test5)
assert len(tested5) == 2
assert is_assoc_in_assoc_set(assoc2, tested5)
merged = tested5.associations_including_objects({tracks[0], tracks[1]})
assert len(merged) == 1
merged = next(iter(merged.associations))
assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]])
19 changes: 19 additions & 0 deletions stonesoup/dataassociator/tests/test_tracktotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def test_euclidiantracktotrack(tracks):

association_set_4 = complete_associator.associate_tracks({tracks[0]}, {tracks[5]})

complete_associator_one2one = TrackToTrackCounting(
association_threshold=10,
consec_pairs_confirm=3,
consec_misses_end=2,
use_positional_only=False)
start_time = datetime.datetime(2019, 1, 1, 14, 0, 0)

association_set_one2one = complete_associator_one2one.associate_tracks(
{tracks[0], tracks[2]}, {tracks[1], tracks[3], tracks[4]})

assert len(association_set_1.associations) == 1
assoc1 = list(association_set_1.associations)[0]
assert set(assoc1.objects) == {tracks[0], tracks[1]}
Expand Down Expand Up @@ -121,6 +131,15 @@ def test_euclidiantracktotrack(tracks):
assert assoc4.time_range.end_timestamp \
== start_time + datetime.timedelta(seconds=7)

assert len(association_set_one2one) == 1
assoc5 = list(association_set_one2one)[0]
# assoc5 should be equal to assoc1
assert set(assoc5.objects) == {tracks[0], tracks[1]}
assert assoc5.time_range.start_timestamp \
== start_time + datetime.timedelta(seconds=1)
assert assoc5.time_range.end_timestamp \
== start_time + datetime.timedelta(seconds=6)


def test_euclidiantracktotruth(tracks):
associator = TrackToTruth(
Expand Down
11 changes: 10 additions & 1 deletion stonesoup/dataassociator/tracktotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..types.groundtruth import GroundTruthPath
from ..types.track import Track
from ..types.time import TimeRange
from ._assignment import multidimensional_deconfliction


class TrackToTrackCounting(TrackToTrackAssociator):
Expand Down Expand Up @@ -77,6 +78,11 @@ class TrackToTrackCounting(TrackToTrackAssociator):
"position components compared to others (such as velocity). "
"Default is 0.6"
)
one_to_one: bool = Property(
default=False,
doc="If True, it is ensured no two associations ever contain the same track "
"at the same time"
)

def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]):
"""Associate two sets of tracks together.
Expand Down Expand Up @@ -180,7 +186,10 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]):
(track1, track2),
TimeRange(start_timestamp, end_timestamp)))

return AssociationSet(associations)
if self.one_to_one:
return multidimensional_deconfliction(AssociationSet(associations))
else:
return AssociationSet(associations)


class TrackToTruth(TrackToTrackAssociator):
Expand Down
12 changes: 6 additions & 6 deletions stonesoup/metricgenerator/basicmetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ def compute_metric(self, manager, *args, **kwargs):
title='Number of targets',
value=len(manager.groundtruth_paths),
time_range=TimeRange(
start_timestamp=min(timestamps),
end_timestamp=max(timestamps)),
start=min(timestamps),
end=max(timestamps)),
generator=self))

metrics.append(TimeRangeMetric(
title='Number of tracks',
value=len(manager.tracks),
time_range=TimeRange(
start_timestamp=min(timestamps),
end_timestamp=max(timestamps)),
start=min(timestamps),
end=max(timestamps)),
generator=self))

metrics.append(TimeRangeMetric(
title='Track-to-target ratio',
value=len(manager.tracks) / len(manager.groundtruth_paths),
time_range=TimeRange(
start_timestamp=min(timestamps),
end_timestamp=max(timestamps)),
start=min(timestamps),
end=max(timestamps)),
generator=self))

return metrics
12 changes: 6 additions & 6 deletions stonesoup/metricgenerator/tests/test_basicmetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ def test_basicmetrics():
correct_metrics = {TimeRangeMetric(title='Number of targets',
value=3,
time_range=TimeRange(
start_timestamp=start_time,
end_timestamp=start_time +
start=start_time,
end=start_time +
datetime.timedelta(seconds=4)),
generator=generator),
TimeRangeMetric(title='Number of tracks',
value=4,
time_range=TimeRange(
start_timestamp=start_time,
end_timestamp=start_time +
start=start_time,
end=start_time +
datetime.timedelta(seconds=4)),
generator=generator),
TimeRangeMetric(title='Track-to-target ratio',
value=4 / 3,
time_range=TimeRange(
start_timestamp=start_time,
end_timestamp=start_time +
start=start_time,
end=start_time +
datetime.timedelta(seconds=4)),
generator=generator)}
for metric_name in ["Number of targets",
Expand Down
12 changes: 7 additions & 5 deletions stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea

# Test longest_track_time_on_truth
assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[0]) == 2
assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 1
# Associations 1 and 2 (starting from 0) will join together
# because of the AssociationSet._simplify method, so this will be 2
assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 2
assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[2]) == 1

# Test compute_metric
Expand All @@ -88,8 +90,8 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea

for metric in metrics:
assert isinstance(metric, TimeRangeMetric)
assert metric.time_range.start_timestamp == timestamps[0]
assert metric.time_range.end_timestamp == timestamps[3]
assert metric.time_range.start == timestamps[0]
assert metric.time_range.end == timestamps[3]
assert metric.generator == siap_generator

if metric.title.endswith(" at times"):
Expand Down Expand Up @@ -173,8 +175,8 @@ def test_id_siap(trial_manager, trial_truths, trial_tracks, trial_associations,

for metric in metrics:
assert isinstance(metric, TimeRangeMetric)
assert metric.time_range.start_timestamp == timestamps[0]
assert metric.time_range.end_timestamp == timestamps[3]
assert metric.time_range.start == timestamps[0]
assert metric.time_range.end == timestamps[3]
assert metric.generator == siap_generator

if metric.title.endswith(" at times"):
Expand Down
Loading

0 comments on commit 8c4fcf2

Please sign in to comment.