Skip to content

Commit

Permalink
feat: InstanceNormalization
Browse files Browse the repository at this point in the history
feat: InstanceNormalization
  • Loading branch information
senysenyseny16 authored Feb 6, 2023
2 parents c2ba45a + 903f43f commit 0d4ec08
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from onnx2torch.node_converters.gemm import *
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.instance_norm import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.lrn import *
from onnx2torch.node_converters.matmul import *
Expand Down
88 changes: 88 additions & 0 deletions onnx2torch/node_converters/instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
__all__ = [
'OnnxInstanceNorm',
]

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import get_shape_from_value_info
from onnx2torch.utils.common import onnx_mapping_from_node

_IN_CLASS_FROM_SPATIAL_RANK = {
0: nn.InstanceNorm1d,
1: nn.InstanceNorm1d,
2: nn.InstanceNorm2d,
3: nn.InstanceNorm3d,
}


class OnnxInstanceNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, momentum: float, epsilon: float):
super().__init__()
self.momentum = momentum
self.epsilon = epsilon

def forward( # pylint: disable=missing-function-docstring
self,
input_data: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return F.instance_norm(
input=input_data,
running_mean=None,
running_var=None,
weight=weight,
bias=bias,
use_input_stats=True,
momentum=self.momentum,
eps=self.epsilon,
)


@add_converter(operation_type='InstanceNormalization', version=1)
@add_converter(operation_type='InstanceNormalization', version=6)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes
epsilon = node_attributes.get('epsilon', 1e-5)
momentum = 0.1

if all(value_name in graph.initializers for value_name in node.input_values[1:]):
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)
spatial_rank = len(input_shape) - 2
try:
in_class = _IN_CLASS_FROM_SPATIAL_RANK[spatial_rank]
except KeyError as exc:
raise NotImplementedError(
f'InstanceNorm operation with spatial rank == {spatial_rank} is not implemented'
) from exc

scale_value_name = node.input_values[1]
bias_value_name = node.input_values[2]

scale = graph.initializers[scale_value_name].to_torch()
torch_module = in_class(
num_features=scale.size()[0],
eps=epsilon,
momentum=momentum,
affine=True,
track_running_stats=False,
)
with torch.no_grad():
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()

onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
else:
torch_module = OnnxInstanceNorm(momentum=momentum, epsilon=epsilon)
onnx_mapping = onnx_mapping_from_node(node)

return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| Hardmax | N | |
| Identity | Y | |
| If | N | |
| InstanceNormalization | N | |
| InstanceNormalization | Y | |
| IsInf | N | |
| IsNaN | N | |
| LRN | Y | |
Expand Down
47 changes: 47 additions & 0 deletions tests/node_converters/instance_norm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List

import numpy as np
import onnx
import pytest

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


@pytest.mark.parametrize('parameters_as_inputs', (True, False))
@pytest.mark.parametrize(
'input_shape',
(
# 1d
[2, 3, 16],
[2, 1, 7],
# 2d
[2, 3, 16, 16],
[2, 1, 7, 16],
# 3d
[2, 3, 16, 16, 16],
[2, 1, 16, 7, 16],
),
)
def test_instance_norm( # pylint: disable=missing-function-docstring
input_shape: List[int],
parameters_as_inputs: bool,
) -> None:
num_features = input_shape[1]
x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32)
scale = np.random.uniform(low=0.0, high=1.0, size=num_features).astype(np.float32)
bias = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32)

inputs = {'input': x}
parameters = {'scale': scale, 'bias': bias}
initializers = {}

if parameters_as_inputs:
inputs.update(parameters)
else:
initializers.update(parameters)

node = onnx.helper.make_node(op_type='InstanceNormalization', inputs=['input', 'scale', 'bias'], outputs=['y'])

model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs)
check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6)

0 comments on commit 0d4ec08

Please sign in to comment.