-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #664 dstl/association_assignment
One to one Assignment for Association
- Loading branch information
Showing
15 changed files
with
976 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import copy | ||
|
||
import numpy as np | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
from ..types.association import AssociationSet | ||
|
||
|
||
def multidimensional_deconfliction(association_set): | ||
"""Solves the Multidimensional Assignment Problem (MAP) | ||
The assignment problem becomes more complex when time is added as a dimension. | ||
This basic solution finds all the conflicts in an association set and then creates a | ||
matrix of sums of conflicts in seconds, which is then passed to linear_sum_assignment to | ||
solve as a simple 2D assignment problem. | ||
Therefore, each object will only ever be assigned to one other | ||
at any one time. In the case of an association that only partially overlaps, the time range | ||
of the "weaker" one (the one eliminated by assign2D) will be trimmed | ||
until there is no conflict. | ||
Due to the possibility of more than two conflicting associations at the same time, | ||
this algorithm is recursive, but it is not expected many (if any) recursions will be required | ||
for most uses. | ||
Parameters | ||
---------- | ||
association_set: The :class:`AssociationSet` to de-conflict | ||
Returns | ||
------- | ||
: :class:`AssociationSet` | ||
The association set without contradictory associations | ||
""" | ||
if check_if_no_conflicts(association_set): | ||
return copy.copy(association_set) | ||
|
||
objects = list(association_set.object_set) | ||
length = len(objects) | ||
totals = np.zeros((length, length)) # Time objects i and j are associated for in total | ||
|
||
for association in association_set.associations: | ||
if len(association.objects) != 2: | ||
raise ValueError("Supplied set must only contain pairs of associated objects") | ||
i, j = (objects.index(object_) for object_ in association.objects) | ||
totals[i, j] = association.time_range.duration.total_seconds() | ||
totals = np.maximum(totals, totals.transpose()) # make symmetric | ||
|
||
totals = np.rint(totals).astype(int) | ||
np.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself | ||
solved_2d = linear_sum_assignment(totals, maximize=True)[1] | ||
cleaned_set = AssociationSet() | ||
association_set_reduced = copy.copy(association_set) | ||
for i, j in enumerate(solved_2d): | ||
if i == j: | ||
# Can't associate with self | ||
continue | ||
try: | ||
assoc = next(iter(association_set_reduced.associations_including_objects( | ||
{objects[i], objects[j]}))) # There should only be 1 association in this set | ||
except StopIteration: | ||
# We took the association out previously in the loop | ||
continue | ||
if all(assoc.duration > clean_assoc.duration or not conflicts(assoc, clean_assoc) | ||
for clean_assoc in cleaned_set): | ||
cleaned_set.add(copy.copy(assoc)) | ||
association_set_reduced.remove(assoc) | ||
|
||
if len(cleaned_set) == 0: | ||
raise ValueError("Problem unsolvable using this method") | ||
|
||
if len(association_set_reduced) == 0: | ||
if check_if_no_conflicts(cleaned_set): | ||
raise RuntimeError("Conflicts still present in cleaned set") | ||
# If no conflicts after this iteration and all objects return | ||
return cleaned_set | ||
else: | ||
# Recursive step | ||
runners_up = multidimensional_deconfliction(association_set_reduced).associations | ||
|
||
for runner_up in runners_up: | ||
runner_up_remaining_time = runner_up.time_range | ||
for winner in cleaned_set: | ||
if conflicts(runner_up, winner): | ||
runner_up_remaining_time -= winner.time_range | ||
if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0: | ||
runner_up_copy = copy.copy(runner_up) | ||
runner_up_copy.time_range = runner_up_remaining_time | ||
cleaned_set.add(runner_up_copy) | ||
return cleaned_set | ||
|
||
|
||
def conflicts(assoc1, assoc2): | ||
if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ | ||
len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ | ||
(assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0 and \ | ||
assoc1 != assoc2: | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def check_if_no_conflicts(association_set): | ||
for assoc1 in range(0, len(association_set)): | ||
for assoc2 in range(assoc1, len(association_set)): | ||
if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]): | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from ...types.association import AssociationSet, TimeRangeAssociation | ||
from ...types.time import TimeRange, CompoundTimeRange | ||
from ...types.track import Track | ||
from .._assignment import multidimensional_deconfliction | ||
import datetime | ||
import pytest | ||
|
||
|
||
def is_assoc_in_assoc_set(assoc, assoc_set): | ||
return any(assoc.time_range == set_assoc.time_range and | ||
assoc.objects == set_assoc.objects for set_assoc in assoc_set) | ||
|
||
|
||
def test_multi_deconfliction(): | ||
test = AssociationSet() | ||
tested = multidimensional_deconfliction(test) | ||
assert test.associations == tested.associations | ||
tracks = [Track(id=0), Track(id=1), Track(id=2), Track(id=3)] | ||
times = [datetime.datetime(year=2022, month=6, day=1, hour=0), | ||
datetime.datetime(year=2022, month=6, day=1, hour=1), | ||
datetime.datetime(year=2022, month=6, day=1, hour=5), | ||
datetime.datetime(year=2022, month=6, day=2, hour=5), | ||
datetime.datetime(year=2022, month=6, day=1, hour=9)] | ||
ranges = [TimeRange(times[0], times[1]), | ||
TimeRange(times[1], times[2]), | ||
TimeRange(times[2], times[3]), | ||
TimeRange(times[0], times[4]), | ||
TimeRange(times[2], times[4])] | ||
|
||
assoc1 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=ranges[0]) | ||
assoc2 = TimeRangeAssociation({tracks[2], tracks[3]}, time_range=ranges[0]) | ||
assoc3 = TimeRangeAssociation({tracks[0], tracks[3]}, | ||
time_range=CompoundTimeRange([ranges[0], ranges[4]])) | ||
assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, | ||
time_range=CompoundTimeRange([ranges[1], ranges[4]])) | ||
a4_clone = TimeRangeAssociation({tracks[0], tracks[1]}, | ||
time_range=CompoundTimeRange([ranges[1], ranges[4]])) | ||
# Will fail as there is only one track, rather than two | ||
assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) | ||
with pytest.raises(ValueError): | ||
multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2})) | ||
|
||
# Objects do not conflict, so should return input association set | ||
test2 = AssociationSet({assoc1, assoc2}) | ||
assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2} | ||
|
||
# Objects do conflict, so remove the shorter time range | ||
test3 = AssociationSet({assoc1, assoc3}) | ||
# Should entirely remove assoc1 | ||
tested3 = multidimensional_deconfliction(test3) | ||
assert len(tested3) == 1 | ||
test_assoc3 = next(iter(tested3.associations)) | ||
for var in vars(test_assoc3): | ||
assert getattr(test_assoc3, var) == getattr(assoc3, var) | ||
|
||
test4 = AssociationSet({assoc1, assoc2, assoc3, assoc4}) | ||
# assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain | ||
tested4 = multidimensional_deconfliction(test4) | ||
assert len(tested4) == 2 | ||
assert is_assoc_in_assoc_set(assoc2, tested4) | ||
merged = tested4.associations_including_objects({tracks[0], tracks[1]}) | ||
assert len(merged) == 1 | ||
merged = next(iter(merged.associations)) | ||
assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) | ||
|
||
test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, a4_clone}) | ||
# Very similar to above, but we add a duplicate assoc4 - should have no effect on the result. | ||
tested5 = multidimensional_deconfliction(test5) | ||
assert len(tested5) == 2 | ||
assert is_assoc_in_assoc_set(assoc2, tested5) | ||
merged = tested5.associations_including_objects({tracks[0], tracks[1]}) | ||
assert len(merged) == 1 | ||
merged = next(iter(merged.associations)) | ||
assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.