Skip to content

Commit

Permalink
feat: LayerNormalization
Browse files Browse the repository at this point in the history
Co-authored-by: Mason Ma <[email protected]>
  • Loading branch information
senysenyseny16 and JohnMasoner authored Jul 7, 2023
1 parent ef6f06a commit 0df1bf7
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.instance_norm import *
from onnx2torch.node_converters.layer_norm import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.lrn import *
from onnx2torch.node_converters.matmul import *
Expand Down
78 changes: 78 additions & 0 deletions onnx2torch/node_converters/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
__all__ = [
'OnnxLayerNorm',
]

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import get_shape_from_value_info
from onnx2torch.utils.common import onnx_mapping_from_node

AXIS_DEFAULT_VALUE = -1
EPSILON_DEFAULT_VALUE = 1e-5


class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, axis: int, epsilon: float):
super().__init__()
self.axis = axis
self.epsilon = epsilon

def forward( # pylint: disable=missing-function-docstring
self,
inputs: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normalized_shape = inputs.shape[self.axis :]
return F.layer_norm(
input=inputs,
normalized_shape=normalized_shape,
weight=scale,
bias=bias,
eps=self.epsilon,
)


@add_converter(operation_type='LayerNormalization', version=17)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes

axis = node_attributes.get('axis', AXIS_DEFAULT_VALUE)
epsilon = node_attributes.get('epsilon', EPSILON_DEFAULT_VALUE)

if all(value_name in graph.initializers for value_name in node.input_values[1:]):
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)

torch_module = nn.LayerNorm(
normalized_shape=input_shape[axis:],
eps=epsilon,
elementwise_affine=True,
)

scale_value_name = node.input_values[1]
bias_value_name = node.input_values[2] if len(node.input_values) > 2 else None

with torch.no_grad():
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
if bias_value_name is not None:
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()

onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
else:
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)
torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon)
onnx_mapping = onnx_mapping_from_node(node)

return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)
1 change: 1 addition & 0 deletions operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| InstanceNormalization | Y | |
| IsInf | N | |
| IsNaN | N | |
| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented |
| LRN | Y | |
| LSTM | N | |
| LeakyRelu | Y | |
Expand Down
76 changes: 76 additions & 0 deletions tests/node_converters/layer_norm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# pylint: disable=missing-function-docstring
from typing import List
from typing import Optional

import numpy as np
import onnx
import pytest

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


def _test_layer_norm(
x: np.ndarray,
scale: np.ndarray,
bias: Optional[np.ndarray],
axis: int,
parameters_as_inputs: bool,
) -> None:
inputs = {'input': x}
parameters = {'scale': scale}
if bias is not None:
parameters['bias'] = bias

initializers = {}

if parameters_as_inputs:
inputs.update(parameters)
else:
initializers.update(parameters)

node = onnx.helper.make_node(
op_type='LayerNormalization',
inputs=['input', 'scale', 'bias'] if bias is not None else ['input', 'scale'],
outputs=['y'],
axis=axis,
)
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17)
check_onnx_model(
onnx_model=model,
onnx_inputs=inputs,
atol_onnx_torch=1e-5,
atol_torch_cpu_cuda=1e-5,
atol_onnx_torch2onnx=1e-5,
)


@pytest.mark.parametrize('parameters_as_inputs', (True, False))
@pytest.mark.parametrize(
'input_shape',
(
[2, 3, 16],
[3, 1, 224],
[4, 3, 16, 16],
[5, 1, 32, 32],
[6, 3, 16, 16, 8],
[7, 1, 7, 7, 16],
),
)
def test_layer_norm(input_shape: List[int], parameters_as_inputs: bool) -> None:
x = np.random.randn(*input_shape).astype(np.float32)

for axis in [*range(len(input_shape))] + [-1]:
normalized_shape = input_shape[axis:]

scale = np.random.randn(*normalized_shape).astype(np.float32)
bias = np.random.randn(*normalized_shape).astype(np.float32)

for bias_ in [bias, None]:
_test_layer_norm(
x=x,
scale=scale,
bias=bias_,
axis=axis,
parameters_as_inputs=parameters_as_inputs,
)

0 comments on commit 0df1bf7

Please sign in to comment.