Skip to content

Commit

Permalink
feat: add LRN
Browse files Browse the repository at this point in the history
* add lrn 13 operation but no test yet

* switch lrn support to y

* add vs code settings to gitignore

* code linted and cleaned accordingly

* formatted with black

* init file (linted and formatted)

* lrn also safe for onnx opset 1

* update after PR comments

* don't check onnx model cuda if unavailable

* Revert "don't check onnx model cuda if unavailable"

This reverts commit 4f1ba13.

* refactor imports with isort
  • Loading branch information
monicathieu authored Sep 9, 2022
1 parent 5424c8c commit 478ac2b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ venv.bak/
# Rope project settings
.ropeproject

# VS Code project settings
.vscode

# mkdocs documentation
/site

Expand Down
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.lrn import *
from onnx2torch.node_converters.matmul import *
from onnx2torch.node_converters.max_pool import *
from onnx2torch.node_converters.mean import *
Expand Down
23 changes: 23 additions & 0 deletions onnx2torch/node_converters/lrn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
__all__ = []

from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node


@add_converter(operation_type='LRN', version=13)
@add_converter(operation_type='LRN', version=1)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
size = node.attributes.get('size')
alpha = node.attributes.get('alpha', 0.0001)
beta = node.attributes.get('beta', 0.75)
k = node.attributes.get('bias', 1) # pylint: disable=invalid-name

return OperationConverterResult(
torch_module=nn.LocalResponseNorm(size=size, alpha=alpha, beta=beta, k=k),
onnx_mapping=onnx_mapping_from_node(node=node),
)
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 15, recommended ops
| InstanceNormalization | N | |
| IsInf | N | |
| IsNaN | N | |
| LRN | N | |
| LRN | Y | |
| LSTM | N | |
| LeakyRelu | Y | |
| Less | Y | |
Expand Down
37 changes: 37 additions & 0 deletions tests/node_converters/lrn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from random import randrange

import numpy as np
import onnx

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


def _test_lrn(data: np.ndarray, alpha: float, beta: float, bias: float, size: int) -> None:
test_inputs = {'input_tensor': data}
node = onnx.helper.make_node(
op_type='LRN',
inputs=list(test_inputs),
outputs=['y'],
alpha=alpha, # ONNX attributes are passed as regular keyword arguments.
beta=beta,
bias=bias,
size=size,
)

model = make_model_from_nodes(
nodes=node,
initializers={},
inputs_example=test_inputs,
)
check_onnx_model(model, test_inputs)


def test_lrn() -> None: # pylint: disable=missing-function-docstring
shape = (1, 3, 227, 227)
data = np.random.random_sample(shape).astype(np.float32)
alpha = np.random.uniform(low=0.0, high=1.0)
beta = np.random.uniform(low=0.0, high=1.0)
bias = np.random.uniform(low=1.0, high=5.0)
size = randrange(start=1, stop=10, step=2) # diameter of channels, not radius, must be odd
_test_lrn(data=data, alpha=alpha, beta=beta, bias=bias, size=size)

0 comments on commit 478ac2b

Please sign in to comment.