Skip to content

Commit

Permalink
Fix to test_message_destinations. Add test_update_messages() to test …
Browse files Browse the repository at this point in the history
…extra scenarios.
  • Loading branch information
spike-dstl committed Sep 9, 2024
1 parent b511737 commit ffd7fc1
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 37 deletions.
2 changes: 0 additions & 2 deletions stonesoup/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,6 @@ def __init__(self, *args, **kwargs):
# Need to reset digraph for info-arch
self.di_graph = nx.to_networkx_graph(self.edges.edge_list, create_using=nx.DiGraph)
# Set attributes such as label, colour, shape, etc. for each node
node_label_gens = {}
labels = {node.label.replace("\n", " ") for node in self.di_graph.nodes if node.label}
for node in self.di_graph.nodes:
self.di_graph.nodes[node].update(self._node_kwargs(node))

Expand Down
6 changes: 1 addition & 5 deletions stonesoup/architecture/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _generate_edgelist(self):
else:
valid = True

elif self.arch_type == 'decentralised':
else:

while not valid:

Expand Down Expand Up @@ -230,10 +230,6 @@ def _generate_edgelist(self):
else:
valid = True

else:
raise ValueError(f"Invalid architecture type of {self.arch_type}. arch_type must be "
"one of: 'hierarchical' or 'decentralised'")

return edges, nodes


Expand Down
53 changes: 38 additions & 15 deletions stonesoup/architecture/tests/test_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pytest

from .. import RepeaterNode
from .. import RepeaterNode, Node
from ..edge import Edges, Edge, DataPiece, Message, FusionQueue
from ...types.track import Track
from ...types.time import CompoundTimeRange, TimeRange
from .._functions import _dict_set

from datetime import timedelta

Expand Down Expand Up @@ -69,6 +70,42 @@ def test_send_update_message(edges, times, data_pieces):
assert message in edge.messages_held['received'][times['a']]


def test_update_messages():

# Test scenario where message has not yet reached recipient
A = Node(label='A')
B = Node(label='B')
edge = Edge((A, B), edge_latency=0.5)

time_created = datetime.datetime.now()
time_sent = time_created
data = DataPiece(A, A, Track([]), time_created)

# Add message to edge
edge.send_message(data, time_created, time_sent)
assert len(edge.messages_held['pending']) == 1

# Message should not have arrived yet
edge.update_messages(time_created)
assert len(edge.messages_held['pending']) == 1

# Try again a secomd later
edge.update_messages(time_created + timedelta(seconds=1))
assert len(edge.messages_held['pending']) == 0

# Test scenario when message has no destinations
message = Message(edge, time_created, time_sent, data, destinations=None)
_, edge.messages_held = _dict_set(edge.messages_held,
message,
'pending',
message.arrival_time)

assert message in edge.messages_held['pending'][message.arrival_time]

edge.update_messages(time_created + datetime.timedelta(seconds=2))
assert len(edge.messages_held['pending']) == 0


def test_failed(edges, times):
edge = edges['a']
assert edge.time_range_failed == CompoundTimeRange()
Expand Down Expand Up @@ -176,11 +213,6 @@ def test_message_destinations(times, radar_nodes):
DataPiece(node1, node1, Track([]),
datetime.datetime(2016, 1, 2, 3, 4, 5)),
destinations={node2, node3})

# Another message like 1, but this will not be put through Edge.pass_message
message1b = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time+timedelta(seconds=1),
DataPiece(node1, node1, Track([]),
datetime.datetime(2016, 1, 2, 3, 4, 5)))

# Add messages to node1.messages_to_pass_on and check that unpassed_data() catches it
node1.messages_to_pass_on = [message1, message2, message3, message4]
Expand All @@ -204,17 +236,10 @@ def test_message_destinations(times, radar_nodes):
assert node2.messages_to_pass_on == []
assert node3.messages_to_pass_on == []

# Add message without destination to edge1
edge1.unpassed_data.append(message1b)
assert message1b.destinations is None

# Update both edges
edge1.update_messages(start_time+datetime.timedelta(minutes=1), to_network_node=False)
edge2.update_messages(start_time + datetime.timedelta(minutes=1), to_network_node=True)

m1b = {m for m in edge1.messages_held['pending'][start_time+timedelta(seconds=1)]}.pop()
assert m1b == {edge1.recipient}

# Check node2.messages_to_pass_on contains message3 that does not have node 2 as a destination
assert len(node2.messages_to_pass_on) == 2
# Check node3.messages_to_pass_on contains all messages as it is not in information arch
Expand All @@ -227,8 +252,6 @@ def test_message_destinations(times, radar_nodes):
assert len(data_held) == 3




def test_unpassed_data(times):
start_time = times['start']
node1 = RepeaterNode()
Expand Down
28 changes: 15 additions & 13 deletions stonesoup/architecture/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ def test_info_generate_invalid(generator_params):
mean_deg = 2.5

with pytest.raises(ValueError):
gen = InformationArchitectureGenerator(arch_type='invalid',
start_time=start_time,
mean_degree=mean_deg,
node_ratio=[3, 1, 1],
base_tracker=base_tracker,
base_sensor=base_sensor,
n_archs=2)
InformationArchitectureGenerator(arch_type='invalid',
start_time=start_time,
mean_degree=mean_deg,
node_ratio=[3, 1, 1],
base_tracker=base_tracker,
base_sensor=base_sensor,
n_archs=2)


def test_net_arch_gen_init(generator_params):
start_time = generator_params['start_time']
Expand All @@ -158,12 +159,13 @@ def test_net_arch_gen_init(generator_params):
assert gen.sensor_max_distance == (0, 0)

with pytest.raises(ValueError):
NetworkArchitectureGenerator(arch_type='not_valid',
start_time=start_time,
mean_degree=2,
node_ratio=[3, 1, 1],
base_tracker=base_tracker,
base_sensor=base_sensor)
gen = NetworkArchitectureGenerator(arch_type='not_valid',
start_time=start_time,
mean_degree=2,
node_ratio=[3, 1, 1],
base_tracker=base_tracker,
base_sensor=base_sensor)
gen.generate()


def test_net_generate_hierarchical(generator_params):
Expand Down
8 changes: 6 additions & 2 deletions stonesoup/architecture/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def test_node(data_pieces, times, nodes):
assert new_data_piece2.time_arrived == times['a']

with pytest.raises(TypeError):
node.update(times['a'], times['b'], data_pieces['fail'], "fused", track=Track([]),
use_arrival_time=False)
node.update(times['a'],
times['b'],
data_pieces['fail'],
"fused",
track=Track([]),
use_arrival_time=False)


def test_sensor_node(nodes):
Expand Down

0 comments on commit ffd7fc1

Please sign in to comment.