Skip to content

Commit

Permalink
Limit the ONNX & PT operations with bias
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Mar 6, 2024
1 parent c0fe3a7 commit bf6b06e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 23 deletions.
2 changes: 0 additions & 2 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@
# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
onnx_metatypes.ONNXConvolutionTransposeMetatype,
onnx_metatypes.ONNXGemmMetatype,
onnx_metatypes.ONNXMatMulMetatype,
]
Expand Down
48 changes: 40 additions & 8 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
from typing import Dict, List, Optional, Type

import onnx
Expand Down Expand Up @@ -44,14 +45,15 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
@classmethod
def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Optional[Type[OperatorMetatype]]:
matches = []
for subtype in cls.get_subtypes():
subtypes_list = deque(cls.get_subtypes())
while subtypes_list:
subtype = subtypes_list.popleft()
if subtype.matches(model, node):
subtypes_list.extend(subtype.get_subtypes())
matches.append(subtype)
if len(matches) > 1:
raise nncf.InternalError("Multiple subtypes match operator call - cannot determine single subtype.")
if not matches:
return None
return matches[0]
return matches[-1]


class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
Expand Down Expand Up @@ -85,6 +87,22 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_depthwise_conv(model, node)


@ONNX_OPERATION_METATYPES.register()
class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "GroupConvOp"
op_names = ["Conv"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
weight_channel_axis = 0
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]

@classmethod
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_group_conv(node)


@ONNX_OPERATION_METATYPES.register()
class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "ConvOp"
Expand All @@ -94,7 +112,7 @@ class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]
subtypes = [ONNXGroupConvolutionMetatype]


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -699,6 +717,23 @@ def get_tensor_edge_name(
return None


def _is_group_conv(node: onnx.NodeProto) -> bool:
"""
Returns True if the convolution is group, False - otherwise.
Group convolution is a convolution with the group attribute.
:param node: Convolution node to check whether it is depthwise.
:return: True if the convolution is group, False - otherwise.
"""
conv_group = None
for attribute in node.attribute:
if attribute.name == "group":
conv_group = onnx.helper.get_attribute_value(attribute)
if conv_group is None or conv_group == 1:
return False
return True


def _is_depthwise_conv(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
"""
Returns True if the convolution is depthwise, False - otherwise.
Expand All @@ -711,12 +746,9 @@ def _is_depthwise_conv(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
:param node: Convolution node to check whether it is depthwise.
:return: True if the convolution is depthwise, False - otherwise.
"""
conv_group = None
for attribute in node.attribute:
if attribute.name == "group":
conv_group = onnx.helper.get_attribute_value(attribute)
if conv_group is None:
return False
weight_tensor_value = None
initializer_name = node.input[1]
for init in model.graph.initializer:
Expand Down
14 changes: 7 additions & 7 deletions nncf/openvino/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,17 @@
]


# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS = [
ov_metatypes.OVConvolutionMetatype,
ov_metatypes.OVMatMulMetatype,
]


CONV_OPERATIONS = [
ov_metatypes.OVConvolutionMetatype,
ov_metatypes.OVDepthwiseConvolutionMetatype,
ov_metatypes.OVGroupConvolutionMetatype,
ov_metatypes.OVConvolutionBackpropDataMetatype,
ov_metatypes.OVGroupConvolutionBackpropDataMetatype,
]


# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS = [
*CONV_OPERATIONS,
ov_metatypes.OVMatMulMetatype,
]
8 changes: 2 additions & 6 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,12 +1038,8 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
PTModuleConv1dMetatype,
PTModuleConv2dMetatype,
PTModuleConv3dMetatype,
PTDepthwiseConv1dSubtype,
PTDepthwiseConv2dSubtype,
PTDepthwiseConv3dSubtype,
PTModuleConvTranspose1dMetatype,
PTModuleConvTranspose2dMetatype,
PTModuleConvTranspose3dMetatype,
# Need to verify that Linear handles correctly
PTModuleLinearMetatype,
]

OPERATORS_FUSED_METATYPES = [
Expand Down

0 comments on commit bf6b06e

Please sign in to comment.