diff --git a/.pylintrc b/.pylintrc index f2405979..29a978a8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,3 +1,7 @@ +[BASIC] +# Good variable names which should always be accepted, separated by a comma. +good-names=x,y,z + [DESIGN] # Maximum number of arguments for function / method max-args=12 @@ -16,7 +20,7 @@ max-public-methods=20 max-line-length=120 [MESSAGES] -disable=logging-fstring-interpolation +disable=logging-fstring-interpolation,no-self-use [SIMILARITIES] # Minimum lines number of a similarity. diff --git a/README.md b/README.md index 7839d917..a0e9fcad 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Below you can find some examples of use. ### Convert ```python import torch -from onnx2torch.converter import convert +from onnx2torch import convert # Path to ONNX model onnx_model_path = '/some/path/mobile_net_v2.onnx' @@ -60,8 +60,34 @@ print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7)) ## Models We have tested the following models: -- [x] ResNet50 -- [x] SSDLite with MobileNetV2 backbone + +Segmentation models: +- [x] DeepLabv3plus +- [x] DeepLabv3 resnet50 (torchvision) +- [x] HRNet +- [x] UNet (torchvision) +- [x] FCN resnet50 (torchvision) +- [x] lraspp mobilenetv3 (torchvision) + +Detection from MMdetection: +- [x] [SSDLite with MobileNetV2 backbone](https://github.com/open-mmlab/mmdetection) +- [x] [RetinaNet R50](https://github.com/open-mmlab/mmdetection) +- [x] [SSD300 with VGG backbone](https://github.com/open-mmlab/mmdetection) +- [x] [Yolov3_d53](https://github.com/open-mmlab/mmdetection) +- [x] [Yolov5](https://github.com/ultralytics/yolov5) + +Classification from __torchvision__: +- [x] Resnet18 +- [x] Resnet50 +- [x] MobileNet v2 +- [x] MobileNet v3 large +- [x] EfficientNet_b{0, 1, 2, 3} +- [x] WideResNet50 +- [x] ResNext50 +- [x] VGG16 +- [x] GoogleleNet +- [x] MnasNet +- [x] RegNet ## How to add new operations to converter @@ -86,24 +112,19 @@ If Operation's behaviour differs from one opset version to another, you should i ```python class OnnxExpand(nn.Module): - @staticmethod - def _do_forward(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: - return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device) - - def forward(self, *args) -> torch.Tensor: + def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: + output = input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(*args) - return _ExpandExportToOnnx.set_output_and_apply(output, *args) + return _ExpandExportToOnnx.set_output_and_apply(output, input_tensor, shape) - return self._do_forward(*args) + return output class _ExpandExportToOnnx(CustomExportToOnnx): @staticmethod - def symbolic(graph: torch_C.Graph, *args, **kwargs) -> torch_C.Value: - return graph.op('Expand', *args, **kwargs, outputs=1) + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + return graph.op('Expand', *args, outputs=1) @add_converter(operation_type='Expand', version=8) diff --git a/onnx2torch/__init__.py b/onnx2torch/__init__.py index e69de29b..a7f6aa13 100644 --- a/onnx2torch/__init__.py +++ b/onnx2torch/__init__.py @@ -0,0 +1 @@ +from onnx2torch.converter import convert diff --git a/onnx2torch/converter.py b/onnx2torch/converter.py index 9fd560fa..c0b5ca9d 100644 --- a/onnx2torch/converter.py +++ b/onnx2torch/converter.py @@ -39,29 +39,37 @@ def forward(self, *args, **kwargs): # pylint: disable=no-self-use raise RuntimeError('Got unexpected "forward" on constant container') -def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mapping: bool = False): +def convert( + onnx_model_or_path: Union[str, Path, ModelProto], + save_input_names: bool = False, + attach_onnx_mapping: bool = False, +) -> fx.GraphModule: """Convert model from onnx to PyTorch. This function build torch.fx GraphModule from onnx ModelProto using operations from the converter registry. - The registered operation can be found in onnx2torch/node_converters + The registered operation can be found in onnx2torch/node_converters. Usage example: - from onnx2torch.converter import convert + from onnx2torch import convert torch_module = convert('path/to/onnx_model.onnx') Parameters ---------- - onnx_model_or_path: + onnx_model_or_path : Union[str, Path, ModelProto] Onnx ModelProto or model path to convert. - attach_onnx_mapping: + save_input_names : bool + Whether to use original onnx inputs names as fx graph placeholders names or to use generated names (input_n). + False by default. + attach_onnx_mapping : bool Whether to attach info about mapping to original onnx tensors names. Returns ------- - : + fx.GraphModule PyTorch GraphModule + """ if isinstance(onnx_model_or_path, ModelProto): @@ -90,8 +98,16 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin torch_nodes = {} # create input nodes - for name in onnx_graph.input_values: - torch_nodes[name] = torch_graph.placeholder(name=name) + for i, name in enumerate(onnx_graph.input_values, 1): + if save_input_names: + if not name.isidentifier(): + raise ValueError(f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.') + + placeholder_name = name + else: + placeholder_name = f'input_{i}' + + torch_nodes[name] = torch_graph.placeholder(name=placeholder_name) # create intermediate nodes # IMPORTANT: nodes already topologically sorted @@ -131,10 +147,16 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin args.append(torch_input_node) elif value_type == ValueType.GRAPH_INITIALIZER: + # The name of putorch buffer must not contain '.'(dot) + len_torch_initializers = sum(1 for _ in torch_initializers.buffers()) + torch_buffer_name = f'onnx_initializer_{len_torch_initializers}' if value_name not in torch_nodes: - torch_initializers.add_initializer(value_name, onnx_graph.initializers[value_name].to_torch()) - torch_nodes[value_name] = torch_graph.get_attr(f'initializers.{value_name}') - args.append(torch_nodes[value_name]) + torch_initializers.add_initializer( + torch_buffer_name, + onnx_graph.initializers[value_name].to_torch(), + ) + torch_nodes[torch_buffer_name] = torch_graph.get_attr(f'initializers.{torch_buffer_name}') + args.append(torch_nodes[torch_buffer_name]) elif value_type == ValueType.EMPTY: args.append(None) @@ -147,12 +169,11 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin if None in args: first_skipped_arg = args.index(None) forward_args = tuple(inspect.signature(torch_module.forward).parameters.keys()) - forward_args = forward_args[first_skipped_arg:] - - for arg_name in forward_args: - arg_value = args.pop(first_skipped_arg) - if arg_value is not None: - kwargs[arg_name] = arg_value + forward_args = forward_args[first_skipped_arg:len(args)] + args, kwargs_values = args[:first_skipped_arg], args[first_skipped_arg:] + kwargs.update( + {name: value for name, value in zip(forward_args, kwargs_values) if value is not None} + ) torch_nodes[name] = torch_graph.call_module(module_name=name, args=tuple(args), kwargs=kwargs) diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index e8286636..beee8f5f 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -1,4 +1,5 @@ from onnx2torch.node_converters.activations import * +from onnx2torch.node_converters.average_pool import * from onnx2torch.node_converters.batch_norm import * from onnx2torch.node_converters.binary_math_operations import * from onnx2torch.node_converters.cast import * @@ -10,11 +11,13 @@ from onnx2torch.node_converters.conv import * from onnx2torch.node_converters.expand import * from onnx2torch.node_converters.flatten import * +from onnx2torch.node_converters.functions import * from onnx2torch.node_converters.gather import * from onnx2torch.node_converters.gemm import * from onnx2torch.node_converters.global_average_pool import * from onnx2torch.node_converters.identity import * from onnx2torch.node_converters.logical import * +from onnx2torch.node_converters.matmul import * from onnx2torch.node_converters.max_pool import * from onnx2torch.node_converters.nms import * from onnx2torch.node_converters.pow import * @@ -22,9 +25,12 @@ from onnx2torch.node_converters.reduce import * from onnx2torch.node_converters.reshape import * from onnx2torch.node_converters.resize import * +from onnx2torch.node_converters.roialign import * +from onnx2torch.node_converters.roundings import * from onnx2torch.node_converters.scatter_nd import * from onnx2torch.node_converters.shape import * from onnx2torch.node_converters.slice import * +from onnx2torch.node_converters.split import * from onnx2torch.node_converters.squeeze import * from onnx2torch.node_converters.tile import * from onnx2torch.node_converters.topk import * diff --git a/onnx2torch/node_converters/activations.py b/onnx2torch/node_converters/activations.py index cd641011..bd5dea42 100644 --- a/onnx2torch/node_converters/activations.py +++ b/onnx2torch/node_converters/activations.py @@ -1,22 +1,23 @@ -__all__ = ['OnnxExp', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11'] +__all__ = ['OnnxErf', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11'] import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxExp(nn.Module): +class OnnxErf(nn.Module, OnnxToTorchModule): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - return torch.exp(input_tensor) + return torch.erf(input_tensor) -class OnnxHardSigmoid(nn.Module): +class OnnxHardSigmoid(nn.Module, OnnxToTorchModule): def __init__(self, alpha: float = 0.2, beta: float = 0.5): super().__init__() self.alpha = alpha @@ -26,24 +27,25 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return torch.clip(self.alpha * input_tensor + self.beta, min=0.0, max=1.0) -class OnnxSoftmaxV1V11(nn.Module): - def __init__(self, axis: int = 1): +class OnnxSoftmaxV1V11(nn.Module, OnnxToTorchModule): + def __init__(self, axis: int = 1, is_log: bool = False): super().__init__() self.axis = axis + self.is_log = is_log def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: shape = input_tensor.shape result = torch.flatten(input_tensor, start_dim=self.axis) - result = torch.softmax(result, -1) + result = torch.log_softmax(result, -1) if self.is_log else torch.softmax(result, -1) return torch.reshape(result, shape) -@add_converter(operation_type='Exp', version=6) -@add_converter(operation_type='Exp', version=13) +@add_converter(operation_type='Erf', version=9) +@add_converter(operation_type='Erf', version=13) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument return OperationConverterResult( - torch_module=OnnxExp(), + torch_module=OnnxErf(), onnx_mapping=onnx_mapping_from_node(node=node), ) @@ -71,6 +73,23 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: ) +@add_converter(operation_type='LogSoftmax', version=13) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + return OperationConverterResult( + torch_module=nn.LogSoftmax(dim=node.attributes.get('axis', -1)), + onnx_mapping=onnx_mapping_from_node(node=node), + ) + + +@add_converter(operation_type='LogSoftmax', version=1) +@add_converter(operation_type='LogSoftmax', version=11) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + return OperationConverterResult( + torch_module=OnnxSoftmaxV1V11(axis=node.attributes.get('axis', 1), is_log=True), + onnx_mapping=onnx_mapping_from_node(node=node), + ) + + @add_converter(operation_type='Relu', version=6) @add_converter(operation_type='Relu', version=13) @add_converter(operation_type='Relu', version=14) @@ -102,9 +121,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: @add_converter(operation_type='Softmax', version=13) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument - axis = node.attributes.get('axis', -1) - return OperationConverterResult( - torch_module=torch.nn.Softmax(dim=axis), + torch_module=torch.nn.Softmax(dim=node.attributes.get('axis', -1)), onnx_mapping=onnx_mapping_from_node(node=node), ) diff --git a/onnx2torch/node_converters/average_pool.py b/onnx2torch/node_converters/average_pool.py index 96d62ffa..6413ce8c 100644 --- a/onnx2torch/node_converters/average_pool.py +++ b/onnx2torch/node_converters/average_pool.py @@ -2,59 +2,51 @@ from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import get_shape_from_value_info -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.common import onnx_padding_to_torch_padding _AVGPOOL_CLASS_FROM_SPATIAL_RANK = { 1: nn.AvgPool1d, 2: nn.AvgPool2d, - 3: nn.AvgPool2d, + 3: nn.AvgPool3d, } @add_converter(operation_type='AveragePool', version=7) @add_converter(operation_type='AveragePool', version=10) +@add_converter(operation_type='AveragePool', version=11) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: input_value_info = graph.value_info[node.input_values[0]] input_shape = get_shape_from_value_info(input_value_info) spatial_rank = len(input_shape) - 2 - maxpool_class = _AVGPOOL_CLASS_FROM_SPATIAL_RANK.get(spatial_rank, None) - if maxpool_class is None: - raise NotImplementedError(f'Convolution operation with spatial rank == {spatial_rank} is not implemented') + avgpool_class = _AVGPOOL_CLASS_FROM_SPATIAL_RANK.get(spatial_rank, None) + if avgpool_class is None: + raise NotImplementedError(f'Average pool operation with spatial rank == {spatial_rank} is not implemented') node_attributes = node.attributes + # required + kernel_shape = node_attributes['kernel_shape'] + # optional ceil_mode = node_attributes.get('ceil_mode', 0) - padding = node_attributes.get('pads', [0] * spatial_rank * 2) - kernel_shape = node_attributes.get('kernel_shape', None) strides = node_attributes.get('strides', 1) - storage_order = node_attributes.get('storage_order', 0) - if storage_order != 0: - raise NotImplementedError('Only row major (0) order is supported.') - if kernel_shape is None: - raise RuntimeError('Kernel shape for MaxPool not specified. Kernel shape is mandatory parameters in onnx.') + count_include_pad = node_attributes.get('count_include_pad', 0) - auto_pad = node_attributes.get('auto_pad', 'NOTSET') - if auto_pad == 'NOTSET': - half_len = len(padding) // 2 - if tuple(padding[:half_len]) != tuple(padding[half_len:]): - raise NotImplementedError(f'Only symmetric padding is implemented ({padding})') - - padding = padding[:half_len] - elif auto_pad in ('SAME_UPPER', 'SAME_LOWER', 'VALID'): - raise NotImplementedError(f'"{auto_pad}" auto_pad is not implemented') - else: - raise ValueError(f'Got unexpected auto_pad value "{auto_pad}"') + padding = onnx_padding_to_torch_padding( + node_attributes.get('pads', [0] * spatial_rank * 2), + node_attributes.get('auto_pad', 'NOTSET'), + ) - torch_module = maxpool_class( + torch_module = avgpool_class( kernel_size=kernel_shape, stride=strides, padding=padding, - count_include_pad=False, + count_include_pad=count_include_pad == 1, ceil_mode=ceil_mode == 1, ) diff --git a/onnx2torch/node_converters/batch_norm.py b/onnx2torch/node_converters/batch_norm.py index d32c16fb..ebbff09d 100644 --- a/onnx2torch/node_converters/batch_norm.py +++ b/onnx2torch/node_converters/batch_norm.py @@ -3,12 +3,12 @@ import torch from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import get_shape_from_value_info from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info _BN_CLASS_FROM_SPATIAL_RANK = { 0: nn.BatchNorm1d, diff --git a/onnx2torch/node_converters/binary_math_operations.py b/onnx2torch/node_converters/binary_math_operations.py index d8bd932e..2dc880fe 100644 --- a/onnx2torch/node_converters/binary_math_operations.py +++ b/onnx2torch/node_converters/binary_math_operations.py @@ -7,12 +7,13 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import old_style_broadcast -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import old_style_broadcast +from onnx2torch.utils.common import onnx_mapping_from_node _TORCH_FUNCTION_FROM_ONNX_TYPE = { 'Add': torch.add, @@ -22,7 +23,7 @@ } -class OnnxBinaryMathOperation(nn.Module): +class OnnxBinaryMathOperation(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None): super().__init__() diff --git a/onnx2torch/node_converters/cast.py b/onnx2torch/node_converters/cast.py index f274b569..e9a4deec 100644 --- a/onnx2torch/node_converters/cast.py +++ b/onnx2torch/node_converters/cast.py @@ -4,11 +4,12 @@ from onnx import TensorProto from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node # pylint: disable=no-member TENSOR_TYPE_TO_TORCH_TYPE = { @@ -27,7 +28,7 @@ # pylint: enable=no-member -class OnnxCast(nn.Module): +class OnnxCast(nn.Module, OnnxToTorchModule): def __init__(self, onnx_dtype: int): super().__init__() @@ -44,9 +45,9 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: @add_converter(operation_type='Cast', version=13) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument node_attributes = node.attributes - to = node_attributes.get('to', None) + onnx_dtype = node_attributes.get('to', None) return OperationConverterResult( - torch_module=OnnxCast(to), + torch_module=OnnxCast(onnx_dtype), onnx_mapping=onnx_mapping_from_node(node=node), ) diff --git a/onnx2torch/node_converters/clip.py b/onnx2torch/node_converters/clip.py index 00b6ed16..3a4123ba 100644 --- a/onnx2torch/node_converters/clip.py +++ b/onnx2torch/node_converters/clip.py @@ -5,16 +5,17 @@ import torch from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import get_const_value -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_const_value +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxClip(nn.Module): +class OnnxClip(nn.Module, OnnxToTorchModule): def __init__( self, diff --git a/onnx2torch/node_converters/comparisons.py b/onnx2torch/node_converters/comparisons.py index f2192612..a4fee2dd 100644 --- a/onnx2torch/node_converters/comparisons.py +++ b/onnx2torch/node_converters/comparisons.py @@ -3,11 +3,12 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node _TORCH_FUNCTION_FROM_ONNX_TYPE = { 'Equal': torch.eq, @@ -18,14 +19,14 @@ } -class OnnxCompare(nn.Module): +class OnnxCompare(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str): super().__init__() self.compare_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return self.compare_function(a, b) + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return self.compare_function(x, y) @add_converter(operation_type='Equal', version=7) diff --git a/onnx2torch/node_converters/concat.py b/onnx2torch/node_converters/concat.py index 30346fd0..ff98d1d5 100644 --- a/onnx2torch/node_converters/concat.py +++ b/onnx2torch/node_converters/concat.py @@ -3,14 +3,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxConcat(nn.Module): +class OnnxConcat(nn.Module, OnnxToTorchModule): def __init__(self, axis: int): super().__init__() diff --git a/onnx2torch/node_converters/constant.py b/onnx2torch/node_converters/constant.py index ad56b223..3dfbc661 100644 --- a/onnx2torch/node_converters/constant.py +++ b/onnx2torch/node_converters/constant.py @@ -5,11 +5,12 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node _CONSTANT_PARSING_MAPPING = { 'value': lambda x: x.to_torch(), @@ -22,7 +23,7 @@ } -class OnnxConstant(nn.Module): +class OnnxConstant(nn.Module, OnnxToTorchModule): def __init__(self, value: Any): super().__init__() diff --git a/onnx2torch/node_converters/constant_of_shape.py b/onnx2torch/node_converters/constant_of_shape.py index 979a9da0..6882842d 100644 --- a/onnx2torch/node_converters/constant_of_shape.py +++ b/onnx2torch/node_converters/constant_of_shape.py @@ -5,32 +5,34 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxConstantOfShape(nn.Module): +class OnnxConstantOfShape(nn.Module, OnnxToTorchModule): def __init__(self, value: Optional[torch.Tensor] = None): super().__init__() if value is None: - value = torch.Tensor([0.0]) + value = torch.Tensor(0.0, dtype=torch.float32) if value.numel() != 1: raise ValueError('parameter "value" must be scalar') - self.value = value + self.value: torch.Tensor + self.register_buffer('value', value) def forward(self, shape: torch.Tensor) -> torch.Tensor: return torch.full( size=torch.Size(shape), fill_value=self.value.item(), dtype=self.value.dtype, - device=shape.device, + device=self.value.device, ) diff --git a/onnx2torch/node_converters/conv.py b/onnx2torch/node_converters/conv.py index 8c230b8c..50a33fac 100644 --- a/onnx2torch/node_converters/conv.py +++ b/onnx2torch/node_converters/conv.py @@ -3,11 +3,12 @@ import torch from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_padding_to_torch_padding _CONV_CLASS_FROM_SPATIAL_RANK = { 1: nn.Conv1d, @@ -37,25 +38,17 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: node_attributes = node.attributes kernel_size = node_attributes.get('kernel_shape', weights.shape[2:]) stride = node_attributes.get('strides', 1) - padding = node_attributes.get('pads', [0]*4) dilation = node_attributes.get('dilations', 1) groups = node_attributes.get('group', 1) + padding = onnx_padding_to_torch_padding( + node_attributes.get('pads', [0] * spatial_rank * 2), + node_attributes.get('auto_pad', 'NOTSET'), + ) + out_channels = weights.shape[0] in_channels = weights.shape[1]*groups - auto_pad = node_attributes.get('auto_pad', 'NOTSET') - if auto_pad == 'NOTSET': - half_len = len(padding) // 2 - if tuple(padding[:half_len]) != tuple(padding[half_len:]): - raise NotImplementedError(f'Only symmetric padding is implemented ({padding})') - - padding = padding[:half_len] - elif auto_pad in ('SAME_UPPER', 'SAME_LOWER', 'VALID'): - raise NotImplementedError(f'"{auto_pad}" auto_pad is not implemented') - else: - raise ValueError(f'Got unexpected auto_pad value "{auto_pad}"') - torch_module = conv_class( in_channels=in_channels, out_channels=out_channels, diff --git a/onnx2torch/node_converters/expand.py b/onnx2torch/node_converters/expand.py index b7272024..70d1818a 100644 --- a/onnx2torch/node_converters/expand.py +++ b/onnx2torch/node_converters/expand.py @@ -4,31 +4,26 @@ import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.common import SkipTorchTracing -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxExpand(nn.Module): - - @staticmethod - def _do_forward(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: - return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device) +class OnnxExpand(nn.Module, OnnxToTorchModuleWithCustomExport): def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: + output = input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(input_tensor, shape) - return _ExpandExportToOnnx.set_output_and_apply(output, input_tensor, shape) + return _ExpandExportToOnnx.set_output_and_apply(output, input_tensor, shape) - return self._do_forward(input_tensor, shape) + return output -class _ExpandExportToOnnx(CustomExportToOnnx): +class _ExpandExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: diff --git a/onnx2torch/node_converters/flatten.py b/onnx2torch/node_converters/flatten.py index d4141d96..229881cf 100644 --- a/onnx2torch/node_converters/flatten.py +++ b/onnx2torch/node_converters/flatten.py @@ -1,14 +1,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxFlatten(nn.Module): +class OnnxFlatten(nn.Module, OnnxToTorchModule): def __init__(self, axis: int = 1): super().__init__() diff --git a/onnx2torch/node_converters/functions.py b/onnx2torch/node_converters/functions.py new file mode 100644 index 00000000..60a362f3 --- /dev/null +++ b/onnx2torch/node_converters/functions.py @@ -0,0 +1,59 @@ +__all__ = ['OnnxFunction'] + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node + +# Exporting from pytorch to onnx operators atanh, asinh, acosh, cosh, sinh are not supported +_TORCH_FUNCTION_FROM_ONNX_TYPE = { + 'Abs': torch.abs, + 'Acos': torch.acos, + 'Asin': torch.asin, + 'Atan': torch.atan, + 'Cos': torch.cos, + 'Exp': torch.exp, + 'Log': torch.log, + 'Sign': torch.sign, + 'Sin': torch.sin, + 'Tan': torch.tan, + 'Tanh': torch.tanh, +} + + +class OnnxFunction(nn.Module, OnnxToTorchModule): + + def __init__(self, function_type: str): + super().__init__() + self.function = _TORCH_FUNCTION_FROM_ONNX_TYPE[function_type] + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + return self.function(input_tensor) + + +@add_converter(operation_type='Abs', version=13) +@add_converter(operation_type='Abs', version=6) +@add_converter(operation_type='Acos', version=7) +@add_converter(operation_type='Asin', version=7) +@add_converter(operation_type='Atan', version=7) +@add_converter(operation_type='Cos', version=7) +@add_converter(operation_type='Exp', version=6) +@add_converter(operation_type='Exp', version=13) +@add_converter(operation_type='Log', version=13) +@add_converter(operation_type='Log', version=6) +@add_converter(operation_type='Sign', version=13) +@add_converter(operation_type='Sign', version=9) +@add_converter(operation_type='Sin', version=7) +@add_converter(operation_type='Tan', version=7) +@add_converter(operation_type='Tanh', version=13) +@add_converter(operation_type='Tanh', version=6) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + return OperationConverterResult( + torch_module=OnnxFunction(node.operation_type), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/node_converters/gather.py b/onnx2torch/node_converters/gather.py index c1596876..37e5d326 100644 --- a/onnx2torch/node_converters/gather.py +++ b/onnx2torch/node_converters/gather.py @@ -1,24 +1,26 @@ __all__ = ['OnnxGather'] from typing import List -from typing import Optional from typing import Tuple from typing import Union import torch +import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxGather(nn.Module): +class OnnxGather(nn.Module, OnnxToTorchModuleWithCustomExport): """ONNX gather implementation (or numpy.take implementation)""" - def __init__(self, axis: Optional[int] = 0): + def __init__(self, axis: int = 0): super().__init__() self.axis = axis @@ -38,7 +40,18 @@ def forward(self, input_tensor: torch.Tensor, indices: torch.Tensor) -> torch.Te # But torch.take does not support different axis. So we make it by yourself # numpy.take is input_data[:, :, indices] where we pass NONE slices AXIS time slice_for_take = self.slice_from_axis(input_tensor, self.axis, indices) - return input_tensor[slice_for_take] + output = input_tensor[slice_for_take] + if torch.onnx.is_in_onnx_export(): + return _GatherExportToOnnx.set_output_and_apply(output, input_tensor, indices, self.axis) + + return output + + +class _GatherExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method + @staticmethod + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + input_tensor, indices, axis = args + return graph.op('Gather', input_tensor, indices, axis_i=axis, outputs=1) @add_converter(operation_type='Gather', version=1) diff --git a/onnx2torch/node_converters/gemm.py b/onnx2torch/node_converters/gemm.py index 11d0f643..e56d1421 100644 --- a/onnx2torch/node_converters/gemm.py +++ b/onnx2torch/node_converters/gemm.py @@ -4,14 +4,15 @@ import torch.nn.functional as F from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult -class OnnxGeneralLinear(nn.Linear): +class OnnxGeneralLinear(nn.Linear, OnnxToTorchModule): """General Linear layer with functionality of ONNX GEMM node. For additional info https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm diff --git a/onnx2torch/node_converters/global_average_pool.py b/onnx2torch/node_converters/global_average_pool.py index db3e8c75..df6ffb68 100644 --- a/onnx2torch/node_converters/global_average_pool.py +++ b/onnx2torch/node_converters/global_average_pool.py @@ -3,14 +3,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxGlobalAveragePool(nn.Module): +class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModule): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use x_dims = list(range(2, len(input_tensor.shape))) diff --git a/onnx2torch/node_converters/identity.py b/onnx2torch/node_converters/identity.py index f4173b26..2ab7ba01 100644 --- a/onnx2torch/node_converters/identity.py +++ b/onnx2torch/node_converters/identity.py @@ -1,14 +1,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxCopyIdentity(nn.Module): +class OnnxCopyIdentity(nn.Module, OnnxToTorchModule): def forward(self, x: torch.Tensor): return x.clone() diff --git a/onnx2torch/node_converters/logical.py b/onnx2torch/node_converters/logical.py index 25beba4b..00a9e5e1 100644 --- a/onnx2torch/node_converters/logical.py +++ b/onnx2torch/node_converters/logical.py @@ -6,14 +6,15 @@ import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import SkipTorchTracing -from onnx2torch.common import old_style_broadcast -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import old_style_broadcast +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport _TORCH_FUNCTION_FROM_ONNX_TYPE = { 'Or': torch.logical_or, @@ -22,28 +23,24 @@ } -class OnnxNot(nn.Module): - - def _do_forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - return torch.logical_not(input_tensor) +class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + output = torch.logical_not(input_tensor) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(input_tensor) - return _NotExportToOnnx.set_output_and_apply(output, input_tensor) + return _NotExportToOnnx.set_output_and_apply(output, input_tensor) - return self._do_forward(input_tensor) + return output -class _NotExportToOnnx(CustomExportToOnnx): +class _NotExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: return graph.op('Not', *args, outputs=1) -class OnnxLogical(nn.Module): +class OnnxLogical(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None): super().__init__() self.broadcast = broadcast diff --git a/onnx2torch/node_converters/matmul.py b/onnx2torch/node_converters/matmul.py new file mode 100644 index 00000000..1f0f1e8f --- /dev/null +++ b/onnx2torch/node_converters/matmul.py @@ -0,0 +1,27 @@ +__all__ = ['OnnxMatMul'] + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node + + +class OnnxMatMul(nn.Module, OnnxToTorchModule): + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.matmul(x, y) + + +@add_converter(operation_type='MatMul', version=1) +@add_converter(operation_type='MatMul', version=9) +@add_converter(operation_type='MatMul', version=13) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + return OperationConverterResult( + torch_module=OnnxMatMul(), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/node_converters/max_pool.py b/onnx2torch/node_converters/max_pool.py index 73abe56f..faae03c2 100644 --- a/onnx2torch/node_converters/max_pool.py +++ b/onnx2torch/node_converters/max_pool.py @@ -2,12 +2,13 @@ from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import get_shape_from_value_info -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.common import onnx_padding_to_torch_padding _MAXPOOL_CLASS_FROM_SPATIAL_RANK = { 1: nn.MaxPool1d, @@ -27,31 +28,23 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: spatial_rank = len(input_shape) - 2 maxpool_class = _MAXPOOL_CLASS_FROM_SPATIAL_RANK.get(spatial_rank, None) if maxpool_class is None: - raise NotImplementedError(f'Convolution operation with spatial rank == {spatial_rank} is not implemented') + raise NotImplementedError(f'Max pool operation with spatial rank == {spatial_rank} is not implemented') node_attributes = node.attributes + # required + kernel_shape = node_attributes['kernel_shape'] + # optional ceil_mode = node_attributes.get('ceil_mode', 0) dilation = node_attributes.get('dilations', 1) - padding = node_attributes.get('pads', [0] * spatial_rank * 2) - kernel_shape = node_attributes.get('kernel_shape', None) strides = node_attributes.get('strides', 1) storage_order = node_attributes.get('storage_order', 0) if storage_order != 0: raise NotImplementedError('Only row major (0) order is supported.') - if kernel_shape is None: - raise RuntimeError('Kernel shape for MaxPool not specified. Kernel shape is mandatory parameters in onnx.') - auto_pad = node_attributes.get('auto_pad', 'NOTSET') - if auto_pad == 'NOTSET': - half_len = len(padding) // 2 - if tuple(padding[:half_len]) != tuple(padding[half_len:]): - raise NotImplementedError(f'Only symmetric padding is implemented ({padding})') - - padding = padding[:half_len] - elif auto_pad in ('SAME_UPPER', 'SAME_LOWER', 'VALID'): - raise NotImplementedError(f'"{auto_pad}" auto_pad is not implemented') - else: - raise ValueError(f'Got unexpected auto_pad value "{auto_pad}"') + padding = onnx_padding_to_torch_padding( + node_attributes.get('pads', [0] * spatial_rank * 2), + node_attributes.get('auto_pad', 'NOTSET'), + ) torch_module = maxpool_class( kernel_size=kernel_shape, diff --git a/onnx2torch/node_converters/nms.py b/onnx2torch/node_converters/nms.py index fac6637c..c9058119 100644 --- a/onnx2torch/node_converters/nms.py +++ b/onnx2torch/node_converters/nms.py @@ -7,16 +7,16 @@ import torchvision from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.common import SkipTorchTracing -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxNonMaxSuppression(nn.Module): +class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, center_point_box: bool = False): super().__init__() @@ -47,7 +47,11 @@ def _do_forward( filtered_batch_boxes = batch_boxes[confidence_indexes] if self.center_point_box: - filtered_batch_boxes = torchvision.ops.box_convert(filtered_batch_boxes, in_fmt='cxcywh', out_fmt='xyxy') + filtered_batch_boxes = torchvision.ops.box_convert( + filtered_batch_boxes, + in_fmt='cxcywh', + out_fmt='xyxy', + ) nms_indexes = torchvision.ops.nms( boxes=filtered_batch_boxes, @@ -62,6 +66,8 @@ def _do_forward( [batch_index, class_index, box_index] for box_index in indexes ) + if len(out) == 0: + return torch.empty([0, 3], dtype=torch.int64, device=boxes.device) return torch.tensor(out, dtype=torch.int64, device=boxes.device) @@ -73,31 +79,29 @@ def forward( iou_threshold: Optional[torch.Tensor] = None, score_threshold: Optional[torch.Tensor] = None, ) -> torch.Tensor: + output = self._do_forward(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) - - if max_output_boxes_per_class is None: - max_output_boxes_per_class = torch.tensor([0], dtype=torch.int64) - if iou_threshold is None: - iou_threshold = torch.tensor([0.0], dtype=torch.float32) - if score_threshold is None: - score_threshold = torch.tensor([0.0], dtype=torch.float32) - - return _NmsExportToOnnx.set_output_and_apply( - output, - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - int(self.center_point_box), - ) - - return self._do_forward(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) - - -class _NmsExportToOnnx(CustomExportToOnnx): + if max_output_boxes_per_class is None: + max_output_boxes_per_class = torch.tensor([0], dtype=torch.int64) + if iou_threshold is None: + iou_threshold = torch.tensor([0.0], dtype=torch.float32) + if score_threshold is None: + score_threshold = torch.tensor([0.0], dtype=torch.float32) + + return _NmsExportToOnnx.set_output_and_apply( + output, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + int(self.center_point_box), + ) + + return output + + +class _NmsExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: diff --git a/onnx2torch/node_converters/pow.py b/onnx2torch/node_converters/pow.py index f97d2228..243b455f 100644 --- a/onnx2torch/node_converters/pow.py +++ b/onnx2torch/node_converters/pow.py @@ -5,15 +5,16 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import old_style_broadcast -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import old_style_broadcast +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxPow(nn.Module): +class OnnxPow(nn.Module, OnnxToTorchModule): def __init__(self, broadcast: Optional[int] = None, axis: Optional[int] = None): super().__init__() self.axis = axis @@ -26,7 +27,7 @@ def forward(self, input_tensor: torch.Tensor, exponent: torch.Tensor) -> torch.T return torch.pow(input_tensor, exponent) -class OnnxSqrt(nn.Module): +class OnnxSqrt(nn.Module, OnnxToTorchModule): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return torch.sqrt(input_tensor) diff --git a/onnx2torch/node_converters/range.py b/onnx2torch/node_converters/range.py index 3489f9ea..5305a3a3 100644 --- a/onnx2torch/node_converters/range.py +++ b/onnx2torch/node_converters/range.py @@ -5,14 +5,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxRange(nn.Module): +class OnnxRange(nn.Module, OnnxToTorchModule): def __init__(self): super().__init__() diff --git a/onnx2torch/node_converters/reduce.py b/onnx2torch/node_converters/reduce.py index d6dff689..58dba30a 100644 --- a/onnx2torch/node_converters/reduce.py +++ b/onnx2torch/node_converters/reduce.py @@ -5,6 +5,7 @@ ] from functools import partial +from typing import Any from typing import List from typing import Optional from typing import Tuple @@ -14,15 +15,24 @@ import torch._C as torch_C from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import SkipTorchTracing -from onnx2torch.common import get_const_value -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_const_value +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport + + +@torch.fx.wrap +def _get_element(x: Union[List, Tuple], index: int = 0) -> Any: + if isinstance(x, (tuple, list)): + return x[index] + + return x def _initialize_none_dim(dim: Optional[Union[int, Tuple[int, ...]]], input_dim: int): @@ -73,7 +83,7 @@ def _sum_square( } -class OnnxReduceSumDynamicAxes(nn.Module): +class OnnxReduceSumDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, keepdims: int = 1, noop_with_empty_axes: int = 0): super().__init__() @@ -96,24 +106,23 @@ def _do_forward(self, input_tensor: torch.Tensor, axes: torch.Tensor) -> torch.T return torch.sum(input_tensor, dim=axes, keepdim=self.keepdims) def forward(self, input_tensor: torch.Tensor, axes: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self._do_forward(input_tensor, axes) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - args = [input_tensor, axes] - output = self._do_forward(*args) - if axes is None: - args.pop() - - return _ReduceSumExportToOnnx.set_output_and_apply( - output, - *args, - int(self.keepdims), - int(self.noop_with_empty_axes), - ) + args = [input_tensor] + if axes is not None: + args.append(axes) + + return _ReduceSumExportToOnnx.set_output_and_apply( + output, + *args, + int(self.keepdims), + int(self.noop_with_empty_axes), + ) - return self._do_forward(input_tensor, axes) + return output -class _ReduceSumExportToOnnx(CustomExportToOnnx): +class _ReduceSumExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: @@ -127,7 +136,7 @@ def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: ) -class OnnxReduceSumStaticAxes(nn.Module): +class OnnxReduceSumStaticAxes(nn.Module, OnnxToTorchModule): def __init__( self, @@ -152,11 +161,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return self.math_op_function(input_tensor) self.axes = list(range(input_tensor.dim())) - + return torch.sum(input_tensor, dim=self.axes, keepdim=self.keepdims) -class OnnxReduceStaticAxes(nn.Module): +class OnnxReduceStaticAxes(nn.Module, OnnxToTorchModule): def __init__( self, @@ -191,8 +200,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: dim=axis if self.keepdims else axis - passed_dims, keepdim=self.keepdims, ) - if isinstance(result, tuple): - result = result[0] + result = _get_element(result, 0) return result @@ -269,4 +277,3 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: torch_module=OnnxReduceSumDynamicAxes(keepdims=keepdims, noop_with_empty_axes=noop_with_empty_axes), onnx_mapping=onnx_mapping_from_node(node), ) - diff --git a/onnx2torch/node_converters/registry.py b/onnx2torch/node_converters/registry.py index 16265178..7ed8b7b0 100644 --- a/onnx2torch/node_converters/registry.py +++ b/onnx2torch/node_converters/registry.py @@ -4,9 +4,9 @@ from onnx import defs -from onnx2torch.common import OperationConverterResult from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult _LOGGER = logging.getLogger(__name__) _CONVERTER_REGISTRY = {} diff --git a/onnx2torch/node_converters/reshape.py b/onnx2torch/node_converters/reshape.py index 785a8ff3..0ada2a35 100644 --- a/onnx2torch/node_converters/reshape.py +++ b/onnx2torch/node_converters/reshape.py @@ -1,18 +1,22 @@ __all__ = ['OnnxReshape'] import torch +import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxReshape(nn.Module): +class OnnxReshape(nn.Module, OnnxToTorchModuleWithCustomExport): - def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use + @staticmethod + def _do_reshape(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: if torch.any(shape == 0): shape = [ input_tensor.shape[i] if dim_size == 0 else dim_size @@ -21,6 +25,21 @@ def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tens return torch.reshape(input_tensor, torch.Size(shape)) + def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use + output = self._do_reshape(input_tensor, shape) + + if torch.onnx.is_in_onnx_export(): + return _ReshapeExportToOnnx.set_output_and_apply(output, input_tensor, shape) + + return output + + +class _ReshapeExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method + + @staticmethod + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + return graph.op('Reshape', *args, outputs=1) + @add_converter(operation_type='Reshape', version=5) @add_converter(operation_type='Reshape', version=13) diff --git a/onnx2torch/node_converters/resize.py b/onnx2torch/node_converters/resize.py index 87375f8d..86ac2e90 100644 --- a/onnx2torch/node_converters/resize.py +++ b/onnx2torch/node_converters/resize.py @@ -6,13 +6,14 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -_TORCH_MODES = { +_MODES_MAPPING = { ('nearest', 1): 'nearest', ('nearest', 2): 'nearest', ('nearest', 3): 'nearest', @@ -30,15 +31,15 @@ def _get_torch_align_corners(mode: str, coordinate_transformation_mode: str) -> return coordinate_transformation_mode == 'align_corners' -def _dimension_mode(mode: str, dim_size: int) -> str: - torch_mode = _TORCH_MODES.get((mode, dim_size), None) +def _onnx_mode_to_torch_mode(onnx_mode: str, dim_size: int) -> str: + torch_mode = _MODES_MAPPING.get((onnx_mode, dim_size), None) if torch_mode is None: - raise NotImplementedError(f'{dim_size}D input is not implemented for "{mode}" mode.') + raise NotImplementedError(f'{dim_size}D input is not implemented for "{onnx_mode}" mode.') return torch_mode -class OnnxResize(nn.Module): +class OnnxResize(nn.Module, OnnxToTorchModule): def __init__( self, @@ -46,7 +47,7 @@ def __init__( align_corners: Optional[bool] = None, ): super().__init__() - self.mode = mode + self.onnx_mode = mode self.align_corners = align_corners def forward( @@ -56,36 +57,41 @@ def forward( scales: Optional[torch.Tensor] = None, sizes: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self.mode = _dimension_mode(self.mode, input_tensor.dim() - 2) - if roi is not None: + torch_mode = _onnx_mode_to_torch_mode(self.onnx_mode, input_tensor.dim() - 2) + if roi is not None and roi.nelement() != 0: raise NotImplementedError('roi logic is not implemented.') # Format of onnx scales and sizes is [n, c, d, h, w] # But in torch only [d, h, w] (without batch and channel dimensions) - input_shape = list(input_tensor.shape) if sizes is not None: - sizes = sizes.tolist() - if input_shape[:2] != sizes[:2]: - raise NotImplementedError('Pytorch\'s interpolate cannot resize channel or batch dimensions.') - sizes = sizes[2:] - elif scales is not None: - scales = scales.tolist() - if scales[:2] != [1, 1]: - raise NotImplementedError('Pytorch\'s interpolate cannot scale channel or batch dimensions.') - scales = scales[2:] - else: - raise ValueError('One of scales or sizes should be defined.') + if sizes.nelement() != 0: + sizes = sizes.tolist() + input_shape = list(input_tensor.shape) + if input_shape[:2] != sizes[:2]: + raise NotImplementedError('Pytorch\'s interpolate cannot resize channel or batch dimensions.') + sizes = sizes[2:] + else: + sizes = None + + if scales is not None: + if scales.nelement() != 0: + scales = scales.tolist() + if scales[:2] != [1, 1]: + raise NotImplementedError('Pytorch\'s interpolate cannot scale channel or batch dimensions.') + scales = scales[2:] + else: + scales = None return torch.nn.functional.interpolate( input_tensor, size=sizes, scale_factor=scales, - mode=self.mode, + mode=torch_mode, align_corners=self.align_corners, ) -class OnnxResizeV10(nn.Module): +class OnnxResizeV10(nn.Module, OnnxToTorchModule): def __init__(self, mode: str = 'nearest'): super().__init__() diff --git a/onnx2torch/node_converters/roialign.py b/onnx2torch/node_converters/roialign.py new file mode 100644 index 00000000..91bcdfcb --- /dev/null +++ b/onnx2torch/node_converters/roialign.py @@ -0,0 +1,70 @@ +__all__ = ['OnnxRoiAlign'] + +import torch +from torch import nn +from torchvision.ops import roi_align + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node + + +class OnnxRoiAlign(nn.Module, OnnxToTorchModule): + + def __init__( + self, + mode: str = 'avg', + output_height: int = 1, + output_width: int = 1, + sampling_ratio: int = 0, + spatial_scale: float = 1.0, + ): + super().__init__() + + if mode != 'avg': + raise NotImplementedError(f'"{mode}" roi align mode is not implemented.') + + self._output_size = (output_height, output_width) + self._sampling_ratio = sampling_ratio + self._spatial_scale = spatial_scale + + def forward( + self, + input_tensor: torch.Tensor, + rois: torch.Tensor, + batch_indices: torch.Tensor, + ) -> torch.Tensor: + batched_rois = torch.concat([batch_indices.unsqueeze(1).to(rois.dtype), rois], dim=1) + + return roi_align( + input=input_tensor, + boxes=batched_rois, + output_size=self._output_size, + spatial_scale=self._spatial_scale, + sampling_ratio=self._sampling_ratio, + aligned=False, + ) + + +@add_converter(operation_type='RoiAlign', version=10) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + node_attributes = node.attributes + mode = node_attributes.get('mode', 'avg') + output_height = node_attributes.get('output_height', 1) + output_width = node_attributes.get('output_width', 1) + sampling_ratio = node_attributes.get('sampling_ratio', 0) + spatial_scale = node_attributes.get('spatial_scale', 1.0) + + return OperationConverterResult( + torch_module=OnnxRoiAlign( + mode=mode, + output_height=output_height, + output_width=output_width, + sampling_ratio=sampling_ratio, + spatial_scale=spatial_scale, + ), + onnx_mapping=onnx_mapping_from_node(node), + ) diff --git a/onnx2torch/node_converters/roundings.py b/onnx2torch/node_converters/roundings.py new file mode 100644 index 00000000..344016aa --- /dev/null +++ b/onnx2torch/node_converters/roundings.py @@ -0,0 +1,39 @@ +__all__ = ['OnnxRound'] + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node + +_TORCH_ROUND_FROM_ONNX_TYPE = { + 'Ceil': torch.ceil, + 'Floor': torch.floor, + 'Round': torch.round, +} + + +class OnnxRound(nn.Module, OnnxToTorchModule): + + def __init__(self, round_type: str): + super().__init__() + self.round_function = _TORCH_ROUND_FROM_ONNX_TYPE[round_type] + + def forward(self, input_tensor: torch.Tensor): + return self.round_function(input_tensor) + + +@add_converter(operation_type='Ceil', version=13) +@add_converter(operation_type='Ceil', version=6) +@add_converter(operation_type='Floor', version=13) +@add_converter(operation_type='Floor', version=6) +@add_converter(operation_type='Round', version=11) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + return OperationConverterResult( + torch_module=OnnxRound(node.operation_type), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/node_converters/scatter_nd.py b/onnx2torch/node_converters/scatter_nd.py index 5b54a511..cf6b2f86 100644 --- a/onnx2torch/node_converters/scatter_nd.py +++ b/onnx2torch/node_converters/scatter_nd.py @@ -1,43 +1,39 @@ __all__ = ['OnnxScatterND'] -from typing import Optional - -import numpy as np import torch import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.common import SkipTorchTracing -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxScatterND(nn.Module): +class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport): - def _do_forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: # There is no scatter nd for torch, use following formula: # https://github.com/onnx/onnx/blob/master/docs/Operators.md#ScatterND output = data.clone() - update_indices = indices.shape[:-1] - for idx in np.ndindex(update_indices): - output[indices[idx]] = updates[idx] - - return output - def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(data, indices, updates) - return _ScatterNDExportToOnnx.set_output_and_apply(output, data, indices, updates) + return _ScatterNDExportToOnnx.set_output_and_apply(output, data, indices, updates) + + ind_dim = indices.dim() + # last dimension is a partial-index into data + indices = indices.reshape((-1, indices.shape[-1])).T.tolist() + # update.shape = indices.shape[0:ind_dim-1] ++ data.shape[indices.shape[-1]:data.dim()-1] + updates = updates.reshape((-1, *updates.shape[ind_dim - 1:])) + output[indices] = updates - return self._do_forward(data, indices, updates) + return output -class _ScatterNDExportToOnnx(CustomExportToOnnx): +class _ScatterNDExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: return graph.op('ScatterND', *args, outputs=1) diff --git a/onnx2torch/node_converters/shape.py b/onnx2torch/node_converters/shape.py index 7fd5d3ae..8fb887bb 100644 --- a/onnx2torch/node_converters/shape.py +++ b/onnx2torch/node_converters/shape.py @@ -3,16 +3,19 @@ from typing import Optional import torch +import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxShape(nn.Module): +class OnnxShape(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, start: Optional[int] = None, end: Optional[int] = None): super().__init__() @@ -20,10 +23,37 @@ def __init__(self, start: Optional[int] = None, end: Optional[int] = None): self.end = end def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - return torch.tensor( + output = torch.tensor( input_tensor.shape[self.start:self.end], device=input_tensor.device, ) + if torch.onnx.is_in_onnx_export(): + args = [ + input_tensor, + ] + if self.start is not None: + args.append(self.start) + if self.end is not None: + args.append(self.end) + elif self.end is not None: + args += [0, self.end] + + return _ShapeExportToOnnx.set_output_and_apply(output, *args) + + return output + + +class _ShapeExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method + + @staticmethod + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + if len(args) == 2: + return graph.op('Shape', args[0], start_i=args[1], outputs=1) + + if len(args) == 3: + return graph.op('Shape', args[0], start_i=args[1], end_i=args[2], outputs=1) + + return graph.op('Shape', *args, outputs=1) @add_converter(operation_type='Shape', version=1) diff --git a/onnx2torch/node_converters/slice.py b/onnx2torch/node_converters/slice.py index b0f234e7..227769d6 100644 --- a/onnx2torch/node_converters/slice.py +++ b/onnx2torch/node_converters/slice.py @@ -7,13 +7,17 @@ import numpy as np import torch +import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport def _get_slices( @@ -63,7 +67,7 @@ def _do_slice(x: torch.Tensor, flip_dims: List, pos_axes_slices: List, neg_axes_ return x -class OnnxSliceV9(nn.Module): +class OnnxSliceV9(nn.Module, OnnxToTorchModule): def __init__(self, starts: np.ndarray, ends: np.ndarray, axes: Optional[np.ndarray] = None): super().__init__() @@ -73,7 +77,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return _do_slice(input_tensor, self.flip_dims, self.pos_axes_slices, self.neg_axes_slices) -class OnnxSlice(nn.Module): +class OnnxSlice(nn.Module, OnnxToTorchModuleWithCustomExport): def forward( self, @@ -84,7 +88,24 @@ def forward( steps: Optional[torch.Tensor] = None, ) -> torch.Tensor: flip_dims, pos_axes_slices, neg_axes_slices = _get_slices(starts, ends, axes, steps) - return _do_slice(input_tensor, flip_dims, pos_axes_slices, neg_axes_slices) + output = _do_slice(input_tensor, flip_dims, pos_axes_slices, neg_axes_slices) + if torch.onnx.is_in_onnx_export(): + args = [input_tensor, starts, ends] + if axes is not None: + args.append(axes) + if steps is not None: + args.append(steps) + + return _SliceExportToOnnx.set_output_and_apply(output, *args) + + return output + + +class _SliceExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method + + @staticmethod + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + return graph.op('Slice', *args, outputs=1) @add_converter(operation_type='Slice', version=9) diff --git a/onnx2torch/node_converters/split.py b/onnx2torch/node_converters/split.py new file mode 100644 index 00000000..9d305c49 --- /dev/null +++ b/onnx2torch/node_converters/split.py @@ -0,0 +1,75 @@ +__all__ = ['OnnxSplit', 'OnnxSplit13'] + +from typing import List +from typing import Optional + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node + + +class OnnxSplit13(nn.Module, OnnxToTorchModule): + def __init__(self, num_splits: int, axis: int = 0): + super().__init__() + + self.axis = axis + self.num_splits = num_splits + + def forward( + self, + input_tensor: torch.Tensor, + split: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if split is None: + axis_len = input_tensor.shape[self.axis] + split_size_or_sections = axis_len // self.num_splits + else: + split_size_or_sections = split.tolist() + + return torch.split(input_tensor, split_size_or_sections, dim=self.axis) + + +class OnnxSplit(nn.Module, OnnxToTorchModule): + def __init__(self, num_splits: int, axis: int = 0, split: Optional[List[int]] = None): + super().__init__() + + self.axis = axis + self.num_splits = num_splits + self.split = split + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + if self.split is None: + axis_len = input_tensor.shape[self.axis] + split_size_or_sections = axis_len // self.num_splits + else: + split_size_or_sections = self.split + + return torch.split(input_tensor, split_size_or_sections, dim=self.axis) + + +@add_converter(operation_type='Split', version=13) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + axis = node.attributes.get('axis', 0) + num_splits = len(node.output_values) + return OperationConverterResult( + torch_module=OnnxSplit13(axis=axis, num_splits=num_splits), + onnx_mapping=onnx_mapping_from_node(node=node), + ) + + +@add_converter(operation_type='Split', version=11) +@add_converter(operation_type='Split', version=2) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + axis = node.attributes.get('axis', 0) + split = node.attributes.get('split', None) + num_splits = len(node.output_values) + return OperationConverterResult( + torch_module=OnnxSplit(axis=axis, split=split, num_splits=num_splits), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/node_converters/squeeze.py b/onnx2torch/node_converters/squeeze.py index c67a5056..7cb97136 100644 --- a/onnx2torch/node_converters/squeeze.py +++ b/onnx2torch/node_converters/squeeze.py @@ -10,18 +10,19 @@ import torch._C as torch_C from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import SkipTorchTracing -from onnx2torch.common import get_const_value -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_const_value +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxSqueezeStaticAxes(nn.Module): +class OnnxSqueezeStaticAxes(nn.Module, OnnxToTorchModule): def __init__(self, axes: Optional[List[int]] = None): super().__init__() @@ -41,34 +42,35 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return result -class OnnxSqueezeDynamicAxes(nn.Module): +class OnnxSqueezeDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport): - def _do_forward(self, input_tensor: torch.Tensor, axes: Optional[torch.Tensor]) -> torch.Tensor: + @staticmethod + def _do_forward(input_tensor: torch.Tensor, axes: Optional[torch.Tensor]) -> torch.Tensor: if axes is None or axes.nelement() == 0: return torch.squeeze(input_tensor) + result = input_tensor for axes_id in torch.sort(axes, descending=True).values: - input_tensor = torch.squeeze(input_tensor, dim=axes_id) + result = torch.squeeze(result, dim=axes_id) - return input_tensor + return result def forward(self, input_tensor: torch.Tensor, axes: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self._do_forward(input_tensor, axes) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - args = [input_tensor, axes] - output = self._do_forward(*args) - if axes is None: - args.pop() - return _SqueezeExportToOnnx.set_output_and_apply(output, *args) - - return self._do_forward(input_tensor, axes) + args = [input_tensor] + if axes is not None: + args.append(axes) + + return _SqueezeExportToOnnx.set_output_and_apply(output, *args) + + return output -class _SqueezeExportToOnnx(CustomExportToOnnx): +class _SqueezeExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: - print(graph.__dir__()) return graph.op('Squeeze', *args, outputs=1) diff --git a/onnx2torch/node_converters/tile.py b/onnx2torch/node_converters/tile.py index 86189ea0..05c9e079 100644 --- a/onnx2torch/node_converters/tile.py +++ b/onnx2torch/node_converters/tile.py @@ -1,20 +1,34 @@ __all__ = ['OnnxTile'] import torch +import torch._C as torch_C from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxTile(nn.Module): +class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport): def forward(self, input_tensor: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: # torch.tile(input_tensor, repeats) is not supported for exporting - return input_tensor.repeat(torch.Size(repeats)) + output = input_tensor.repeat(torch.Size(repeats)) + if torch.onnx.is_in_onnx_export(): + return _TileExportToOnnx.set_output_and_apply(output, input_tensor, repeats) + + return output + + +class _TileExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method + + @staticmethod + def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: + return graph.op('Tile', *args, outputs=1) @add_converter(operation_type='Tile', version=6) diff --git a/onnx2torch/node_converters/topk.py b/onnx2torch/node_converters/topk.py index 8371c5e6..d32e03e1 100644 --- a/onnx2torch/node_converters/topk.py +++ b/onnx2torch/node_converters/topk.py @@ -6,14 +6,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxTopK(nn.Module): +class OnnxTopK(nn.Module, OnnxToTorchModule): def __init__(self, dim: int = -1, largest: int = 1, sorted_: int = 1): super().__init__() diff --git a/onnx2torch/node_converters/transpose.py b/onnx2torch/node_converters/transpose.py index e38a716f..61859b70 100644 --- a/onnx2torch/node_converters/transpose.py +++ b/onnx2torch/node_converters/transpose.py @@ -6,14 +6,15 @@ import torch from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult -class OnnxTranspose(nn.Module): +class OnnxTranspose(nn.Module, OnnxToTorchModule): def __init__(self, perm: Optional[List[int]] = None): super().__init__() @@ -22,6 +23,7 @@ def __init__(self, perm: Optional[List[int]] = None): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: if self.perm is None: self.perm = list(range(input_tensor.dim()))[::-1] + return input_tensor.permute(self.perm) diff --git a/onnx2torch/node_converters/unsqueeze.py b/onnx2torch/node_converters/unsqueeze.py index 861b8ff8..ce3acabb 100644 --- a/onnx2torch/node_converters/unsqueeze.py +++ b/onnx2torch/node_converters/unsqueeze.py @@ -10,18 +10,19 @@ import torch._C as torch_C from torch import nn -from onnx2torch.common import OnnxMapping -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import SkipTorchTracing -from onnx2torch.common import get_const_value -from onnx2torch.common import onnx_mapping_from_node -from onnx2torch.custom_export_to_onnx import CustomExportToOnnx from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_const_value +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxUnsqueezeStaticAxes(nn.Module): +class OnnxUnsqueezeStaticAxes(nn.Module, OnnxToTorchModule): def __init__(self, axes: Optional[List[int]] = None): super().__init__() @@ -35,9 +36,10 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return result -class OnnxUnsqueezeDynamicAxes(nn.Module): +class OnnxUnsqueezeDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport): - def _do_forward(self, input_tensor: torch.Tensor, axes: torch.Tensor) -> torch.Tensor: + @staticmethod + def _do_forward(input_tensor: torch.Tensor, axes: torch.Tensor) -> torch.Tensor: result = input_tensor for axes_id in torch.sort(axes).values: result = torch.unsqueeze(result, dim=axes_id) @@ -45,19 +47,17 @@ def _do_forward(self, input_tensor: torch.Tensor, axes: torch.Tensor) -> torch.T return result def forward(self, input_tensor: torch.Tensor, axes: torch.Tensor) -> torch.Tensor: + output = self._do_forward(input_tensor, axes) if torch.onnx.is_in_onnx_export(): - with SkipTorchTracing(): - output = self._do_forward(input_tensor, axes) - return _UnsqueezeExportToOnnx.set_output_and_apply(output, input_tensor, axes) + return _UnsqueezeExportToOnnx.set_output_and_apply(output, input_tensor, axes) - return self._do_forward(input_tensor, axes) + return output -class _UnsqueezeExportToOnnx(CustomExportToOnnx): +class _UnsqueezeExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: - print(graph.__dir__()) return graph.op('Unsqueeze', *args, outputs=1) diff --git a/onnx2torch/node_converters/where.py b/onnx2torch/node_converters/where.py index b04dfc87..0a6a0c42 100644 --- a/onnx2torch/node_converters/where.py +++ b/onnx2torch/node_converters/where.py @@ -3,14 +3,15 @@ import torch from torch import nn -from onnx2torch.common import OperationConverterResult -from onnx2torch.common import onnx_mapping_from_node from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node -class OnnxWhere(nn.Module): +class OnnxWhere(nn.Module, OnnxToTorchModule): def forward(self, condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.where(condition, x, y) diff --git a/onnx2torch/utils/__init__.py b/onnx2torch/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/onnx2torch/common.py b/onnx2torch/utils/common.py similarity index 72% rename from onnx2torch/common.py rename to onnx2torch/utils/common.py index d337c676..aed54dbf 100644 --- a/onnx2torch/common.py +++ b/onnx2torch/utils/common.py @@ -2,19 +2,22 @@ from typing import NamedTuple from typing import Tuple from typing import Union -from warnings import catch_warnings -from warnings import filterwarnings import torch -import torch._C as torch_C from onnx import ValueInfoProto from torch import nn -from torch.jit import TracerWarning from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode +class OnnxToTorchModule: + """ + Marker class for onnx2torch modules. + """ + pass + + class OnnxMapping(NamedTuple): inputs: Tuple[str, ...] outputs: Tuple[str, ...] @@ -58,21 +61,6 @@ def get_const_value(name: str, graph: OnnxGraph) -> Union[torch.Tensor, float, i raise KeyError(f'Tensor "{name}" is not found in constant values') -class SkipTorchTracing: - def __init__(self): - self._catch_warnings = catch_warnings() - self._state = None - - def __enter__(self): - self._state = torch_C._get_tracing_state() - self._catch_warnings.__enter__() - filterwarnings(action='ignore', category=TracerWarning) - - def __exit__(self, exc_type, exc_val, exc_tb): - torch_C._set_tracing_state(self._state) - self._catch_warnings.__exit__(exc_type, exc_val, exc_tb) - - def old_style_broadcast(first: torch.Tensor, second: torch.Tensor, axis: int) -> torch.Tensor: rank = len(first.shape) axis = axis + rank if axis < 0 else axis @@ -81,3 +69,18 @@ def old_style_broadcast(first: torch.Tensor, second: torch.Tensor, axis: int) -> second_shape = second_shape + [1]*(rank - len(second_shape)) return second.view(second_shape) + + +def onnx_padding_to_torch_padding(padding: Tuple[int, ...], auto_pad: str) -> Tuple[int, ...]: + if auto_pad == 'NOTSET': + half_len = len(padding) // 2 + if padding[:half_len] != padding[half_len:]: + raise NotImplementedError(f'Only symmetric padding is implemented ({padding})') + + padding = padding[:half_len] + elif auto_pad in ('SAME_UPPER', 'SAME_LOWER', 'VALID'): + raise NotImplementedError(f'"{auto_pad}" auto_pad is not implemented') + else: + raise ValueError(f'Got unexpected auto_pad value "{auto_pad}"') + + return padding diff --git a/onnx2torch/custom_export_to_onnx.py b/onnx2torch/utils/custom_export_to_onnx.py similarity index 70% rename from onnx2torch/custom_export_to_onnx.py rename to onnx2torch/utils/custom_export_to_onnx.py index cab65e43..959d29fc 100644 --- a/onnx2torch/custom_export_to_onnx.py +++ b/onnx2torch/utils/custom_export_to_onnx.py @@ -1,10 +1,22 @@ -__all__ = ['CustomExportToOnnx'] +__all__ = [ + 'CustomExportToOnnx', + 'OnnxToTorchModuleWithCustomExport', +] from typing import Any import torch from torch import _C as torch_C +from onnx2torch.utils.common import OnnxToTorchModule + + +class OnnxToTorchModuleWithCustomExport(OnnxToTorchModule): + """ + Marker class for onnx2torch modules with custom export to onnx. + """ + pass + class CustomExportToOnnx(torch.autograd.Function): _NEXT_OUTPUT = None diff --git a/tests/models/models_test.py b/tests/models/models_test.py index b6cd218a..22a8dc04 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -3,10 +3,12 @@ import numpy as np import onnx import pytest +import torchvision from PIL import Image from onnx import version_converter -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model +from tests.utils.common import check_torch_model from tests.utils.resources import get_minimal_dataset_path from tests.utils.resources import get_model_path @@ -14,12 +16,15 @@ _COCO_STD = np.array([0.225, 0.224, 0.229], dtype=np.float32) -def create_test_batch(n: int = 32, target_size: Tuple[int, int] = (224, 224)) -> np.ndarray: +def create_test_batch( + bs: int = 32, + target_size: Tuple[int, int] = (224, 224), +) -> np.ndarray: minimal_dataset_path = get_minimal_dataset_path() batch = [] for i, image_path in enumerate(minimal_dataset_path.glob('*.jpg')): - if i >= n: + if i >= bs: break image = Image.open(image_path).convert('RGB') @@ -45,7 +50,7 @@ def test_resnet50(): input_name: np.random.randn(1, 3, 224, 224).astype(dtype=np.float32) } - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10 ** -5, @@ -55,16 +60,116 @@ def test_resnet50(): @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') -def test_ssdlite() -> None: - model_path = get_model_path('ssdlite') +@pytest.mark.parametrize( + 'model,resolution', + ( + ('retinanet', (604, 604)), + ('ssd300_vgg', (604, 604)), + ('ssdlite', (224, 224)), + ('yolov3_d53', (604, 604)), + ('yolov5_ultralitics', (672, 256)), + ('deeplabv3_mnv3_large', (320, 320)), + ('deeplabv3_plus_resnet101', (486, 500)), + ('hrnet', (321, 321)), + ('unet', (320, 320)), + ), +) +def test_onnx_models(model: str, resolution: Tuple[int, int]) -> None: + model_path = get_model_path(model) model = onnx.load_model(str(model_path.resolve())) input_name = model.graph.input[0].name test_inputs = { - input_name: create_test_batch(n=32), + input_name: create_test_batch(bs=1, target_size=resolution), } - check_model( + check_onnx_model( + model, + test_inputs, + atol_onnx_torch=10 ** -3, + atol_torch_cpu_cuda=10 ** -3, + atol_onnx_torch2onnx=10 ** -3, + ) + + +@pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') +@pytest.mark.parametrize( + 'model', + ( + 'resnet18', + 'resnet50', + 'mobilenet_v2', + 'mobilenet_v3_large', + 'efficientnet_b0', + 'efficientnet_b1', + 'efficientnet_b2', + 'efficientnet_b3', + 'wide_resnet50_2', + 'resnext50_32x4d', + 'vgg16', + 'googlenet', + 'mnasnet1_0', + 'regnet_y_400mf', + 'regnet_y_16gf', + ) +) +def test_torchvision_classification(model: str) -> None: + torch_model = getattr(torchvision.models, model)(pretrained=True) + test_inputs = { + 'inputs': create_test_batch(bs=32), + } + + check_torch_model( + torch_model, + test_inputs, + atol_onnx_torch=10 ** -4, + atol_torch_cpu_cuda=10 ** -4, + atol_onnx_torch2onnx=10 ** -4, + ) + + +@pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') +@pytest.mark.parametrize( + 'model', + ( + 'fcn_resnet50', + 'deeplabv3_resnet50', + 'lraspp_mobilenet_v3_large', + ) +) +def test_torchvision_segmentation(model: str) -> None: + torch_model = getattr(torchvision.models.segmentation, model)(pretrained=True) + test_inputs = { + 'inputs': create_test_batch(bs=8), + } + + check_torch_model( + torch_model, + test_inputs, + atol_onnx_torch=10 ** -3, + atol_torch_cpu_cuda=10 ** -3, + atol_onnx_torch2onnx=10 ** -3, + ) + + +@pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') +@pytest.mark.parametrize( + 'model', + ( + 'vit', + 'swin', + ) +) +def test_transformer_models(model: str) -> None: + model_path = get_model_path(model) + model = onnx.load_model(str(model_path.resolve())) + + input_name = model.graph.input[0].name + test_inputs = { + input_name: create_test_batch(bs=8, target_size=(224, 224)), + } + + check_onnx_model( model, test_inputs, atol_onnx_torch=10 ** -4, diff --git a/tests/node_converters/activations_test.py b/tests/node_converters/activations_test.py index 56d1ef87..64632ff2 100644 --- a/tests/node_converters/activations_test.py +++ b/tests/node_converters/activations_test.py @@ -5,7 +5,7 @@ import onnx import pytest -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -19,17 +19,18 @@ def _test_activation(activation: str, data: np.ndarray, opset_version, **kwargs) opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.parametrize( 'activation,input_shape', ( - ('Relu', [8, 3, 32, 32]), - ('Exp', [8, 3, 32, 32]), - ('Sigmoid', [8, 3, 32, 32]), + ('Erf', [8, 3, 32, 32]), ('HardSigmoid', [8, 3, 32, 32]), ('LeakyRelu', [8, 3, 32, 32]), + ('LogSoftmax', [8, 3, 32, 32]), + ('Relu', [8, 3, 32, 32]), + ('Sigmoid', [8, 3, 32, 32]), ), ) def test_common_activations(activation: str, input_shape: List[int]) -> None: @@ -54,9 +55,10 @@ def test_common_activations(activation: str, input_shape: List[int]) -> None: ([8, 3, 32, 32], -1, 13), ), ) -def test_softmax(input_shape: List[int], axis: Optional[int], opset_version: int) -> None: +@pytest.mark.parametrize('activation', ('Softmax', 'LogSoftmax')) +def test_softmax(activation: str, input_shape: List[int], axis: Optional[int], opset_version: int) -> None: data = np.random.randn(*input_shape).astype(np.float32) if axis is None: - _test_activation('Softmax', data=data, opset_version=opset_version) + _test_activation(activation, data=data, opset_version=opset_version) else: - _test_activation('Softmax', data=data, opset_version=opset_version, axis=axis) + _test_activation(activation, data=data, opset_version=opset_version, axis=axis) diff --git a/tests/node_converters/average_pool_max_pool_test.py b/tests/node_converters/average_pool_max_pool_test.py new file mode 100644 index 00000000..063280ca --- /dev/null +++ b/tests/node_converters/average_pool_max_pool_test.py @@ -0,0 +1,71 @@ +from typing import Dict +from typing import List + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_pool_op( + op_type, + input_shape: List[int], + atol_onnx_torch: float = 0.0, + **kwargs, +) -> None: + x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) + test_inputs = {'x': x} + + node = onnx.helper.make_node( + op_type, + inputs=['x'], + outputs=['y'], + **kwargs, + ) + model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) + check_onnx_model( + model, + test_inputs, + atol_onnx_torch=atol_onnx_torch, + atol_torch_cpu_cuda=0, + atol_onnx_torch2onnx=0, + ) + + +@pytest.mark.parametrize( + 'op', + ( + 'MaxPool', + 'AveragePool', + ) +) +@pytest.mark.parametrize( + 'input_shape,kernel_shape,optional_attrs', + ( + # 1d + ([2, 3, 16], [2], {}), + ([2, 3, 16], [1], {}), + ([2, 3, 16], [3], {}), + ([2, 3, 16], [2], {'strides': [3]}), + ([2, 3, 16], [2], {'ceil_mode': 1}), + # 2d + ([2, 3, 16, 16], [2, 2], {}), + ([2, 3, 16, 16], [1, 2], {}), + ([2, 3, 16, 16], [3, 2], {}), + ([2, 3, 16, 16], [2, 2], {'strides': [2, 3]}), + ([2, 3, 16, 16], [2, 2], {'ceil_mode': 1}), + # 3d + ([2, 3, 16, 16, 16], [2, 2, 2], {}), + ([2, 3, 16, 16, 16], [1, 2, 3], {}), + ([2, 3, 16, 16, 16], [3, 2, 1], {}), + ([2, 3, 16, 16, 16], [2, 2, 2], {'strides': [1, 2, 3]}), + ([2, 3, 16, 16, 16], [2, 2, 2], {'ceil_mode': 1}), + ), +) +def test_max_pool_average_pool(op: str, input_shape: List[int], kernel_shape: List[int], optional_attrs: Dict) -> None: + if op == 'AveragePool': + optional_attrs['atol_onnx_torch'] = 10**-7 + + _test_pool_op(op, input_shape=input_shape, kernel_shape=kernel_shape, **optional_attrs) diff --git a/tests/node_converters/batch_norm_test.py b/tests/node_converters/batch_norm_test.py index be7348b3..4f67dc19 100644 --- a/tests/node_converters/batch_norm_test.py +++ b/tests/node_converters/batch_norm_test.py @@ -1,7 +1,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -32,7 +32,7 @@ def _test_batch_norm( ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10**-7, diff --git a/tests/node_converters/binary_operations_test.py b/tests/node_converters/binary_operations_test.py index 02617a17..48825ce9 100644 --- a/tests/node_converters/binary_operations_test.py +++ b/tests/node_converters/binary_operations_test.py @@ -2,7 +2,7 @@ import onnx import pytest -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -29,4 +29,4 @@ def test_math_binary_operation(op_type: str) -> None: ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) diff --git a/tests/node_converters/clip_test.py b/tests/node_converters/clip_test.py index 0d7c5d7f..1766163e 100644 --- a/tests/node_converters/clip_test.py +++ b/tests/node_converters/clip_test.py @@ -4,7 +4,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -32,7 +32,7 @@ def _test_clip( **kwargs, ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def _test_clip_opset9( @@ -49,7 +49,7 @@ def _test_clip_opset9( **kwargs, ) model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, opset_version=9) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_clip() -> None: diff --git a/tests/node_converters/comparisons_test.py b/tests/node_converters/comparisons_test.py index 8832a793..20961a20 100644 --- a/tests/node_converters/comparisons_test.py +++ b/tests/node_converters/comparisons_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -33,7 +33,7 @@ def _test_comparison(op_type: str, a: np.ndarray, b: np.ndarray, opset_version: outputs_info=outputs_info, opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.parametrize( diff --git a/tests/node_converters/concat_test.py b/tests/node_converters/concat_test.py index 6fcb1c58..5baf537a 100644 --- a/tests/node_converters/concat_test.py +++ b/tests/node_converters/concat_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -37,7 +37,7 @@ def _test_concat( outputs_info=outputs_info, opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_concat() -> None: diff --git a/tests/node_converters/constant_of_shape_test.py b/tests/node_converters/constant_of_shape_test.py index d468a0f2..fdaea80c 100644 --- a/tests/node_converters/constant_of_shape_test.py +++ b/tests/node_converters/constant_of_shape_test.py @@ -7,7 +7,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -30,7 +30,7 @@ def _test_constant_of_shape(shape: np.ndarray, value: np.ndarray) -> None: inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/constant_test.py b/tests/node_converters/constant_test.py index ea196805..0cf4bc7e 100644 --- a/tests/node_converters/constant_test.py +++ b/tests/node_converters/constant_test.py @@ -7,7 +7,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -28,7 +28,7 @@ def _test_constant_as_tensor(shape: Tuple[int, ...], dtype: np.dtype) -> None: inputs_example={}, outputs_info=outputs_info, ) - check_model(model, onnx_inputs={}) + check_onnx_model(model, onnx_inputs={}) @pytest.mark.filterwarnings('ignore:No input args') diff --git a/tests/node_converters/conv_test.py b/tests/node_converters/conv_test.py index a248cd4f..a5e38f26 100644 --- a/tests/node_converters/conv_test.py +++ b/tests/node_converters/conv_test.py @@ -5,7 +5,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -34,7 +34,7 @@ def _test_conv( ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10**-4, diff --git a/tests/node_converters/expand_test.py b/tests/node_converters/expand_test.py index 3b381e83..a1643b28 100644 --- a/tests/node_converters/expand_test.py +++ b/tests/node_converters/expand_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -24,7 +24,7 @@ def _test_expand( make_tensor_value_info( name='y', elem_type=NP_TYPE_TO_TENSOR_TYPE[data.dtype], - shape=[], + shape=[None] * len(shape), ), ] @@ -34,7 +34,7 @@ def _test_expand( inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.parametrize( diff --git a/tests/node_converters/flatten_test.py b/tests/node_converters/flatten_test.py index a72d0c84..bdf799c0 100644 --- a/tests/node_converters/flatten_test.py +++ b/tests/node_converters/flatten_test.py @@ -3,7 +3,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -22,7 +22,7 @@ def _test_flatten( **kwargs, ) model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_flatten() -> None: diff --git a/tests/node_converters/gather_test.py b/tests/node_converters/gather_test.py index 95d1a329..42afa36c 100644 --- a/tests/node_converters/gather_test.py +++ b/tests/node_converters/gather_test.py @@ -1,7 +1,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -29,7 +29,7 @@ def _test_gather( inputs_example=test_inputs, opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_gather() -> None: diff --git a/tests/node_converters/gemm_test.py b/tests/node_converters/gemm_test.py index 45be67dc..b291ed1f 100644 --- a/tests/node_converters/gemm_test.py +++ b/tests/node_converters/gemm_test.py @@ -3,7 +3,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -33,7 +33,7 @@ def _test_gemm( **kwargs, ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10**-5, diff --git a/tests/node_converters/global_avg_pool_test.py b/tests/node_converters/global_avg_pool_test.py index 7b2017d8..253dc211 100644 --- a/tests/node_converters/global_avg_pool_test.py +++ b/tests/node_converters/global_avg_pool_test.py @@ -3,7 +3,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -22,7 +22,7 @@ def _test_global_avg_pool( **kwargs, ) model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10**-7, diff --git a/tests/node_converters/logical_test.py b/tests/node_converters/logical_test.py index f163cf51..84b70dbf 100644 --- a/tests/node_converters/logical_test.py +++ b/tests/node_converters/logical_test.py @@ -2,7 +2,7 @@ import onnx import pytest -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -33,7 +33,7 @@ def test_logical_operation(op_type: str) -> None: initializers=initializers, inputs_example=test_inputs, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_not() -> None: @@ -54,4 +54,4 @@ def test_not() -> None: ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) diff --git a/tests/node_converters/matmul_test.py b/tests/node_converters/matmul_test.py new file mode 100644 index 00000000..b1cab558 --- /dev/null +++ b/tests/node_converters/matmul_test.py @@ -0,0 +1,41 @@ +import numpy as np +import onnx + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def test_matmul() -> None: + a_variants = [ + np.random.randn(3, 4).astype(np.float32), + np.random.randn(2, 3, 4).astype(np.float32), + np.random.randn(1, 2, 3, 4).astype(np.float32), + ] + + b_variants = [ + np.random.randn(4, 3).astype(np.float32), + np.random.randn(2, 4, 3).astype(np.float32), + np.random.randn(1, 2, 4, 3).astype(np.float32), + ] + + for a, b in zip(a_variants, b_variants): + test_inputs = {'a': a, 'b': b} + initializers = {} + node = onnx.helper.make_node( + op_type='MatMul', + inputs=['a', 'b'], + outputs=['z'], + ) + + model = make_model_from_nodes( + nodes=node, + initializers=initializers, + inputs_example=test_inputs, + ) + check_onnx_model( + model, + test_inputs, + atol_onnx_torch=10 ** -6, + atol_torch_cpu_cuda=10 ** -6, + atol_onnx_torch2onnx=10 ** -6, + ) diff --git a/tests/node_converters/max_pool_test.py b/tests/node_converters/max_pool_test.py deleted file mode 100644 index 68ff571c..00000000 --- a/tests/node_converters/max_pool_test.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List - -import numpy as np -import onnx - -from tests.utils.common import check_model -from tests.utils.common import make_model_from_nodes - - -def _test_max_pool( - input_shape: List[int], - **kwargs, -) -> None: - - x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) - test_inputs = {'x': x} - - node = onnx.helper.make_node( - 'MaxPool', - inputs=['x'], - outputs=['y'], - **kwargs, - ) - model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) - check_model(model, test_inputs) - - -def test_max_pool() -> None: - _test_max_pool(input_shape=[2, 3, 16, 16], kernel_shape=[2, 2], strides=[2, 2]) - _test_max_pool(input_shape=[2, 3, 16, 16, 16], kernel_shape=[2, 2, 2], strides=[2, 2, 2]) - _test_max_pool(input_shape=[2, 3, 16, 16], kernel_shape=[2, 2], strides=[2, 2], ceil_mode=1) diff --git a/tests/node_converters/nms_test.py b/tests/node_converters/nms_test.py index 05c09509..a8bd7127 100644 --- a/tests/node_converters/nms_test.py +++ b/tests/node_converters/nms_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -63,7 +63,7 @@ def _test_nms( inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) _BOXES = np.array([[ diff --git a/tests/node_converters/pow_test.py b/tests/node_converters/pow_test.py index 54f32d69..f5dca03a 100644 --- a/tests/node_converters/pow_test.py +++ b/tests/node_converters/pow_test.py @@ -1,7 +1,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -35,7 +35,7 @@ def test_pow() -> None: initializers=initializers, inputs_example=test_inputs, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_sqrt() -> None: @@ -51,4 +51,4 @@ def test_sqrt() -> None: ) model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) diff --git a/tests/node_converters/range_test.py b/tests/node_converters/range_test.py index 011ca193..1f3591d7 100644 --- a/tests/node_converters/range_test.py +++ b/tests/node_converters/range_test.py @@ -4,7 +4,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -30,7 +30,7 @@ def _test_range( inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/reduce_test.py b/tests/node_converters/reduce_test.py index dca295c6..c761a5a2 100644 --- a/tests/node_converters/reduce_test.py +++ b/tests/node_converters/reduce_test.py @@ -8,7 +8,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -25,7 +25,7 @@ def _test_reduce(input_tensor: np.ndarray, op_type: str, tol: float, **kwargs) - initializers={}, inputs_example=test_inputs, ) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=tol, @@ -80,7 +80,7 @@ def _test_reduce_sum( shape=output_shape, ),), ) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10 ** -5, diff --git a/tests/node_converters/reshape_test.py b/tests/node_converters/reshape_test.py index 222cc744..93bbdcec 100644 --- a/tests/node_converters/reshape_test.py +++ b/tests/node_converters/reshape_test.py @@ -4,7 +4,7 @@ import onnx import pytest -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -29,7 +29,7 @@ def _test_reshape( inputs_example=test_inputs, opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/resize_test.py b/tests/node_converters/resize_test.py index 4c5a047a..88361915 100644 --- a/tests/node_converters/resize_test.py +++ b/tests/node_converters/resize_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -51,7 +51,7 @@ def _test_resize( outputs_info=outputs_info, opset_version=13, ) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10 ** -6, @@ -84,7 +84,7 @@ def _test_resize_v10( outputs_info=outputs_info, opset_version=10, ) - check_model( + check_onnx_model( model, test_inputs, atol_onnx_torch=10 ** -7, diff --git a/tests/node_converters/roialign_test.py b/tests/node_converters/roialign_test.py new file mode 100644 index 00000000..15f8e3b2 --- /dev/null +++ b/tests/node_converters/roialign_test.py @@ -0,0 +1,110 @@ +from typing import List + +import numpy as np +import onnx +import pytest +from onnx.helper import make_tensor_value_info +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def get_roi_align_input_values(): # type: ignore + x = np.array( + [ + [ + [ + [ + 0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250, + ], + [ + 0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467, + ], + [ + 0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162, + ], + [ + 0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799, + ], + [ + 0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119, + ], + [ + 0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119, + ], + [ + 0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689, + ], + [ + 0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928, + ], + [ + 0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514, + ], + [ + 0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502, + ], + ] + ] + ], + dtype=np.float32, + ) + batch_indices = np.array([0, 0, 0], dtype=np.int64) + rois = np.array([[0, 0, 9, 9], [0, 5, 4, 9], [5, 5, 9, 9]], dtype=np.float32) + return x, batch_indices, rois + + +def _test_roi( + input_tensor: np.ndarray, + rois: np.ndarray, + batch_indices: np.ndarray, + **kwargs, +) -> None: + test_inputs = {'X': input_tensor, 'rois': rois, 'batch_indices': batch_indices} + + node = onnx.helper.make_node( + op_type='RoiAlign', + inputs=list(test_inputs), + outputs=['y'], + **kwargs, + ) + onnx_type = NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')] + outputs_info = [make_tensor_value_info(name='y', elem_type=onnx_type, shape=None)] + model = make_model_from_nodes( + nodes=node, + initializers={}, + inputs_example=test_inputs, + outputs_info=outputs_info, + ) + check_onnx_model(model, test_inputs) + + +@pytest.mark.parametrize( + 'spatial_scale,sampling_ratio,output_height,output_width', + ( + (1.0, 2, 5, 5), + (0.25, 0, 7, 7), + (0.125, 0, 7, 7), + (0.6, 0, 1, 1), + (None, None, None, None), + ) +) +@pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') +def test_roi(spatial_scale: float, sampling_ratio: int, output_height: int, output_width:int) -> None: + x, batch_indices, rois = get_roi_align_input_values() + kwargs = {} + if spatial_scale is not None: + kwargs['spatial_scale'] = spatial_scale + if sampling_ratio is not None: + kwargs['sampling_ratio'] = sampling_ratio + if output_height is not None: + kwargs['output_height'] = output_height + if output_width is not None: + kwargs['output_width'] = output_width + _test_roi( + input_tensor=x, + rois=rois, + batch_indices=batch_indices, + **kwargs, + ) diff --git a/tests/node_converters/scatter_nd_test.py b/tests/node_converters/scatter_nd_test.py index 33cd5330..ca4e3615 100644 --- a/tests/node_converters/scatter_nd_test.py +++ b/tests/node_converters/scatter_nd_test.py @@ -1,7 +1,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -21,7 +21,7 @@ def _test_scatter_nd( ) model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_scatter_nd() -> None: @@ -29,14 +29,34 @@ def test_scatter_nd() -> None: [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], - [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]] + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], + ], dtype=np.float32) + + indices = np.array([[0, 1, 2], [1, 2, 3]], dtype=np.int64) + updates = np.array([1232, 5463], dtype=np.float32) + _test_scatter_nd( + data=data, + indices=indices, + updates=updates, + ) + + indices = np.array([[0, 1], [1, 2]], dtype=np.int64) + updates = np.array([ + [8, 7, 6, 5], + [4, 3, 2, 1], ], dtype=np.float32) + + _test_scatter_nd( + data=data, + indices=indices, + updates=updates, + ) + indices = np.array([[0], [2]], dtype=np.int64) updates = np.array([ [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], - [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]] + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], ], dtype=np.float32) - _test_scatter_nd( data=data, indices=indices, diff --git a/tests/node_converters/shape_test.py b/tests/node_converters/shape_test.py index d81ed437..cc1e1c2e 100644 --- a/tests/node_converters/shape_test.py +++ b/tests/node_converters/shape_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -33,7 +33,7 @@ def _test_shape( outputs_info=outputs_info, opset_version=opset_version, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/slice_test.py b/tests/node_converters/slice_test.py index d7923c9d..cb1e21cf 100644 --- a/tests/node_converters/slice_test.py +++ b/tests/node_converters/slice_test.py @@ -6,7 +6,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -44,7 +44,7 @@ def _test_slice( inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/split_test.py b/tests/node_converters/split_test.py new file mode 100644 index 00000000..7ac846af --- /dev/null +++ b/tests/node_converters/split_test.py @@ -0,0 +1,90 @@ +from typing import List +from typing import Optional + +import numpy as np +import onnx +import pytest +from onnx.helper import make_tensor_value_info +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_split( + x: np.ndarray, + expected_output: List[np.ndarray], + opset_version: int, + **kwargs, +) -> None: + inputs = ['x', ] + test_inputs = {'x': x} + + if opset_version >= 13 and kwargs.get('split') is not None: + split = kwargs.pop('split') + test_inputs['split'] = split + inputs.append('split') + + node = onnx.helper.make_node( + op_type='Split', + inputs=inputs, + outputs=[f'output_{i}' for i, _ in enumerate(expected_output)], + **kwargs, + ) + + outputs_info = [ + make_tensor_value_info( + name=f'output_{i}', + elem_type=NP_TYPE_TO_TENSOR_TYPE[out.dtype], + shape=out.shape, + ) + for i, out in enumerate(expected_output) + ] + + model = make_model_from_nodes( + nodes=node, + initializers={}, + inputs_example=test_inputs, + outputs_info=outputs_info, + opset_version=opset_version, + ) + check_onnx_model(model, test_inputs) + + +INPUT_1D = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32) +INPUT_2D = np.array([ + [1., 2., 3., 4., 5., 6.], + [7., 8., 9., 10., 11., 12.] +]).astype(np.float32) + +EMPTY_INPUT = np.array([]).astype(np.float32) +EXPECTED_EMPTY_OUT = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)] + + +@pytest.mark.parametrize( + 'input_array,expected_out,axis,split', + ( + (INPUT_1D, np.split(INPUT_1D, 3), None, None), + (INPUT_1D, np.split(INPUT_1D, 3), 0, None), + (INPUT_1D, np.split(INPUT_1D, [2]), None, np.array([2, 4]).astype(np.int64)), + (INPUT_1D, np.split(INPUT_1D, [2]), 0, np.array([2, 4]).astype(np.int64)), + (INPUT_2D, np.split(INPUT_2D, 2, axis=1), 1, None), + (INPUT_2D, np.split(INPUT_2D, [2], axis=1), 1, np.array([2, 4]).astype(np.int64)), + (EMPTY_INPUT, EXPECTED_EMPTY_OUT, None, np.array([0, 0, 0]).astype(np.int64)) + ), +) +@pytest.mark.parametrize('opset_version', (13, 11, 2)) +def test_split( + input_array: np.ndarray, + expected_out: List[np.ndarray], + axis: Optional[int], + split: Optional[np.ndarray], + opset_version: int, +) -> None: + kwargs = {} + if axis is not None: + kwargs['axis'] = axis + if split is not None: + kwargs['split'] = split + + _test_split(input_array, expected_out, opset_version=opset_version, **kwargs) diff --git a/tests/node_converters/squeeze_test.py b/tests/node_converters/squeeze_test.py index 6f6b0f96..6801a3b0 100644 --- a/tests/node_converters/squeeze_test.py +++ b/tests/node_converters/squeeze_test.py @@ -9,7 +9,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -53,7 +53,7 @@ def _test_squeeze( shape=output_shape, ),), ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/test_functions.py b/tests/node_converters/test_functions.py new file mode 100644 index 00000000..b3780575 --- /dev/null +++ b/tests/node_converters/test_functions.py @@ -0,0 +1,81 @@ +from typing import List + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_functions(function: str, data: np.ndarray, opset_version, **kwargs) -> None: + test_inputs = {'input_tensor': data} + + node = onnx.helper.make_node(op_type=function, inputs=['input_tensor'], outputs=['y'], **kwargs) + model = make_model_from_nodes( + nodes=node, initializers={}, + inputs_example=test_inputs, + opset_version=opset_version, + ) + + check_onnx_model(model, test_inputs) + + +@pytest.mark.parametrize( + 'function,input_shape', + ( + ('Ceil', [8, 3, 32, 32]), + ('Floor', [8, 3, 32, 32]), + ('Round', [8, 3, 32, 32]), + ), +) +def test_roundings(function: str, input_shape: List[int]) -> None: + data = np.random.randn(*input_shape).astype(np.float32) + _test_functions(function, data=data, opset_version=11) + + +@pytest.mark.parametrize( + 'function,input_shape', + ( + ('Abs', [8, 3, 32, 32]), + ('Cos', [8, 3, 32, 32]), + ('Exp', [8, 3, 32, 32]), + ('Log', [8, 3, 32, 32]), + ('Sign', [8, 3, 32, 32]), + ('Sin', [8, 3, 32, 32]), + ('Tan', [8, 3, 32, 32]) + ), +) +def test_common_functions(function: str, input_shape: List[int]) -> None: + data = np.random.randn(*input_shape).astype(np.float32) + if function == 'Log': + data[data <= 0] = 10**-4 + _test_functions(function, data=data, opset_version=11) + + +@pytest.mark.parametrize( + 'function,input_shape', + ( + ('Acos', [8, 3, 32, 32]), + ('Asin', [8, 3, 32, 32]), + ('Atan', [8, 3, 32, 32]), + ), +) +def test_arc_functions(function: str, input_shape: List[int]) -> None: + if function in ['Acos', 'Asin']: + data = np.random.uniform(-1, 1, input_shape).astype(np.float32) + else: + data = np.random.randn(*input_shape).astype(np.float32) + + _test_functions(function, data=data, opset_version=11) + + +@pytest.mark.parametrize( + 'function,input_shape', + ( + ('Tanh', [8, 3, 32, 32]), + ), +) +def test_hyperbolic_functions(function: str, input_shape: List[int]) -> None: + data = np.random.randn(*input_shape).astype(np.float32) + _test_functions(function, data=data, opset_version=11) diff --git a/tests/node_converters/tile_test.py b/tests/node_converters/tile_test.py index 466ad71a..f4b78734 100644 --- a/tests/node_converters/tile_test.py +++ b/tests/node_converters/tile_test.py @@ -4,7 +4,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -32,7 +32,7 @@ def _test_tile( inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') diff --git a/tests/node_converters/topk_test.py b/tests/node_converters/topk_test.py index a04fa0e5..53de112c 100644 --- a/tests/node_converters/topk_test.py +++ b/tests/node_converters/topk_test.py @@ -3,7 +3,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -26,7 +26,7 @@ def _test_topk(data: np.ndarray, k: np.ndarray, **kwargs) -> None: inputs_example=test_inputs, outputs_info=outputs_info, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_topk() -> None: diff --git a/tests/node_converters/transpose_test.py b/tests/node_converters/transpose_test.py index 193c2085..5d55840f 100644 --- a/tests/node_converters/transpose_test.py +++ b/tests/node_converters/transpose_test.py @@ -3,7 +3,7 @@ import numpy as np import onnx -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -20,7 +20,7 @@ def _test_transpose(data: np.ndarray, **kwargs) -> None: initializers={}, inputs_example=test_inputs, ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_transpose() -> None: diff --git a/tests/node_converters/unsqueeze_test.py b/tests/node_converters/unsqueeze_test.py index a0ddfbd8..61060ad4 100644 --- a/tests/node_converters/unsqueeze_test.py +++ b/tests/node_converters/unsqueeze_test.py @@ -8,7 +8,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -44,7 +44,7 @@ def _test_unsqueeze( shape=np.expand_dims(input_tensor, axis=axes).shape, ),), ) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) # Known warning. Shape Inference do not work properly in opset_version=9 and negative indices. diff --git a/tests/node_converters/where_test.py b/tests/node_converters/where_test.py index 3bc09d1c..eef0ffa6 100644 --- a/tests/node_converters/where_test.py +++ b/tests/node_converters/where_test.py @@ -3,7 +3,7 @@ from onnx.helper import make_tensor_value_info from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from tests.utils.common import check_model +from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -26,7 +26,7 @@ def where_test( ) ] model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, outputs_info=outputs_info) - check_model(model, test_inputs) + check_onnx_model(model, test_inputs) def test_where() -> None: diff --git a/tests/utils/common.py b/tests/utils/common.py index 494a36f6..b3970ab1 100644 --- a/tests/utils/common.py +++ b/tests/utils/common.py @@ -189,12 +189,13 @@ def convert_onnx2torch2onnx( return onnx.load_from_string(tmp_file.getvalue()) -def _check_model( +def _check_onnx_model( onnx_model: ModelProto, onnx_inputs: Dict[str, Any], onnx_torch_check_function: Callable, torch_cpu_cuda_check_function: Optional[Callable] = None, onnx_torch2onnx_check_function: Optional[Callable] = None, + opset_version: int = 13, ) -> None: ort_outputs = calc_ort_outputs(onnx_model, onnx_inputs) torch_outputs = calc_torch_outputs(onnx_model, onnx_inputs, device='cpu') @@ -206,17 +207,18 @@ def _check_model( torch_cpu_cuda_check_function(torch_outputs, torch_cuda_outputs) if onnx_torch2onnx_check_function is not None: - torch2onnx_model = convert_onnx2torch2onnx(onnx_model, inputs=onnx_inputs) + torch2onnx_model = convert_onnx2torch2onnx(onnx_model, inputs=onnx_inputs, opset_version=opset_version) ort_torch2onnx_outputs = calc_ort_outputs(torch2onnx_model, onnx_inputs, skip_unused_inputs=True) onnx_torch2onnx_check_function(ort_outputs, ort_torch2onnx_outputs) -def check_model( +def check_onnx_model( onnx_model: ModelProto, onnx_inputs: Dict[str, Any], atol_onnx_torch: float = 0.0, atol_torch_cpu_cuda: float = 0.0, atol_onnx_torch2onnx: float = 0.0, + opset_version: int = 13, ) -> None: def onnx_torch_check_function(onnx_output, torch_output): if len(onnx_output) == 1: @@ -226,7 +228,7 @@ def onnx_torch_check_function(onnx_output, torch_output): assert np.all(np.isclose(a, b, atol=atol_onnx_torch)), 'ort and torch outputs have significant difference' def torch_cpu_cuda_check_function(torch_cpu_output, torch_cuda_output): - if not isinstance(torch_cpu_output, List): + if not isinstance(torch_cpu_output, (List, Tuple)): torch_cpu_output = [torch_cpu_output] torch_cuda_output = [torch_cuda_output] @@ -243,10 +245,37 @@ def onnx_torch2onnx_check_function(onnx_output, torch2onnx_output): return True - _check_model( + _check_onnx_model( onnx_model=onnx_model, onnx_inputs=onnx_inputs, onnx_torch_check_function=onnx_torch_check_function, torch_cpu_cuda_check_function=torch_cpu_cuda_check_function, onnx_torch2onnx_check_function=onnx_torch2onnx_check_function, + opset_version=opset_version, ) + + +def check_torch_model( + torch_model: torch.nn.Module, + onnx_inputs: Dict[str, Any], + atol_onnx_torch: float = 0.0, + atol_torch_cpu_cuda: float = 0.0, + atol_onnx_torch2onnx: float = 0.0, + opset_version: int = 13, +) -> None: + arguments = locals() + input_names = list(onnx_inputs.keys()) + args = tuple(torch.tensor(arg) for arg in onnx_inputs.values()) + + with io.BytesIO() as tmp_file: + torch.onnx.export( + model=torch_model, + args=args, + f=tmp_file, + input_names=input_names, + opset_version=opset_version, + ) + + arguments.pop('torch_model') + arguments['onnx_model'] = onnx.load_from_string(tmp_file.getvalue()) + check_onnx_model(**arguments) diff --git a/tests/utils/resources.py b/tests/utils/resources.py index cc39c20a..3db3add3 100644 --- a/tests/utils/resources.py +++ b/tests/utils/resources.py @@ -1,18 +1,31 @@ import tarfile +import urllib.request from pathlib import Path -import requests from google_drive_downloader import GoogleDriveDownloader from tests import DATASETS_DIR from tests import MODELS_DIR -_ONNX_MODELS_URLS = { - 'resnet50': 'https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.onnx', -} +_BASE_URL = 'https://gitlab.expasoft.com/p.ivanov/onnx2torch_data/-/raw/main/models_for_tests' + +_CHKP_DETECTION_URL = f'{_BASE_URL}/detection' +_CHKP_SEGMENTATION_URL = f'{_BASE_URL}/segmentation' +_CHKP_TRANSFORMERS_URL = f'{_BASE_URL}/transformers' _ONNX_MODELS_IDS = { - 'ssdlite': '1b_daJsjdIeOOWUEKIfru_0hzyE49JMqf', + 'deeplabv3_mnv3_large': f'{_CHKP_SEGMENTATION_URL}/deeplabv3_mobilenet_v3_large.onnx', + 'deeplabv3_plus_resnet101': f'{_CHKP_SEGMENTATION_URL}/deeplabv3_resnet101_dimans.onnx', + 'hrnet': f'{_CHKP_SEGMENTATION_URL}/hrnet.onnx', + 'unet': f'{_CHKP_SEGMENTATION_URL}/unet_resnet34.onnx', + 'retinanet': f'{_CHKP_DETECTION_URL}/retinanet_r50_fpn.onnx', + 'ssd300_vgg': f'{_CHKP_DETECTION_URL}/ssd300.onnx', + 'ssdlite': f'{_CHKP_DETECTION_URL}/ssdlite.onnx', + 'yolov3_d53': f'{_CHKP_DETECTION_URL}/yolov3_d53_tuned_shape.onnx', + 'yolov5_ultralitics': f'{_CHKP_DETECTION_URL}/yolov5_ultralitics.onnx', + 'swin': f'{_CHKP_TRANSFORMERS_URL}/swin.onnx', + 'vit': f'{_CHKP_TRANSFORMERS_URL}/vit.onnx', + 'resnet50': 'https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx', } _MINIMAL_DATASETS_ID = '1Vd7qfQotrRADPLFxViA2tRpz7tBymR31' @@ -21,18 +34,9 @@ def get_model_path(name: str) -> Path: model_path = MODELS_DIR / f'{name}.onnx' if not model_path.exists(): - if name in _ONNX_MODELS_URLS: - url = _ONNX_MODELS_URLS[name] - with model_path.open(mode='wb') as model_file: - response = requests.get(url, stream=True) - for chunk in response.iter_content(chunk_size=4*1024): - model_file.write(chunk) - elif name in _ONNX_MODELS_IDS: - GoogleDriveDownloader.download_file_from_google_drive( - file_id=_ONNX_MODELS_IDS[name], - dest_path=model_path, - overwrite=True, - ) + if name in _ONNX_MODELS_IDS: + url = _ONNX_MODELS_IDS[name] + urllib.request.urlretrieve(url=url, filename=model_path) else: raise RuntimeError('Cannot find model path.')