From ffd7fc1db6bfaf36dfc3b917c667ac7bd0fcfb96 Mon Sep 17 00:00:00 2001 From: spike Date: Mon, 9 Sep 2024 09:18:04 +0100 Subject: [PATCH] Fix to test_message_destinations. Add test_update_messages() to test extra scenarios. --- stonesoup/architecture/__init__.py | 2 - stonesoup/architecture/generator.py | 6 +-- stonesoup/architecture/tests/test_edge.py | 53 +++++++++++++------ .../architecture/tests/test_generator.py | 28 +++++----- stonesoup/architecture/tests/test_node.py | 8 ++- 5 files changed, 60 insertions(+), 37 deletions(-) diff --git a/stonesoup/architecture/__init__.py b/stonesoup/architecture/__init__.py index f55493a17..8518fd4af 100644 --- a/stonesoup/architecture/__init__.py +++ b/stonesoup/architecture/__init__.py @@ -513,8 +513,6 @@ def __init__(self, *args, **kwargs): # Need to reset digraph for info-arch self.di_graph = nx.to_networkx_graph(self.edges.edge_list, create_using=nx.DiGraph) # Set attributes such as label, colour, shape, etc. for each node - node_label_gens = {} - labels = {node.label.replace("\n", " ") for node in self.di_graph.nodes if node.label} for node in self.di_graph.nodes: self.di_graph.nodes[node].update(self._node_kwargs(node)) diff --git a/stonesoup/architecture/generator.py b/stonesoup/architecture/generator.py index 6df48f851..f76239efb 100644 --- a/stonesoup/architecture/generator.py +++ b/stonesoup/architecture/generator.py @@ -202,7 +202,7 @@ def _generate_edgelist(self): else: valid = True - elif self.arch_type == 'decentralised': + else: while not valid: @@ -230,10 +230,6 @@ def _generate_edgelist(self): else: valid = True - else: - raise ValueError(f"Invalid architecture type of {self.arch_type}. arch_type must be " - "one of: 'hierarchical' or 'decentralised'") - return edges, nodes diff --git a/stonesoup/architecture/tests/test_edge.py b/stonesoup/architecture/tests/test_edge.py index 8d91eb8e2..0a5638af3 100644 --- a/stonesoup/architecture/tests/test_edge.py +++ b/stonesoup/architecture/tests/test_edge.py @@ -2,10 +2,11 @@ import pytest -from .. import RepeaterNode +from .. import RepeaterNode, Node from ..edge import Edges, Edge, DataPiece, Message, FusionQueue from ...types.track import Track from ...types.time import CompoundTimeRange, TimeRange +from .._functions import _dict_set from datetime import timedelta @@ -69,6 +70,42 @@ def test_send_update_message(edges, times, data_pieces): assert message in edge.messages_held['received'][times['a']] +def test_update_messages(): + + # Test scenario where message has not yet reached recipient + A = Node(label='A') + B = Node(label='B') + edge = Edge((A, B), edge_latency=0.5) + + time_created = datetime.datetime.now() + time_sent = time_created + data = DataPiece(A, A, Track([]), time_created) + + # Add message to edge + edge.send_message(data, time_created, time_sent) + assert len(edge.messages_held['pending']) == 1 + + # Message should not have arrived yet + edge.update_messages(time_created) + assert len(edge.messages_held['pending']) == 1 + + # Try again a secomd later + edge.update_messages(time_created + timedelta(seconds=1)) + assert len(edge.messages_held['pending']) == 0 + + # Test scenario when message has no destinations + message = Message(edge, time_created, time_sent, data, destinations=None) + _, edge.messages_held = _dict_set(edge.messages_held, + message, + 'pending', + message.arrival_time) + + assert message in edge.messages_held['pending'][message.arrival_time] + + edge.update_messages(time_created + datetime.timedelta(seconds=2)) + assert len(edge.messages_held['pending']) == 0 + + def test_failed(edges, times): edge = edges['a'] assert edge.time_range_failed == CompoundTimeRange() @@ -176,11 +213,6 @@ def test_message_destinations(times, radar_nodes): DataPiece(node1, node1, Track([]), datetime.datetime(2016, 1, 2, 3, 4, 5)), destinations={node2, node3}) - - # Another message like 1, but this will not be put through Edge.pass_message - message1b = Message(edge1, datetime.datetime(2016, 1, 2, 3, 4, 5), start_time+timedelta(seconds=1), - DataPiece(node1, node1, Track([]), - datetime.datetime(2016, 1, 2, 3, 4, 5))) # Add messages to node1.messages_to_pass_on and check that unpassed_data() catches it node1.messages_to_pass_on = [message1, message2, message3, message4] @@ -204,17 +236,10 @@ def test_message_destinations(times, radar_nodes): assert node2.messages_to_pass_on == [] assert node3.messages_to_pass_on == [] - # Add message without destination to edge1 - edge1.unpassed_data.append(message1b) - assert message1b.destinations is None - # Update both edges edge1.update_messages(start_time+datetime.timedelta(minutes=1), to_network_node=False) edge2.update_messages(start_time + datetime.timedelta(minutes=1), to_network_node=True) - m1b = {m for m in edge1.messages_held['pending'][start_time+timedelta(seconds=1)]}.pop() - assert m1b == {edge1.recipient} - # Check node2.messages_to_pass_on contains message3 that does not have node 2 as a destination assert len(node2.messages_to_pass_on) == 2 # Check node3.messages_to_pass_on contains all messages as it is not in information arch @@ -227,8 +252,6 @@ def test_message_destinations(times, radar_nodes): assert len(data_held) == 3 - - def test_unpassed_data(times): start_time = times['start'] node1 = RepeaterNode() diff --git a/stonesoup/architecture/tests/test_generator.py b/stonesoup/architecture/tests/test_generator.py index 17cc03b59..b936d5eda 100644 --- a/stonesoup/architecture/tests/test_generator.py +++ b/stonesoup/architecture/tests/test_generator.py @@ -125,13 +125,14 @@ def test_info_generate_invalid(generator_params): mean_deg = 2.5 with pytest.raises(ValueError): - gen = InformationArchitectureGenerator(arch_type='invalid', - start_time=start_time, - mean_degree=mean_deg, - node_ratio=[3, 1, 1], - base_tracker=base_tracker, - base_sensor=base_sensor, - n_archs=2) + InformationArchitectureGenerator(arch_type='invalid', + start_time=start_time, + mean_degree=mean_deg, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor, + n_archs=2) + def test_net_arch_gen_init(generator_params): start_time = generator_params['start_time'] @@ -158,12 +159,13 @@ def test_net_arch_gen_init(generator_params): assert gen.sensor_max_distance == (0, 0) with pytest.raises(ValueError): - NetworkArchitectureGenerator(arch_type='not_valid', - start_time=start_time, - mean_degree=2, - node_ratio=[3, 1, 1], - base_tracker=base_tracker, - base_sensor=base_sensor) + gen = NetworkArchitectureGenerator(arch_type='not_valid', + start_time=start_time, + mean_degree=2, + node_ratio=[3, 1, 1], + base_tracker=base_tracker, + base_sensor=base_sensor) + gen.generate() def test_net_generate_hierarchical(generator_params): diff --git a/stonesoup/architecture/tests/test_node.py b/stonesoup/architecture/tests/test_node.py index 083ec07a3..0dbf52317 100644 --- a/stonesoup/architecture/tests/test_node.py +++ b/stonesoup/architecture/tests/test_node.py @@ -32,8 +32,12 @@ def test_node(data_pieces, times, nodes): assert new_data_piece2.time_arrived == times['a'] with pytest.raises(TypeError): - node.update(times['a'], times['b'], data_pieces['fail'], "fused", track=Track([]), - use_arrival_time=False) + node.update(times['a'], + times['b'], + data_pieces['fail'], + "fused", + track=Track([]), + use_arrival_time=False) def test_sensor_node(nodes):