diff --git a/nncf/onnx/graph/metatypes/groups.py b/nncf/onnx/graph/metatypes/groups.py index 56ea16842dc..6fbbd5db13e 100644 --- a/nncf/onnx/graph/metatypes/groups.py +++ b/nncf/onnx/graph/metatypes/groups.py @@ -10,6 +10,8 @@ # limitations under the License. from nncf.onnx.graph.metatypes import onnx_metatypes +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import get_operator_metatypes QUANTIZE_AGNOSTIC_OPERATIONS = [ onnx_metatypes.ONNXGlobalMaxPoolMetatype, @@ -67,14 +69,19 @@ onnx_metatypes.ONNXMinimumMetatype, ] - CONSTANT_WEIGHT_LAYER_METATYPES = [ - onnx_metatypes.ONNXConvolutionMetatype, - onnx_metatypes.ONNXDepthwiseConvolutionMetatype, - onnx_metatypes.ONNXConvolutionTransposeMetatype, - onnx_metatypes.ONNXEmbeddingMetatype, + metatype + for metatype in get_operator_metatypes() + if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.weight_port_ids ] +POSSIBLE_WEIGHT_LAYER_METATYPES = [ + metatype + for metatype in get_operator_metatypes() + if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.possible_weight_ports +] + +OPERATIONS_WITH_WEIGHTS = list(set().union(CONSTANT_WEIGHT_LAYER_METATYPES, POSSIBLE_WEIGHT_LAYER_METATYPES)) LINEAR_OPERATIONS = [ onnx_metatypes.ONNXConvolutionMetatype, @@ -124,11 +131,6 @@ onnx_metatypes.ONNXMeanMetatype, ] -OPERATIONS_WITH_WEIGHTS = [ - *CONSTANT_WEIGHT_LAYER_METATYPES, - *MATMUL_METATYPES, -] - BATCH_NORMALIZATION_OPERATIONS = [ onnx_metatypes.ONNXBatchNormMetatype, diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index 075b2597643..213d348533a 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -58,17 +58,17 @@ def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Opti class ONNXOpWithWeightsMetatype(ONNXOpMetatype): """ Metatype which could have weights. - - :param weight_channel_axis: Axis for weight per-channel quantization, meaning the number of output filters. - :param weight_port_ids: Input ports of the node's weight. - If the value is None the weight_port_id should be determined dynamically. - :param bias_port_id: Input port of the node's bias. - If the value is None it means that the Metatype does not have bias. + :param weight_channel_axis: Axis for weight per-channel quantization. + :param weight_port_ids: Constant input ports of the node's weight. Defaults to an empty list. + :param bias_port_id: Input port of the node's bias. If the value is None, + it means that the Metatype does not have bias. Defaults to None. + :param possible_weight_ports: Input ports on which weight could be laid. Defaults to an empty list. """ weight_channel_axis: int - weight_port_ids: Optional[List[int]] = None + weight_port_ids: List[int] = [] bias_port_id: Optional[int] = None + possible_weight_ports: List[int] = [] @ONNX_OPERATION_METATYPES.register(is_subtype=True) @@ -131,19 +131,17 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype): op_names = ["Gemm"] hw_config_names = [HWConfigOpName.MATMUL] weight_channel_axis = -1 # For port_id=1 - weight_port_ids = None bias_port_id = 2 possible_weight_ports = [0, 1] output_channel_axis = -1 @ONNX_OPERATION_METATYPES.register() -class ONNXMatMulMetatype(ONNXOpMetatype): +class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype): name = "MatMulOp" op_names = ["MatMul"] hw_config_names = [HWConfigOpName.MATMUL] weight_channel_axis = -1 # For port_id=1 - weight_port_ids = None bias_port_id = 2 possible_weight_ports = [0, 1] output_channel_axis = -1 @@ -454,7 +452,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype): @ONNX_OPERATION_METATYPES.register(is_subtype=True) -class ONNXEmbeddingMetatype(ONNXOpMetatype): +class ONNXEmbeddingMetatype(ONNXOpWithWeightsMetatype): name = "EmbeddingOp" hw_config_names = [HWConfigOpName.EMBEDDING] weight_port_ids = [0] diff --git a/nncf/onnx/graph/nncf_graph_builder.py b/nncf/onnx/graph/nncf_graph_builder.py index 834fcbccced..ccdc870d17b 100644 --- a/nncf/onnx/graph/nncf_graph_builder.py +++ b/nncf/onnx/graph/nncf_graph_builder.py @@ -23,8 +23,8 @@ from nncf.common.graph.operator_metatypes import InputNoopMetatype from nncf.common.graph.operator_metatypes import OutputNoopMetatype from nncf.onnx.graph.metatypes.groups import CONSTANT_WEIGHT_LAYER_METATYPES -from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS +from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype @@ -95,7 +95,7 @@ def get_possible_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]: :param metatype: Metatype. :return: Port ids. """ - if metatype in MATMUL_METATYPES: + if metatype in POSSIBLE_WEIGHT_LAYER_METATYPES: return metatype.possible_weight_ports return [] diff --git a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py index 6a8d59ce312..cd77cc06678 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py @@ -81,8 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: activation_port = 0 - - if hasattr(node.metatype, "possible_weight_ports"): + if node.metatype.possible_weight_ports: activation_ports = deepcopy(node.metatype.possible_weight_ports) for weight_port in node.layer_attributes.weight_attrs: activation_ports.remove(weight_port) diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 94f70f3a931..5953c1789d9 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -41,7 +41,7 @@ torchvision/resnet18_backend_FX_TORCH: torchvision/mobilenet_v3_small_BC_backend_FP32: metric_value: 0.6766 torchvision/mobilenet_v3_small_BC_backend_OV: - metric_value: 0.6669 + metric_value: 0.6681 torchvision/mobilenet_v3_small_BC_backend_ONNX: metric_value: 0.6679 torchvision/mobilenet_v3_small_BC_backend_FX_TORCH: @@ -103,7 +103,7 @@ timm/dpn68_backend_CUDA_TORCH: timm/dpn68_backend_FP32: metric_value: 0.76342 timm/dpn68_backend_ONNX: - metric_value: 0.75906 + metric_value: 0.7592 timm/dpn68_backend_OV: metric_value: 0.75972 timm/dpn68_backend_TORCH: @@ -201,7 +201,7 @@ timm/regnetx_002_backend_CUDA_TORCH: timm/regnetx_002_backend_FP32: metric_value: 0.68756 timm/regnetx_002_backend_ONNX: - metric_value: 0.6848 + metric_value: 0.6854 timm/regnetx_002_backend_OV: metric_value: 0.6852 timm/regnetx_002_backend_TORCH: @@ -211,7 +211,7 @@ timm/resnest14d_backend_CUDA_TORCH: timm/resnest14d_backend_FP32: metric_value: 0.75516 timm/resnest14d_backend_ONNX: - metric_value: 0.75428 + metric_value: 0.7538 timm/resnest14d_backend_OV: metric_value: 0.75 timm/resnest14d_backend_TORCH: @@ -253,7 +253,7 @@ timm/visformer_small_backend_CUDA_TORCH: timm/visformer_small_backend_FP32: metric_value: 0.82098 timm/visformer_small_backend_ONNX: - metric_value: 0.81562 + metric_value: 0.8160 timm/visformer_small_backend_OV: metric_value: 0.81674 timm/visformer_small_backend_TORCH: diff --git a/tests/post_training/data/ptq_reference_data_2024.5.yaml b/tests/post_training/data/ptq_reference_data_2024.5.yaml deleted file mode 100644 index 998fc76f889..00000000000 --- a/tests/post_training/data/ptq_reference_data_2024.5.yaml +++ /dev/null @@ -1,2 +0,0 @@ -torchvision/mobilenet_v3_small_BC_backend_OV: - metric_value: 0.6681 diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 8fbc14a4396..6c48904c91a 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -7,15 +7,15 @@ tinyllama_data_aware_backend_OV: num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_stateful_backend_OV: - metric_value: 0.85571 + metric_value: 0.85616 num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_scale_estimation_backend_OV: - metric_value: 0.86355 + metric_value: 0.85502 num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV: - metric_value: 0.86355 + metric_value: 0.85502 num_int4: 94 num_int8: 124 tinyllama_int8_data_free_backend_TORCH: @@ -23,12 +23,12 @@ tinyllama_int8_data_free_backend_TORCH: num_int4: 0 num_int8: 312 tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV: - metric_value: 0.86697 + metric_value: 0.86503 num_int4: 94 num_int8: 124 metrics_xfail_reason: "Issue-148819" tinyllama_scale_estimation_per_channel_backend_OV: - metric_value: 0.80798 + metric_value: 0.81389 num_int4: 188 num_int8: 124 tinyllama_data_aware_lora_stateful_backend_OV: @@ -36,11 +36,15 @@ tinyllama_data_aware_lora_stateful_backend_OV: num_int4: 94 num_int8: 500 tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV: - metric_value: 0.87132 + metric_value: 0.88663 num_int4: 11 num_int8: 290 metrics_xfail_reason: "Issue-148819" tinyllama_awq_backup_mode_none_backend_OV: - metric_value: 0.85679 + metric_value: 0.84783 num_int4: 208 num_int8: 0 +tinyllama_int4_data_free_backend_TORCH: + metric_value: 0.73873 + num_int4: 114 + num_int8: 84 diff --git a/tests/post_training/data/wc_reference_data_2024.5.yaml b/tests/post_training/data/wc_reference_data_2024.5.yaml deleted file mode 100644 index e55a9f03dcd..00000000000 --- a/tests/post_training/data/wc_reference_data_2024.5.yaml +++ /dev/null @@ -1,34 +0,0 @@ -tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV: - metric_value: 0.88663 - num_int4: 11 - num_int8: 290 - metrics_xfail_reason: "Issue-148819" -tinyllama_int4_data_free_backend_TORCH: - metric_value: 0.73873 - num_int4: 114 - num_int8: 84 -tinyllama_awq_backup_mode_none_backend_OV: - metric_value: 0.84783 - num_int4: 208 - num_int8: 0 -tinyllama_data_aware_awq_scale_estimation_backend_OV: - metric_value: 0.85502 - num_int4: 94 - num_int8: 124 -tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV: - metric_value: 0.85502 - num_int4: 94 - num_int8: 124 -tinyllama_data_aware_awq_stateful_backend_OV: - metric_value: 0.85616 - num_int4: 94 - num_int8: 124 -tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV: - metric_value: 0.86503 - num_int4: 94 - num_int8: 124 - metrics_xfail_reason: "Issue-148819" -tinyllama_scale_estimation_per_channel_backend_OV: - metric_value: 0.81389 - num_int4: 188 - num_int8: 124 \ No newline at end of file