From 8a4823c7a721a0fe66742af3cb1ad9ae79ba27c0 Mon Sep 17 00:00:00 2001 From: tc269127 Date: Tue, 10 Jan 2023 12:16:30 +0000 Subject: [PATCH 1/2] [PythonAPI] Adding input and output names options to the wraping function --- python/pytorch_to_n2d2/pytorch_interface.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/pytorch_to_n2d2/pytorch_interface.py b/python/pytorch_to_n2d2/pytorch_interface.py index af6df233..332e5920 100755 --- a/python/pytorch_to_n2d2/pytorch_interface.py +++ b/python/pytorch_to_n2d2/pytorch_interface.py @@ -331,6 +331,8 @@ def __exit__(self, exc_type, exc_value, traceback): def wrap(torch_model:torch.nn.Module, input_size: Union[list, tuple], opset_version:int=11, + in_names:list=None, + out_names:list=None, verbose:bool=False) -> Block: """Function generating a ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`. The torch_model is exported to N2D2 via ONNX. @@ -341,6 +343,10 @@ def wrap(torch_model:torch.nn.Module, :type input_size: ``list`` :param opset_version: Opset version used to generate the intermediate ONNX file, default=11 :type opset_version: int, optional + :param in_names: Specify specific names for the network inputs + :type in_names: list, optional + :param out_names: Specify specific names for the network outputs + :type in_names: list, optional :param verbose: Enable the verbose output of torch onnx export, default=False :type verbose: bool, optional :return: A custom ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`. @@ -361,6 +367,8 @@ def wrap(torch_model:torch.nn.Module, dummy_in, raw_model_path, verbose=verbose, + input_names=in_names, + output_names=out_names, export_params=True, opset_version=opset_version, training=torch.onnx.TrainingMode.TRAINING, From 56ffa3f9891a8276ffbbb4f076321e7e6381b7d9 Mon Sep 17 00:00:00 2001 From: cmoineau Date: Wed, 11 Jan 2023 07:51:29 +0000 Subject: [PATCH 2/2] [Docs] Update STM32 export and pruning documentation. --- docs/export/CPP_STM32.rst | 2 -- docs/quant/pruning.rst | 14 +++++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/export/CPP_STM32.rst b/docs/export/CPP_STM32.rst index 347818ca..f56e9d71 100644 --- a/docs/export/CPP_STM32.rst +++ b/docs/export/CPP_STM32.rst @@ -1,8 +1,6 @@ Export: C++/STM32 ================= -**N2D2-IP only: available upon request.** - Export type: ``CPP_STM32`` C++ export for STM32. diff --git a/docs/quant/pruning.rst b/docs/quant/pruning.rst index 204593f7..2e342eab 100644 --- a/docs/quant/pruning.rst +++ b/docs/quant/pruning.rst @@ -15,16 +15,16 @@ Example with Python :members: :inherited-members: -Example of code to use the *PruneCell* in your scripts: +Example of code to use the :py:class:`n2d2.quantizer.PruneCell` in your scripts: .. code-block:: python for cell in model: - ### Add Pruning ### - if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc): - cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct") + ### Add Pruning ### + if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc): + cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct") -Some explanations with the differents options of the *PruneCell*: +Some explanations with the differents options of the :py:class:`n2d2.quantizer.PruneCell` : Pruning mode ^^^^^^^^^^^^ @@ -42,7 +42,7 @@ For example, to update each two epochs, write: n2d2.quantizer.PruneCell(prune_mode="Gradual", threshold=0.3, stepsize=2*DATASET_SIZE) -Where *DATASET_SIZE* is the size of the dataset you are using. +Where ``DATASET_SIZE`` is the size of the dataset you are using. Pruning filler ^^^^^^^^^^^^^^ @@ -53,7 +53,7 @@ Pruning filler - IterNonStruct: all weights below than the ``delta`` factor are pruned. If this is not enough to reach ``threshold``, all the weights below 2 "delta" are pruned and so on... -**Important**: With *PruneCell*, ``quant_mode`` and ``range`` are not used. +**Important**: With :py:class:`n2d2.quantizer.PruneCell`, ``quant_mode`` and ``range`` are not used. Example with INI file