From 2680ac15413d0b347c2b3bd779e872552327d839 Mon Sep 17 00:00:00 2001 From: Felix Bauer Date: Tue, 12 Nov 2024 14:29:19 +0100 Subject: [PATCH] Properly copy DVSLayer when instantiating DynapcnnNetwork. Fix DVS input unit tests. --- .../backend/dynapcnn/nir_graph_extractor.py | 2 +- tests/test_dynapcnn/test_dvs_input.py | 58 +++++++++++++------ 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/sinabs/backend/dynapcnn/nir_graph_extractor.py b/sinabs/backend/dynapcnn/nir_graph_extractor.py index 51d12c17..d942ca33 100644 --- a/sinabs/backend/dynapcnn/nir_graph_extractor.py +++ b/sinabs/backend/dynapcnn/nir_graph_extractor.py @@ -424,7 +424,7 @@ def _handle_dvs_input( # Make a copy of the layer so that the original version is not # changed in place new_dvs_layer = deepcopy(self.dvs_layer) - self.name_2_indx_map[self.dvs_layer_index] = new_dvs_layer + self._indx_2_module_map[self.dvs_layer_index] = new_dvs_layer elif dvs_input: # Insert a DVSLayer node in the graph. new_dvs_layer = self._add_dvs_node(dvs_input_shape=input_shape) diff --git a/tests/test_dynapcnn/test_dvs_input.py b/tests/test_dynapcnn/test_dvs_input.py index 0eae07c3..61d3507c 100644 --- a/tests/test_dynapcnn/test_dvs_input.py +++ b/tests/test_dynapcnn/test_dvs_input.py @@ -1,11 +1,9 @@ """This should test cases of dynapcnn compatible networks with dvs input.""" from itertools import product -from typing import Optional, Tuple +from typing import Optional, Tuple, Union -import numpy as np import pytest -import samna import torch from torch import nn @@ -30,7 +28,7 @@ def verify_dvs_config( origin: Tuple[int, int] = (0, 0), cut: Optional[Tuple[int, int]] = None, destination: Optional[int] = None, - dvs_input: bool = True, + dvs_input: Union[bool, None] = True, flip: Optional[dict] = None, merge_polarities: bool = False, ): @@ -80,9 +78,12 @@ def forward(self, x): class NetPool2D(nn.Module): - def __init__(self, input_layer: bool = False): + def __init__(self, add_input_layer: bool = False): super().__init__() - layers = [] + if add_input_layer: + layers = [DVSLayer(input_shape=INPUT_SHAPE[1:])] + else: + layers = [] layers += [ nn.AvgPool2d(kernel_size=(2, 4)), nn.Conv2d(2, 4, kernel_size=2, stride=2), @@ -128,36 +129,48 @@ def test_dvs_no_pooling(dvs_input): ) -@pytest.mark.parametrize("dvs_input", (False, True)) -def test_dvs_pooling_2d(dvs_input): +args = product((True, False, None), (True, False)) +@pytest.mark.parametrize("dvs_input,add_input_layer", args) +def test_dvs_pooling_2d(dvs_input, add_input_layer): # - ANN and SNN generation - ann = NetPool2D(input_layer=True) + ann = NetPool2D(add_input_layer=add_input_layer) snn = from_model(ann.seq, batch_size=1) snn.eval() # - SPN generation - spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) + if not dvs_input and not add_input_layer: + # No DVS layer is part of the SNN nor being added to it. The pooling layer should cause an exception + with pytest.raises(InvalidGraphStructure): + spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) + return - # When there is pooling, a DVSLayer should also be added if `dvs_input` is False + # If `add_input_layer` is False but `dvs_input` is `True`, a DVS layer will + # be added to the DynapcnnNetwork upon instantiation + spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) assert spn.has_dvs_layer() - # - Make sure missing input shapes cause exception - with pytest.raises(InputConfigurationError): - spn = DynapcnnNetwork(snn, dvs_input=dvs_input) + if not add_input_layer: + # - Make sure missing input shapes cause exception + with pytest.raises(InputConfigurationError): + spn = DynapcnnNetwork(snn, dvs_input=dvs_input) - # - Compare snn and spn outputs - spn_float = DynapcnnNetwork(snn, discretize=False, input_shape=INPUT_SHAPE) + # - Compare snn and spn outputs. - Always add DVS so that pooling layer is properly handled + spn_float = DynapcnnNetwork(snn, dvs_input=True, discretize=False, input_shape=INPUT_SHAPE) snn_out = snn(input_data).squeeze() spn_out = spn_float(input_data).squeeze() assert torch.equal(snn_out.detach(), spn_out) # - Verify DYNAP-CNN config - target_layers = [5] - config = spn.make_config(chip_layers_ordering=target_layers) + # Get index of only DynapcnnLayer to map it to core 5 + cnn_layer_idx = next(spn.dynapcnn_layers.__iter__()) + target_dest = 5 + config = spn.make_config(layer2core_map={cnn_layer_idx: target_dest}) + if dvs_input is None: + dvs_input = not snn.spiking_model[0].disable_pixel_array verify_dvs_config( config, input_shape=INPUT_SHAPE, - destination=target_layers[0], + destination=target_dest, dvs_input=dvs_input, pooling=(2, 4), ) @@ -290,6 +303,13 @@ def test_whether_dvs_mirror_cfg_is_all_switched_off(dvs_input, pool): snn = nn.Sequential(*layer_list) + if pool and not dvs_input: + with pytest.raises(InvalidGraphStructure): + dynapcnn = DynapcnnNetwork( + snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True + ) + return + dynapcnn = DynapcnnNetwork( snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True )