Skip to content

Commit

Permalink
[ONNX] Add support of correct quantization of MatMul (#1917)
Browse files Browse the repository at this point in the history
### Changes

Add correct handling MatMul ops during quantization for every case with
only activations and activation with weight.

- Introduce ONNXLayerAttributes, which are assigned to every NNCFNode.
- Split weight_port_ids to constan_port_ids and possible_weight_ports.
possible_weight_ports are used to determine weight dynamically.
- Add logic to determine whether a node has weight
(_get_weight_edge_name)
- Add transpose attribute for GEMM node


### Reason for changes

To get the most optimized performance after quantization.

### Related tickets

112530
95156

### Tests

Synthetic models were added in graph tests
gpt-2 and bertsquad were added in graph test
Add test on scales after quantization
Add test on transpose_axis func
  • Loading branch information
kshpv authored Jul 18, 2023
1 parent 2a8604b commit 044bfa7
Show file tree
Hide file tree
Showing 56 changed files with 16,774 additions and 6,038 deletions.
119 changes: 91 additions & 28 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional, Type

import onnx
Expand All @@ -34,7 +33,7 @@ def get_subtypes(cls) -> List[Type[OperatorMetatype]]:
return cls.subtypes

@classmethod
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Optional[bool]:
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return node.op_type in cls.op_names

@classmethod
Expand All @@ -50,33 +49,30 @@ def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Opti
return matches[0]


@dataclass
class OpWeightDef:
class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
"""
Contains the information about the weight and bias of the operation.
Metatype which could have weights.
:param weight_channel_axis: Axis for weight per-channel quantization, meaning the number of output filters.
:param weight_port_id: Input port of the node's weight.
:param weight_port_ids: Input ports of the node's weight.
If the value is None the weight_port_id should be determined dynamically.
:param bias_port_id: Input port of the node's bias.
If the value is None it means that the Metatype does not have bias.
"""

weight_channel_axis: int
weight_port_id: Optional[int] = None
weight_port_ids: Optional[List[int]] = None
bias_port_id: Optional[int] = None


class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
weight_definitions = None # type: OpWeightDef


@ONNX_OPERATION_METATYPES.register()
class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "DepthwiseConvOp"
op_names = ["Conv"]
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
weight_definitions = OpWeightDef(weight_channel_axis=0, weight_port_id=1, bias_port_id=2)
weight_channel_axis = 0
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1

@classmethod
Expand All @@ -89,7 +85,9 @@ class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "ConvOp"
op_names = ["Conv"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
weight_definitions = OpWeightDef(weight_channel_axis=0, weight_port_id=1, bias_port_id=2)
weight_channel_axis = 0
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]

Expand All @@ -99,17 +97,33 @@ class ONNXConvolutionTransposeMetatype(ONNXOpWithWeightsMetatype):
name = "ConvTransposeOp"
op_names = ["ConvTranspose"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
weight_definitions = OpWeightDef(weight_channel_axis=1, weight_port_id=1, bias_port_id=2)
weight_channel_axis = 1
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1


@ONNX_OPERATION_METATYPES.register()
class ONNXLinearMetatype(ONNXOpWithWeightsMetatype):
name = "LinearOp"
class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
name = "GemmOp"
op_names = ["Gemm"]
hw_config_names = [HWConfigOpName.MATMUL]
# TODO(kshpv): ticket:95156
weight_definitions = OpWeightDef(weight_channel_axis=0, weight_port_id=1, bias_port_id=2)
weight_channel_axis = -1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1


@ONNX_OPERATION_METATYPES.register()
class ONNXMatMulMetatype(ONNXOpMetatype):
name = "MatMulOp"
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1


Expand Down Expand Up @@ -413,13 +427,6 @@ class ONNXRoiAlignMetatype(ONNXOpMetatype):
op_names = ["RoiAlign"]


@ONNX_OPERATION_METATYPES.register()
class ONNXMatMulMetatype(ONNXOpMetatype):
name = "MatMulOp"
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherMetatype(ONNXOpMetatype):
name = "GatherOp"
Expand Down Expand Up @@ -589,13 +596,16 @@ class ONNXDeformableConvolutionMetatype(ONNXOpMetatype):
op_names = ["DeformConv"]


WEIGHT_LAYER_METATYPES = [
CONSTANT_WEIGHT_LAYER_METATYPES = [
ONNXConvolutionMetatype,
ONNXDepthwiseConvolutionMetatype,
ONNXConvolutionTransposeMetatype,
ONNXLinearMetatype,
]

MATMUL_METATYPES = [ONNXGemmMetatype, ONNXMatMulMetatype]

GENERAL_WEIGHT_LAYER_METATYPES = CONSTANT_WEIGHT_LAYER_METATYPES + MATMUL_METATYPES

# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS_METATYPES = [
ONNXConvolutionMetatype,
Expand All @@ -607,11 +617,64 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
:return: List of operator metatypes .
:return: List of operator metatypes.
"""
return list(ONNX_OPERATION_METATYPES.registry_dict.values())


def get_metatype(model: onnx.ModelProto, node: onnx.NodeProto) -> ONNXOpMetatype:
"""
Returns matched ONNXOpMetatype metatype to a ONNX node.
:param model: ONNX model.
:param node: Node from ONNX model.
:return: Matched metatype.
"""
metatype = ONNX_OPERATION_METATYPES.get_operator_metatype_by_op_name(node.op_type)
if metatype.get_subtypes():
subtype = metatype.determine_subtype(model, node)
if subtype is not None:
metatype = subtype
return metatype


def get_constant_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]:
"""
Returns port ids on which metatype must have a weight based on Operation definition.
:param metatype: Metatype.
:return: Port ids.
"""
if metatype in CONSTANT_WEIGHT_LAYER_METATYPES:
return metatype.weight_port_ids
return []


def get_possible_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]:
"""
Returns weight port ids on which metatype could have a weight.
Example: ONNXMatMulMetatype could have activations or weights on input port ids: 0, 1
:param metatype: Metatype.
:return: Port ids.
"""
if metatype in MATMUL_METATYPES:
return metatype.possible_weight_ports
return []


def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int]:
"""
Returns input port id, where a bias tensor should output.
:param node: Node, for which input port id is returned,
:return: Input port id, where a weight bias should output or None if node can not have bias.
"""
if metatype in OPERATIONS_WITH_BIAS_METATYPES:
return metatype.bias_port_id
return None


def _is_depthwise_conv(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
"""
Returns True if the convolution is depthwise, False - otherwise.
Expand Down
Loading

0 comments on commit 044bfa7

Please sign in to comment.