From 1cb7b5ca4bf1632f3e882a56b83c61eea06f7e75 Mon Sep 17 00:00:00 2001 From: Felix Bauer Date: Tue, 12 Nov 2024 14:48:48 +0100 Subject: [PATCH] Reintroduce missing methods of DynapcnnNetwork: `reset_states`, `zero_grad` --- sinabs/backend/dynapcnn/dynapcnn_layer.py | 2 +- sinabs/backend/dynapcnn/dynapcnn_network.py | 70 +++++++++++++++++++-- 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/sinabs/backend/dynapcnn/dynapcnn_layer.py b/sinabs/backend/dynapcnn/dynapcnn_layer.py index ea573701..76057854 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_layer.py +++ b/sinabs/backend/dynapcnn/dynapcnn_layer.py @@ -188,7 +188,7 @@ def forward(self, x) -> List[torch.Tensor]: def zero_grad(self, set_to_none: bool = False) -> None: """Call `zero_grad` method of spiking layer""" - return self._spk.zero_grad(set_to_none) + return self.spk.zero_grad(set_to_none) def get_neuron_shape(self) -> Tuple[int, int, int]: """Return the output shape of the neuron layer. diff --git a/sinabs/backend/dynapcnn/dynapcnn_network.py b/sinabs/backend/dynapcnn/dynapcnn_network.py index 2d1ef082..e27bafa7 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_network.py +++ b/sinabs/backend/dynapcnn/dynapcnn_network.py @@ -53,10 +53,20 @@ def __init__( - batch_size (optional int): If `None`, will try to infer the batch size from the model. If int value is provided, it has to match the actual batch size of the model. - dvs_input (bool): optional (default as `None`). Wether or not dynapcnn receive - input from its DVS camera. If a `DVSLayer` is part of `snn` and `dvs_input` is - false, the DVS sensor will be configured but its output will not be sent as input - to the chip. If `dvs_input` is `True` and `snn` does not contain a `DVSLayer`, - it will be added. + input from its DVS camera. + If a `DVSLayer` is part of `snn`... + ... and `dvs_input` is `False`, its `disable_pixel_array` attribute + will be set `True`. This means the DVS sensor will be configured + upon deployment but its output will not be sent as input + ... and `dvs_input` is `None`, the `disable_pixel_array` attribute + of the layer will not be changed. + ... and `dvs_input` is `True`, `disable_pixel_array` will be set + `False`, so that the DVS sensor data is sent to the network. + If no `DVSLayer` is part of `snn`... + ... and `dvs_input` is `False` or `None`, no `DVSLayer` will be added + and the DVS sensor will not be configured upon deployment. + ... and `dvs_input` is `True`, a `DVSLayer` instance will be added + to the network, with `disable_pixel_array` set to `False`. - discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading weights to dynapcnn. Set to `False` only for testing purposes. - weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to @@ -514,6 +524,58 @@ def has_dvs_layer(self) -> bool: """ return self.dvs_layer is not None + def zero_grad(self, set_to_none: bool = False) -> None: + """ Call `zero_grad` method of each DynapCNN layer + + Parameters + ---------- + - set_to_none (bool): This argument is passed directly to the + `zero_grad` method of each DynapCNN layer + """ + for lyr in self.dynapcnn_layers.values(): + lyr.zero_grad(set_to_none) + + def reset_states(self, randomize=False): + """Reset the states of the network. + + Parameters + ---------- + - randomize (bool): If `False` (default), will set all states to 0. + Otherwise will set to random values. + + Notes + ----- + - Setting `randomize` to `True` is only supported for models that have + not yet been deployed on a SynSense device. + """ + if hasattr(self, "device") and isinstance(self.device, str): # pragma: no cover + device_name, _ = parse_device_id(self.device) + # Reset states on SynSense device + if device_name in ChipFactory.supported_devices: + config_builder = ChipFactory(self.device).get_config_builder() + # Set all the vmem states in the samna config to zero + config_builder.reset_states(self.samna_config, randomize=randomize) + self.samna_device.get_model().apply_configuration(self.samna_config) + # wait for the config to be written + time.sleep(1) + # Note: The below shouldn't be necessary ideally + # Erase all vmem memory + if not randomize: + if hasattr(self, "samna_input_graph"): + self.samna_input_graph.stop() + for lyr_idx in self.chip_layers_ordering: + config_builder.set_all_v_mem_to_zeros( + self.samna_device, lyr_idx + ) + time.sleep(0.1) + self.samna_input_graph.start() + return + + # Reset states of `DynapcnnLayer` instances + for layer in self.sequence: + if isinstance(layer, DynapcnnLayer): + layer.spk_layer.reset_states(randomize=randomize) + ####################################################### Private Methods ####################################################### def _make_config(