From 529523a58c17cccfa6033d2e80e30757e33dced3 Mon Sep 17 00:00:00 2001 From: Lyalyushkin Nikolay Date: Tue, 25 Jul 2023 12:39:06 +0200 Subject: [PATCH] Shift+Scale and Input+Shift+Scale pattern for PT (#1989) ### Changes Introduced Shift + Scale fused pattern and Input+Shift+Scale pattern to insert Fake Quantize operations optimally for CPU. ### Reason for changes Customer has a model that is quantized not optimally: ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/59a228fd-1336-4e91-80c4-b67f76febcb8) FakeQuantize between subtraction and division is redundant and introduces additional cost in runtime. FakeQuantize between input and pre-processing is not needed in case of single edge from input, because pre-processing can be fused to the FQ after pre-processing. ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/2037597e-cf0d-45f3-b36b-83c1fa7f0de2) When there are multiple edges from input and one edge is going to pre-processing, it's optimal to have a common fake quantize for all edges. ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/a150a4b2-1e34-461a-9683-46955abb6ffc) If pre-processing represented via normalize op from torchvision (e.g. like here https://github.com/PeterL1n/RobustVideoMatting/blob/master/model/mobilenetv3.py#L37), NNCF doesn't insert FQ between subtraction and division and between input and pre-processing. It happens because pre-processing is implemented via in-place operations, and since FQ is not in-place it can't be applied (see more details here: https://github.com/openvinotoolkit/nncf/pull/1565) ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/17ecd154-a7d8-468c-95d0-31d99ca3185f) ### Related tickets 112934 ### Tests synthetic tests for pre-processing --- nncf/common/graph/patterns/patterns.py | 1 + nncf/onnx/hardware/fused_patterns.py | 57 +++++++++++-------- nncf/openvino/hardware/fused_patterns.py | 50 +++++++++------- nncf/torch/hardware/fused_patterns.py | 19 +++++++ .../ShiftScale__multi_input_branch.dot | 23 ++++++++ ...ftScale__normalize__multi_input_branch.dot | 27 +++++++++ ...tScale__normalize__single_input_branch.dot | 19 +++++++ .../ShiftScale__single_input_branch.dot | 15 +++++ tests/torch/test_compressed_graph.py | 18 ++++++ tests/torch/test_models/synthetic.py | 31 ++++++++++ tests/torch/test_pattern_manager.py | 1 - 11 files changed, 215 insertions(+), 46 deletions(-) create mode 100644 tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__multi_input_branch.dot create mode 100644 tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot create mode 100644 tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot create mode 100644 tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__single_input_branch.dot diff --git a/nncf/common/graph/patterns/patterns.py b/nncf/common/graph/patterns/patterns.py index 32c3d02ac53..021addf0799 100644 --- a/nncf/common/graph/patterns/patterns.py +++ b/nncf/common/graph/patterns/patterns.py @@ -296,6 +296,7 @@ class HWFusedPatternNames(Enum): MVN_SCALE_SHIFT = PatternDesc("mvn_scale_shift") NORMALIZE_L2_MULTIPLY = PatternDesc("normalize_l2_multiply") SCALE_SHIFT = PatternDesc("scale_shift") + SHIFT_SCALE = PatternDesc("shift_scale") SE_BLOCK = PatternDesc("se_block") SOFTMAX_DIV = PatternDesc("softmax_div") diff --git a/nncf/onnx/hardware/fused_patterns.py b/nncf/onnx/hardware/fused_patterns.py index 2b5c3580a89..9572fd42970 100644 --- a/nncf/onnx/hardware/fused_patterns.py +++ b/nncf/onnx/hardware/fused_patterns.py @@ -44,6 +44,26 @@ def create_scale_shift() -> GraphPattern: return pattern +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE) +def create_shift_scale() -> GraphPattern: + pattern = GraphPattern() + add_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "ADD, SUBTRACT", + GraphPattern.METATYPE_ATTR: [om.ONNXAddLayerMetatype, om.ONNXSubMetatype], + } + ) + mul_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "MULTIPLY, DIV", + GraphPattern.METATYPE_ATTR: [om.ONNXMulLayerMetatype, om.ONNXDivLayerMetatype], + } + ) + + pattern.add_edge(add_node, mul_node) + return pattern + + @ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SWISH_WITH_SIGMOID) def create_swish_with_sigmoid() -> GraphPattern: pattern = GraphPattern() @@ -113,24 +133,23 @@ def create_hswish() -> GraphPattern: # INPUT PROCESSING +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SCALE_SHIFT) +def create_input_scale_shift() -> GraphPattern: + pattern = GraphPattern() + pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype}) + scale_shift = create_scale_shift() + + pattern.join_patterns(scale_shift) + return pattern + + @ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE) def create_input_shift_scale() -> GraphPattern: pattern = GraphPattern() - input_node = pattern.add_node( - **{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype} - ) - add_node = pattern.add_node( - **{ - GraphPattern.LABEL_ATTR: "ADD, SUBTRACT", - GraphPattern.METATYPE_ATTR: [om.ONNXAddLayerMetatype, om.ONNXSubMetatype], - } - ) - multiply_node = pattern.add_node( - **{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.ONNXMulLayerMetatype} - ) + pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype}) + shift_scale = create_shift_scale() - pattern.add_edge(input_node, add_node) - pattern.add_edge(add_node, multiply_node) + pattern.join_patterns(shift_scale) return pattern @@ -151,16 +170,6 @@ def create_input_add() -> GraphPattern: return pattern -@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SCALE_SHIFT) -def create_input_scale_shift() -> GraphPattern: - pattern = GraphPattern() - pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype}) - scale_shift = create_scale_shift() - - pattern.join_patterns(scale_shift) - return pattern - - # COMBINATIONS diff --git a/nncf/openvino/hardware/fused_patterns.py b/nncf/openvino/hardware/fused_patterns.py index b89e548fade..7a1724e26e8 100644 --- a/nncf/openvino/hardware/fused_patterns.py +++ b/nncf/openvino/hardware/fused_patterns.py @@ -128,6 +128,25 @@ def create_scale_shift() -> GraphPattern: return pattern +@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE) +def create_shift_scale() -> GraphPattern: + pattern = GraphPattern() + add_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "ADD, SUBTRACT", + GraphPattern.METATYPE_ATTR: [om.OVAddMetatype, om.OVSubtractMetatype], + } + ) + mul_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "MULTIPLY, DIV", + GraphPattern.METATYPE_ATTR: [om.OVMultiplyMetatype, om.OVDivideMetatype], + } + ) + pattern.add_edge(add_node, mul_node) + return pattern + + @OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK) def create_se_block() -> GraphPattern: pattern = GraphPattern() @@ -308,27 +327,6 @@ def create_softmax() -> GraphPattern: # INPUT PROCESSING -@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE) -def create_input_shift_scale() -> GraphPattern: - pattern = GraphPattern() - model_input = pattern.add_node( - **{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: om.OVParameterMetatype} - ) - add_node = pattern.add_node( - **{ - GraphPattern.LABEL_ATTR: "ADD, SUBTRACT", - GraphPattern.METATYPE_ATTR: [om.OVAddMetatype, om.OVSubtractMetatype], - } - ) - multiply_node = pattern.add_node( - **{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype} - ) - - pattern.add_edge(model_input, add_node) - pattern.add_edge(add_node, multiply_node) - return pattern - - @OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_CONVERT_TRANSPOSE_PROCESSING) def create_input_convert_transpose_processing() -> GraphPattern: input_convert_transpose = create_input_convert_transpose() @@ -461,6 +459,16 @@ def create_input_scale_shift() -> GraphPattern: return pattern +@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE) +def create_input_shift_scale() -> GraphPattern: + pattern = GraphPattern() + pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: om.OVParameterMetatype}) + shift_scale = create_shift_scale() + + pattern.join_patterns(shift_scale) + return pattern + + @OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_TRANSPOSE_PROCESSING) def create_input_transpose_processing() -> GraphPattern: pattern = GraphPattern() diff --git a/nncf/torch/hardware/fused_patterns.py b/nncf/torch/hardware/fused_patterns.py index 5ca28821586..c30e5b0d4c0 100644 --- a/nncf/torch/hardware/fused_patterns.py +++ b/nncf/torch/hardware/fused_patterns.py @@ -12,6 +12,7 @@ from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import HWFusedPatternNames from nncf.common.utils.registry import Registry +from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype from nncf.torch.graph.pattern_operations import ARITHMETIC_OPERATIONS from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.torch.graph.pattern_operations import BATCH_NORMALIZATION_OPERATIONS @@ -49,6 +50,24 @@ def create_l2_norm_operations() -> GraphPattern: # COMBINATIONS +@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE) +def create_shift_scale() -> GraphPattern: + pattern = GraphPattern() + add_node = pattern.add_node(label="ADD, SUB", type=["__add__", "__sub__"]) + truediv_node = pattern.add_node(label="MUL, DIV", type=["__mul__", "__truediv__"]) + pattern.add_edge(add_node, truediv_node) + return pattern + + +@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE) +def create_input_shift_scale() -> GraphPattern: + pattern = GraphPattern() + pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: PTInputNoopMetatype}) + shift_scale = create_shift_scale() + pattern.join_patterns(shift_scale) + return pattern + + @PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ARITHMETIC) def create_linear_arithmetic_operations() -> GraphPattern: linear = linear_operations() diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__multi_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__multi_input_branch.dot new file mode 100644 index 00000000000..a1587e51232 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__multi_input_branch.dot @@ -0,0 +1,23 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; +"2 ShiftScaleParametrized/__sub___0" [id=2, type=__sub__]; +"3 ShiftScaleParametrized/__truediv___0" [id=3, type=__truediv__]; +"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" [id=4, type=symmetric_quantize]; +"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=6, type=conv2d]; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=7, type=symmetric_quantize]; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" [id=8, type=conv2d]; +"9 /nncf_model_output_0" [id=9, type=nncf_model_output]; +"10 /nncf_model_output_1" [id=10, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ShiftScaleParametrized/__sub___0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"2 ShiftScaleParametrized/__sub___0" -> "3 ShiftScaleParametrized/__truediv___0"; +"3 ShiftScaleParametrized/__truediv___0" -> "4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0"; +"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "9 /nncf_model_output_0"; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" -> "10 /nncf_model_output_1"; +} diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot new file mode 100644 index 00000000000..079b8f9afe8 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot @@ -0,0 +1,27 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 ShiftScaleParametrized/is_floating_point_0" [id=1, type=is_floating_point]; +"2 ShiftScaleParametrized/clone_0" [id=2, type=clone]; +"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_]; +"4 ShiftScaleParametrized/div__0" [id=4, type=div_]; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize]; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d]; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=8, type=symmetric_quantize]; +"9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" [id=9, type=symmetric_quantize]; +"10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" [id=10, type=conv2d]; +"11 /nncf_model_output_0" [id=11, type=nncf_model_output]; +"12 /nncf_model_output_1" [id=12, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/is_floating_point_0"; +"0 /nncf_model_input_0" -> "2 ShiftScaleParametrized/clone_0"; +"0 /nncf_model_input_0" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0"; +"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0"; +"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0"; +"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "11 /nncf_model_output_0"; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" -> "10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" -> "12 /nncf_model_output_1"; +} diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot new file mode 100644 index 00000000000..4d067597486 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot @@ -0,0 +1,19 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 ShiftScaleParametrized/is_floating_point_0" [id=1, type=is_floating_point]; +"2 ShiftScaleParametrized/clone_0" [id=2, type=clone]; +"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_]; +"4 ShiftScaleParametrized/div__0" [id=4, type=div_]; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize]; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d]; +"8 /nncf_model_output_0" [id=8, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/is_floating_point_0"; +"0 /nncf_model_input_0" -> "2 ShiftScaleParametrized/clone_0"; +"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0"; +"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0"; +"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "8 /nncf_model_output_0"; +} diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__single_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__single_input_branch.dot new file mode 100644 index 00000000000..0a2f73028fb --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__single_input_branch.dot @@ -0,0 +1,15 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 ShiftScaleParametrized/__sub___0" [id=1, type=__sub__]; +"2 ShiftScaleParametrized/__truediv___0" [id=2, type=__truediv__]; +"3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" [id=3, type=symmetric_quantize]; +"4 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=4, type=symmetric_quantize]; +"5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=5, type=conv2d]; +"6 /nncf_model_output_0" [id=6, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/__sub___0"; +"1 ShiftScaleParametrized/__sub___0" -> "2 ShiftScaleParametrized/__truediv___0"; +"2 ShiftScaleParametrized/__truediv___0" -> "3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0"; +"3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" -> "5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"4 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "6 /nncf_model_output_0"; +} diff --git a/tests/torch/test_compressed_graph.py b/tests/torch/test_compressed_graph.py index e55152c73fd..dcb3e966a80 100644 --- a/tests/torch/test_compressed_graph.py +++ b/tests/torch/test_compressed_graph.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os from abc import ABC from abc import abstractmethod @@ -68,6 +69,7 @@ from tests.torch.test_models.synthetic import MultiOutputSameTensorModel from tests.torch.test_models.synthetic import PoolUnPool from tests.torch.test_models.synthetic import ReshapeModel +from tests.torch.test_models.synthetic import ShiftScaleParametrized from tests.torch.test_models.synthetic import TransposeModel @@ -575,6 +577,21 @@ def forward(self, x): return TestModel(self.tensor_method, **self.model_kwargs) +shift_scale_models = [] +params_combinations = list(itertools.product([True, False], repeat=2)) + + +for pair in params_combinations: + names = ["is_single_input", "use_normalize"] + kwargs = dict(zip(names, pair)) + desc = GeneralModelDesc( + model_name=ShiftScaleParametrized.get_name(**kwargs), + model_builder=partial(ShiftScaleParametrized, **kwargs), + input_sample_sizes=(ShiftScaleParametrized.INPUT_SIZES), + ) + shift_scale_models.append(desc) + + TWO_INT_INPUTS_INFO = [{"sample_size": [1], "type": "long"}, {"sample_size": [1], "type": "long"}] SYNTHETIC_MODEL_DESC_LIST = [ SingleLayerModelDesc(layer=nn.Conv1d(1, 1, 1), input_sample_sizes=[1, 1, 1]), @@ -732,6 +749,7 @@ def forward(self, x): wrap_inputs_fn=partial(n_inputs_fn, nargs=3), ), GeneralModelDesc(model_builder=MHA_single_input, input_sample_sizes=(MHA_single_input.INPUT_SIZES,)), + *shift_scale_models, ] diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index 0c78b74e6c2..0ca77979915 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -16,6 +16,7 @@ from torch.nn import BatchNorm2d from torch.nn import Dropout from torch.nn import Parameter +from torchvision.transforms.functional import normalize from nncf.torch import register_module from tests.torch.helpers import create_conv @@ -332,3 +333,33 @@ def __init__(self): def forward(self, x): return self.mha(x, x, x) + + +class ShiftScaleParametrized(torch.nn.Module): + NUM_CHANNELS = 3 + INPUT_SIZES = [1, NUM_CHANNELS, 2, 2] + + def __init__(self, is_single_input: bool, use_normalize: bool): + super().__init__() + self.conv = create_conv(self.NUM_CHANNELS, 1, 1) + self.is_single_input = is_single_input + self.use_normalize = use_normalize + + @classmethod + def get_name(cls, is_single_input: bool, use_normalize: bool): + suffix_1 = "single" if is_single_input else "multi" + suffix_2 = "__normalize" if use_normalize else "" + return f"ShiftScale{suffix_2}__{suffix_1}_input_branch" + + def forward(self, x): + values = [1] * self.NUM_CHANNELS + if self.use_normalize: + pre_proc = normalize(x, values, values, inplace=False) + else: + vector = torch.Tensor(values).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3) + pre_proc = (x - vector) / vector + + output = self.conv(pre_proc) + if self.is_single_input: + return output + return output, self.conv(x) diff --git a/tests/torch/test_pattern_manager.py b/tests/torch/test_pattern_manager.py index 3a0b6e8309f..e7879f5988c 100644 --- a/tests/torch/test_pattern_manager.py +++ b/tests/torch/test_pattern_manager.py @@ -39,7 +39,6 @@ HWFusedPatternNames.INPUT_REVERSE_ADD: "Not relevant for Torch.", HWFusedPatternNames.INPUT_REVERSE_SCALE_SHIFT: "Not relevant for Torch.", HWFusedPatternNames.INPUT_SCALE_SHIFT: "Not relevant for Torch.", - HWFusedPatternNames.INPUT_SHIFT_SCALE: "Not relevant for Torch.", HWFusedPatternNames.INPUT_TRANSPOSE_PROCESSING: "Not relevant for Torch.", HWFusedPatternNames.INPUT_TRANSPOSE_REVERSE_ADD: "Not relevant for Torch.", HWFusedPatternNames.INPUT_TRANSPOSE_SCALE_SHIFT: "Not relevant for Torch.",