diff --git a/docs/export/CPP_STM32.rst b/docs/export/CPP_STM32.rst index 347818ca..f56e9d71 100644 --- a/docs/export/CPP_STM32.rst +++ b/docs/export/CPP_STM32.rst @@ -1,8 +1,6 @@ Export: C++/STM32 ================= -**N2D2-IP only: available upon request.** - Export type: ``CPP_STM32`` C++ export for STM32. diff --git a/docs/quant/pruning.rst b/docs/quant/pruning.rst index 204593f7..2e342eab 100644 --- a/docs/quant/pruning.rst +++ b/docs/quant/pruning.rst @@ -15,16 +15,16 @@ Example with Python :members: :inherited-members: -Example of code to use the *PruneCell* in your scripts: +Example of code to use the :py:class:`n2d2.quantizer.PruneCell` in your scripts: .. code-block:: python for cell in model: - ### Add Pruning ### - if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc): - cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct") + ### Add Pruning ### + if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc): + cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct") -Some explanations with the differents options of the *PruneCell*: +Some explanations with the differents options of the :py:class:`n2d2.quantizer.PruneCell` : Pruning mode ^^^^^^^^^^^^ @@ -42,7 +42,7 @@ For example, to update each two epochs, write: n2d2.quantizer.PruneCell(prune_mode="Gradual", threshold=0.3, stepsize=2*DATASET_SIZE) -Where *DATASET_SIZE* is the size of the dataset you are using. +Where ``DATASET_SIZE`` is the size of the dataset you are using. Pruning filler ^^^^^^^^^^^^^^ @@ -53,7 +53,7 @@ Pruning filler - IterNonStruct: all weights below than the ``delta`` factor are pruned. If this is not enough to reach ``threshold``, all the weights below 2 "delta" are pruned and so on... -**Important**: With *PruneCell*, ``quant_mode`` and ``range`` are not used. +**Important**: With :py:class:`n2d2.quantizer.PruneCell`, ``quant_mode`` and ``range`` are not used. Example with INI file diff --git a/python/pytorch_to_n2d2/pytorch_interface.py b/python/pytorch_to_n2d2/pytorch_interface.py index af6df233..332e5920 100755 --- a/python/pytorch_to_n2d2/pytorch_interface.py +++ b/python/pytorch_to_n2d2/pytorch_interface.py @@ -331,6 +331,8 @@ def __exit__(self, exc_type, exc_value, traceback): def wrap(torch_model:torch.nn.Module, input_size: Union[list, tuple], opset_version:int=11, + in_names:list=None, + out_names:list=None, verbose:bool=False) -> Block: """Function generating a ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`. The torch_model is exported to N2D2 via ONNX. @@ -341,6 +343,10 @@ def wrap(torch_model:torch.nn.Module, :type input_size: ``list`` :param opset_version: Opset version used to generate the intermediate ONNX file, default=11 :type opset_version: int, optional + :param in_names: Specify specific names for the network inputs + :type in_names: list, optional + :param out_names: Specify specific names for the network outputs + :type in_names: list, optional :param verbose: Enable the verbose output of torch onnx export, default=False :type verbose: bool, optional :return: A custom ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`. @@ -361,6 +367,8 @@ def wrap(torch_model:torch.nn.Module, dummy_in, raw_model_path, verbose=verbose, + input_names=in_names, + output_names=out_names, export_params=True, opset_version=opset_version, training=torch.onnx.TrainingMode.TRAINING,