From 0df1bf76fc133f12f33700e4ffd378a16349afe5 Mon Sep 17 00:00:00 2001 From: Arseny <82811840+senysenyseny16@users.noreply.github.com> Date: Fri, 7 Jul 2023 13:09:00 +0300 Subject: [PATCH] feat: LayerNormalization Co-authored-by: Mason Ma --- onnx2torch/node_converters/__init__.py | 1 + onnx2torch/node_converters/layer_norm.py | 78 ++++++++++++++++++++++++ operators.md | 1 + tests/node_converters/layer_norm_test.py | 76 +++++++++++++++++++++++ 4 files changed, 156 insertions(+) create mode 100644 onnx2torch/node_converters/layer_norm.py create mode 100644 tests/node_converters/layer_norm_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index dd40354a..c850d006 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -21,6 +21,7 @@ from onnx2torch.node_converters.global_average_pool import * from onnx2torch.node_converters.identity import * from onnx2torch.node_converters.instance_norm import * +from onnx2torch.node_converters.layer_norm import * from onnx2torch.node_converters.logical import * from onnx2torch.node_converters.lrn import * from onnx2torch.node_converters.matmul import * diff --git a/onnx2torch/node_converters/layer_norm.py b/onnx2torch/node_converters/layer_norm.py new file mode 100644 index 00000000..c31c7a2e --- /dev/null +++ b/onnx2torch/node_converters/layer_norm.py @@ -0,0 +1,78 @@ +__all__ = [ + 'OnnxLayerNorm', +] + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node + +AXIS_DEFAULT_VALUE = -1 +EPSILON_DEFAULT_VALUE = 1e-5 + + +class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, axis: int, epsilon: float): + super().__init__() + self.axis = axis + self.epsilon = epsilon + + def forward( # pylint: disable=missing-function-docstring + self, + inputs: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + normalized_shape = inputs.shape[self.axis :] + return F.layer_norm( + input=inputs, + normalized_shape=normalized_shape, + weight=scale, + bias=bias, + eps=self.epsilon, + ) + + +@add_converter(operation_type='LayerNormalization', version=17) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + node_attributes = node.attributes + + axis = node_attributes.get('axis', AXIS_DEFAULT_VALUE) + epsilon = node_attributes.get('epsilon', EPSILON_DEFAULT_VALUE) + + if all(value_name in graph.initializers for value_name in node.input_values[1:]): + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + + torch_module = nn.LayerNorm( + normalized_shape=input_shape[axis:], + eps=epsilon, + elementwise_affine=True, + ) + + scale_value_name = node.input_values[1] + bias_value_name = node.input_values[2] if len(node.input_values) > 2 else None + + with torch.no_grad(): + torch_module.weight.data = graph.initializers[scale_value_name].to_torch() + if bias_value_name is not None: + torch_module.bias.data = graph.initializers[bias_value_name].to_torch() + + onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) + else: + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon) + onnx_mapping = onnx_mapping_from_node(node) + + return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) diff --git a/operators.md b/operators.md index c0aa5768..7753ee34 100644 --- a/operators.md +++ b/operators.md @@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | InstanceNormalization | Y | | | IsInf | N | | | IsNaN | N | | +| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented | | LRN | Y | | | LSTM | N | | | LeakyRelu | Y | | diff --git a/tests/node_converters/layer_norm_test.py b/tests/node_converters/layer_norm_test.py new file mode 100644 index 00000000..8341f7a6 --- /dev/null +++ b/tests/node_converters/layer_norm_test.py @@ -0,0 +1,76 @@ +# pylint: disable=missing-function-docstring +from typing import List +from typing import Optional + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_layer_norm( + x: np.ndarray, + scale: np.ndarray, + bias: Optional[np.ndarray], + axis: int, + parameters_as_inputs: bool, +) -> None: + inputs = {'input': x} + parameters = {'scale': scale} + if bias is not None: + parameters['bias'] = bias + + initializers = {} + + if parameters_as_inputs: + inputs.update(parameters) + else: + initializers.update(parameters) + + node = onnx.helper.make_node( + op_type='LayerNormalization', + inputs=['input', 'scale', 'bias'] if bias is not None else ['input', 'scale'], + outputs=['y'], + axis=axis, + ) + model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17) + check_onnx_model( + onnx_model=model, + onnx_inputs=inputs, + atol_onnx_torch=1e-5, + atol_torch_cpu_cuda=1e-5, + atol_onnx_torch2onnx=1e-5, + ) + + +@pytest.mark.parametrize('parameters_as_inputs', (True, False)) +@pytest.mark.parametrize( + 'input_shape', + ( + [2, 3, 16], + [3, 1, 224], + [4, 3, 16, 16], + [5, 1, 32, 32], + [6, 3, 16, 16, 8], + [7, 1, 7, 7, 16], + ), +) +def test_layer_norm(input_shape: List[int], parameters_as_inputs: bool) -> None: + x = np.random.randn(*input_shape).astype(np.float32) + + for axis in [*range(len(input_shape))] + [-1]: + normalized_shape = input_shape[axis:] + + scale = np.random.randn(*normalized_shape).astype(np.float32) + bias = np.random.randn(*normalized_shape).astype(np.float32) + + for bias_ in [bias, None]: + _test_layer_norm( + x=x, + scale=scale, + bias=bias_, + axis=axis, + parameters_as_inputs=parameters_as_inputs, + )