-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add lrn 13 operation but no test yet * switch lrn support to y * add vs code settings to gitignore * code linted and cleaned accordingly * formatted with black * init file (linted and formatted) * lrn also safe for onnx opset 1 * update after PR comments * don't check onnx model cuda if unavailable * Revert "don't check onnx model cuda if unavailable" This reverts commit 4f1ba13. * refactor imports with isort
- Loading branch information
1 parent
5424c8c
commit 478ac2b
Showing
5 changed files
with
65 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
__all__ = [] | ||
|
||
from torch import nn | ||
|
||
from onnx2torch.node_converters.registry import add_converter | ||
from onnx2torch.onnx_graph import OnnxGraph | ||
from onnx2torch.onnx_node import OnnxNode | ||
from onnx2torch.utils.common import OperationConverterResult | ||
from onnx2torch.utils.common import onnx_mapping_from_node | ||
|
||
|
||
@add_converter(operation_type='LRN', version=13) | ||
@add_converter(operation_type='LRN', version=1) | ||
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument | ||
size = node.attributes.get('size') | ||
alpha = node.attributes.get('alpha', 0.0001) | ||
beta = node.attributes.get('beta', 0.75) | ||
k = node.attributes.get('bias', 1) # pylint: disable=invalid-name | ||
|
||
return OperationConverterResult( | ||
torch_module=nn.LocalResponseNorm(size=size, alpha=alpha, beta=beta, k=k), | ||
onnx_mapping=onnx_mapping_from_node(node=node), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from random import randrange | ||
|
||
import numpy as np | ||
import onnx | ||
|
||
from tests.utils.common import check_onnx_model | ||
from tests.utils.common import make_model_from_nodes | ||
|
||
|
||
def _test_lrn(data: np.ndarray, alpha: float, beta: float, bias: float, size: int) -> None: | ||
test_inputs = {'input_tensor': data} | ||
node = onnx.helper.make_node( | ||
op_type='LRN', | ||
inputs=list(test_inputs), | ||
outputs=['y'], | ||
alpha=alpha, # ONNX attributes are passed as regular keyword arguments. | ||
beta=beta, | ||
bias=bias, | ||
size=size, | ||
) | ||
|
||
model = make_model_from_nodes( | ||
nodes=node, | ||
initializers={}, | ||
inputs_example=test_inputs, | ||
) | ||
check_onnx_model(model, test_inputs) | ||
|
||
|
||
def test_lrn() -> None: # pylint: disable=missing-function-docstring | ||
shape = (1, 3, 227, 227) | ||
data = np.random.random_sample(shape).astype(np.float32) | ||
alpha = np.random.uniform(low=0.0, high=1.0) | ||
beta = np.random.uniform(low=0.0, high=1.0) | ||
bias = np.random.uniform(low=1.0, high=5.0) | ||
size = randrange(start=1, stop=10, step=2) # diameter of channels, not radius, must be odd | ||
_test_lrn(data=data, alpha=alpha, beta=beta, bias=bias, size=size) |