Skip to content

Commit

Permalink
Properly copy DVSLayer when instantiating DynapcnnNetwork. Fix DVS in…
Browse files Browse the repository at this point in the history
…put unit tests.
  • Loading branch information
bauerfe committed Nov 12, 2024
1 parent 8897d24 commit 2680ac1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
2 changes: 1 addition & 1 deletion sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _handle_dvs_input(
# Make a copy of the layer so that the original version is not
# changed in place
new_dvs_layer = deepcopy(self.dvs_layer)
self.name_2_indx_map[self.dvs_layer_index] = new_dvs_layer
self._indx_2_module_map[self.dvs_layer_index] = new_dvs_layer
elif dvs_input:
# Insert a DVSLayer node in the graph.
new_dvs_layer = self._add_dvs_node(dvs_input_shape=input_shape)
Expand Down
58 changes: 39 additions & 19 deletions tests/test_dynapcnn/test_dvs_input.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""This should test cases of dynapcnn compatible networks with dvs input."""

from itertools import product
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import numpy as np
import pytest
import samna
import torch
from torch import nn

Expand All @@ -30,7 +28,7 @@ def verify_dvs_config(
origin: Tuple[int, int] = (0, 0),
cut: Optional[Tuple[int, int]] = None,
destination: Optional[int] = None,
dvs_input: bool = True,
dvs_input: Union[bool, None] = True,
flip: Optional[dict] = None,
merge_polarities: bool = False,
):
Expand Down Expand Up @@ -80,9 +78,12 @@ def forward(self, x):


class NetPool2D(nn.Module):
def __init__(self, input_layer: bool = False):
def __init__(self, add_input_layer: bool = False):
super().__init__()
layers = []
if add_input_layer:
layers = [DVSLayer(input_shape=INPUT_SHAPE[1:])]
else:
layers = []
layers += [
nn.AvgPool2d(kernel_size=(2, 4)),
nn.Conv2d(2, 4, kernel_size=2, stride=2),
Expand Down Expand Up @@ -128,36 +129,48 @@ def test_dvs_no_pooling(dvs_input):
)


@pytest.mark.parametrize("dvs_input", (False, True))
def test_dvs_pooling_2d(dvs_input):
args = product((True, False, None), (True, False))
@pytest.mark.parametrize("dvs_input,add_input_layer", args)
def test_dvs_pooling_2d(dvs_input, add_input_layer):
# - ANN and SNN generation
ann = NetPool2D(input_layer=True)
ann = NetPool2D(add_input_layer=add_input_layer)
snn = from_model(ann.seq, batch_size=1)
snn.eval()

# - SPN generation
spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE)
if not dvs_input and not add_input_layer:
# No DVS layer is part of the SNN nor being added to it. The pooling layer should cause an exception
with pytest.raises(InvalidGraphStructure):
spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE)
return

# When there is pooling, a DVSLayer should also be added if `dvs_input` is False
# If `add_input_layer` is False but `dvs_input` is `True`, a DVS layer will
# be added to the DynapcnnNetwork upon instantiation
spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE)
assert spn.has_dvs_layer()

# - Make sure missing input shapes cause exception
with pytest.raises(InputConfigurationError):
spn = DynapcnnNetwork(snn, dvs_input=dvs_input)
if not add_input_layer:
# - Make sure missing input shapes cause exception
with pytest.raises(InputConfigurationError):
spn = DynapcnnNetwork(snn, dvs_input=dvs_input)

# - Compare snn and spn outputs
spn_float = DynapcnnNetwork(snn, discretize=False, input_shape=INPUT_SHAPE)
# - Compare snn and spn outputs. - Always add DVS so that pooling layer is properly handled
spn_float = DynapcnnNetwork(snn, dvs_input=True, discretize=False, input_shape=INPUT_SHAPE)
snn_out = snn(input_data).squeeze()
spn_out = spn_float(input_data).squeeze()
assert torch.equal(snn_out.detach(), spn_out)

# - Verify DYNAP-CNN config
target_layers = [5]
config = spn.make_config(chip_layers_ordering=target_layers)
# Get index of only DynapcnnLayer to map it to core 5
cnn_layer_idx = next(spn.dynapcnn_layers.__iter__())
target_dest = 5
config = spn.make_config(layer2core_map={cnn_layer_idx: target_dest})
if dvs_input is None:
dvs_input = not snn.spiking_model[0].disable_pixel_array
verify_dvs_config(
config,
input_shape=INPUT_SHAPE,
destination=target_layers[0],
destination=target_dest,
dvs_input=dvs_input,
pooling=(2, 4),
)
Expand Down Expand Up @@ -290,6 +303,13 @@ def test_whether_dvs_mirror_cfg_is_all_switched_off(dvs_input, pool):

snn = nn.Sequential(*layer_list)

if pool and not dvs_input:
with pytest.raises(InvalidGraphStructure):
dynapcnn = DynapcnnNetwork(
snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True
)
return

dynapcnn = DynapcnnNetwork(
snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True
)
Expand Down

0 comments on commit 2680ac1

Please sign in to comment.