Skip to content

Commit

Permalink
Shift+Scale and Input+Shift+Scale pattern for PT (#1989)
Browse files Browse the repository at this point in the history
### Changes

Introduced Shift + Scale fused pattern and Input+Shift+Scale pattern to
insert Fake Quantize operations optimally for CPU.

### Reason for changes

Customer has a model that is quantized not optimally:

![image](https://github.com/openvinotoolkit/nncf/assets/4014476/59a228fd-1336-4e91-80c4-b67f76febcb8)

FakeQuantize between subtraction and division is redundant and
introduces additional cost in runtime.
FakeQuantize between input and pre-processing is not needed in case of
single edge from input, because pre-processing can be fused to the FQ
after pre-processing.

![image](https://github.com/openvinotoolkit/nncf/assets/4014476/2037597e-cf0d-45f3-b36b-83c1fa7f0de2)
When there are multiple edges from input and one edge is going to
pre-processing, it's optimal to have a common fake quantize for all
edges.

![image](https://github.com/openvinotoolkit/nncf/assets/4014476/a150a4b2-1e34-461a-9683-46955abb6ffc)


If pre-processing represented via normalize op from torchvision (e.g.
like here
https://github.com/PeterL1n/RobustVideoMatting/blob/master/model/mobilenetv3.py#L37),
NNCF doesn't insert FQ between subtraction and division and between
input and pre-processing.
It happens because pre-processing is implemented via in-place
operations, and since FQ is not in-place it can't be applied (see more
details here: #1565)

![image](https://github.com/openvinotoolkit/nncf/assets/4014476/17ecd154-a7d8-468c-95d0-31d99ca3185f)


### Related tickets

112934

### Tests

synthetic tests for pre-processing
  • Loading branch information
ljaljushkin authored Jul 25, 2023
1 parent 778751d commit 529523a
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 46 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class HWFusedPatternNames(Enum):
MVN_SCALE_SHIFT = PatternDesc("mvn_scale_shift")
NORMALIZE_L2_MULTIPLY = PatternDesc("normalize_l2_multiply")
SCALE_SHIFT = PatternDesc("scale_shift")
SHIFT_SCALE = PatternDesc("shift_scale")
SE_BLOCK = PatternDesc("se_block")
SOFTMAX_DIV = PatternDesc("softmax_div")

Expand Down
57 changes: 33 additions & 24 deletions nncf/onnx/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ def create_scale_shift() -> GraphPattern:
return pattern


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE)
def create_shift_scale() -> GraphPattern:
pattern = GraphPattern()
add_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "ADD, SUBTRACT",
GraphPattern.METATYPE_ATTR: [om.ONNXAddLayerMetatype, om.ONNXSubMetatype],
}
)
mul_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "MULTIPLY, DIV",
GraphPattern.METATYPE_ATTR: [om.ONNXMulLayerMetatype, om.ONNXDivLayerMetatype],
}
)

pattern.add_edge(add_node, mul_node)
return pattern


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SWISH_WITH_SIGMOID)
def create_swish_with_sigmoid() -> GraphPattern:
pattern = GraphPattern()
Expand Down Expand Up @@ -113,24 +133,23 @@ def create_hswish() -> GraphPattern:
# INPUT PROCESSING


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SCALE_SHIFT)
def create_input_scale_shift() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype})
scale_shift = create_scale_shift()

pattern.join_patterns(scale_shift)
return pattern


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE)
def create_input_shift_scale() -> GraphPattern:
pattern = GraphPattern()
input_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype}
)
add_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "ADD, SUBTRACT",
GraphPattern.METATYPE_ATTR: [om.ONNXAddLayerMetatype, om.ONNXSubMetatype],
}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.ONNXMulLayerMetatype}
)
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype})
shift_scale = create_shift_scale()

pattern.add_edge(input_node, add_node)
pattern.add_edge(add_node, multiply_node)
pattern.join_patterns(shift_scale)
return pattern


Expand All @@ -151,16 +170,6 @@ def create_input_add() -> GraphPattern:
return pattern


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SCALE_SHIFT)
def create_input_scale_shift() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: InputNoopMetatype})
scale_shift = create_scale_shift()

pattern.join_patterns(scale_shift)
return pattern


# COMBINATIONS


Expand Down
50 changes: 29 additions & 21 deletions nncf/openvino/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ def create_scale_shift() -> GraphPattern:
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE)
def create_shift_scale() -> GraphPattern:
pattern = GraphPattern()
add_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "ADD, SUBTRACT",
GraphPattern.METATYPE_ATTR: [om.OVAddMetatype, om.OVSubtractMetatype],
}
)
mul_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "MULTIPLY, DIV",
GraphPattern.METATYPE_ATTR: [om.OVMultiplyMetatype, om.OVDivideMetatype],
}
)
pattern.add_edge(add_node, mul_node)
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
pattern = GraphPattern()
Expand Down Expand Up @@ -308,27 +327,6 @@ def create_softmax() -> GraphPattern:
# INPUT PROCESSING


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE)
def create_input_shift_scale() -> GraphPattern:
pattern = GraphPattern()
model_input = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: om.OVParameterMetatype}
)
add_node = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "ADD, SUBTRACT",
GraphPattern.METATYPE_ATTR: [om.OVAddMetatype, om.OVSubtractMetatype],
}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
)

pattern.add_edge(model_input, add_node)
pattern.add_edge(add_node, multiply_node)
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_CONVERT_TRANSPOSE_PROCESSING)
def create_input_convert_transpose_processing() -> GraphPattern:
input_convert_transpose = create_input_convert_transpose()
Expand Down Expand Up @@ -461,6 +459,16 @@ def create_input_scale_shift() -> GraphPattern:
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE)
def create_input_shift_scale() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: om.OVParameterMetatype})
shift_scale = create_shift_scale()

pattern.join_patterns(shift_scale)
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_TRANSPOSE_PROCESSING)
def create_input_transpose_processing() -> GraphPattern:
pattern = GraphPattern()
Expand Down
19 changes: 19 additions & 0 deletions nncf/torch/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns import HWFusedPatternNames
from nncf.common.utils.registry import Registry
from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype
from nncf.torch.graph.pattern_operations import ARITHMETIC_OPERATIONS
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
from nncf.torch.graph.pattern_operations import BATCH_NORMALIZATION_OPERATIONS
Expand Down Expand Up @@ -49,6 +50,24 @@ def create_l2_norm_operations() -> GraphPattern:
# COMBINATIONS


@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE)
def create_shift_scale() -> GraphPattern:
pattern = GraphPattern()
add_node = pattern.add_node(label="ADD, SUB", type=["__add__", "__sub__"])
truediv_node = pattern.add_node(label="MUL, DIV", type=["__mul__", "__truediv__"])
pattern.add_edge(add_node, truediv_node)
return pattern


@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE)
def create_input_shift_scale() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: PTInputNoopMetatype})
shift_scale = create_shift_scale()
pattern.join_patterns(shift_scale)
return pattern


@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ARITHMETIC)
def create_linear_arithmetic_operations() -> GraphPattern:
linear = linear_operations()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize];
"2 ShiftScaleParametrized/__sub___0" [id=2, type=__sub__];
"3 ShiftScaleParametrized/__truediv___0" [id=3, type=__truediv__];
"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" [id=4, type=symmetric_quantize];
"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=5, type=symmetric_quantize];
"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=6, type=conv2d];
"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=7, type=symmetric_quantize];
"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" [id=8, type=conv2d];
"9 /nncf_model_output_0" [id=9, type=nncf_model_output];
"10 /nncf_model_output_1" [id=10, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ShiftScaleParametrized/__sub___0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1";
"2 ShiftScaleParametrized/__sub___0" -> "3 ShiftScaleParametrized/__truediv___0";
"3 ShiftScaleParametrized/__truediv___0" -> "4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0";
"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "9 /nncf_model_output_0";
"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1";
"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" -> "10 /nncf_model_output_1";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 ShiftScaleParametrized/is_floating_point_0" [id=1, type=is_floating_point];
"2 ShiftScaleParametrized/clone_0" [id=2, type=clone];
"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_];
"4 ShiftScaleParametrized/div__0" [id=4, type=div_];
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize];
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize];
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d];
"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=8, type=symmetric_quantize];
"9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" [id=9, type=symmetric_quantize];
"10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" [id=10, type=conv2d];
"11 /nncf_model_output_0" [id=11, type=nncf_model_output];
"12 /nncf_model_output_1" [id=12, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/is_floating_point_0";
"0 /nncf_model_input_0" -> "2 ShiftScaleParametrized/clone_0";
"0 /nncf_model_input_0" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0";
"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0";
"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0";
"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0";
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "11 /nncf_model_output_0";
"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1";
"9 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" -> "10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1";
"10 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" -> "12 /nncf_model_output_1";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 ShiftScaleParametrized/is_floating_point_0" [id=1, type=is_floating_point];
"2 ShiftScaleParametrized/clone_0" [id=2, type=clone];
"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_];
"4 ShiftScaleParametrized/div__0" [id=4, type=div_];
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize];
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize];
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d];
"8 /nncf_model_output_0" [id=8, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/is_floating_point_0";
"0 /nncf_model_input_0" -> "2 ShiftScaleParametrized/clone_0";
"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0";
"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0";
"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0";
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "8 /nncf_model_output_0";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 ShiftScaleParametrized/__sub___0" [id=1, type=__sub__];
"2 ShiftScaleParametrized/__truediv___0" [id=2, type=__truediv__];
"3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" [id=3, type=symmetric_quantize];
"4 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=4, type=symmetric_quantize];
"5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=5, type=conv2d];
"6 /nncf_model_output_0" [id=6, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/__sub___0";
"1 ShiftScaleParametrized/__sub___0" -> "2 ShiftScaleParametrized/__truediv___0";
"2 ShiftScaleParametrized/__truediv___0" -> "3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0";
"3 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/__truediv___0|OUTPUT]/symmetric_quantize_0" -> "5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"4 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"5 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "6 /nncf_model_output_0";
}
18 changes: 18 additions & 0 deletions tests/torch/test_compressed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
from abc import ABC
from abc import abstractmethod
Expand Down Expand Up @@ -68,6 +69,7 @@
from tests.torch.test_models.synthetic import MultiOutputSameTensorModel
from tests.torch.test_models.synthetic import PoolUnPool
from tests.torch.test_models.synthetic import ReshapeModel
from tests.torch.test_models.synthetic import ShiftScaleParametrized
from tests.torch.test_models.synthetic import TransposeModel


Expand Down Expand Up @@ -575,6 +577,21 @@ def forward(self, x):
return TestModel(self.tensor_method, **self.model_kwargs)


shift_scale_models = []
params_combinations = list(itertools.product([True, False], repeat=2))


for pair in params_combinations:
names = ["is_single_input", "use_normalize"]
kwargs = dict(zip(names, pair))
desc = GeneralModelDesc(
model_name=ShiftScaleParametrized.get_name(**kwargs),
model_builder=partial(ShiftScaleParametrized, **kwargs),
input_sample_sizes=(ShiftScaleParametrized.INPUT_SIZES),
)
shift_scale_models.append(desc)


TWO_INT_INPUTS_INFO = [{"sample_size": [1], "type": "long"}, {"sample_size": [1], "type": "long"}]
SYNTHETIC_MODEL_DESC_LIST = [
SingleLayerModelDesc(layer=nn.Conv1d(1, 1, 1), input_sample_sizes=[1, 1, 1]),
Expand Down Expand Up @@ -732,6 +749,7 @@ def forward(self, x):
wrap_inputs_fn=partial(n_inputs_fn, nargs=3),
),
GeneralModelDesc(model_builder=MHA_single_input, input_sample_sizes=(MHA_single_input.INPUT_SIZES,)),
*shift_scale_models,
]


Expand Down
31 changes: 31 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import BatchNorm2d
from torch.nn import Dropout
from torch.nn import Parameter
from torchvision.transforms.functional import normalize

from nncf.torch import register_module
from tests.torch.helpers import create_conv
Expand Down Expand Up @@ -332,3 +333,33 @@ def __init__(self):

def forward(self, x):
return self.mha(x, x, x)


class ShiftScaleParametrized(torch.nn.Module):
NUM_CHANNELS = 3
INPUT_SIZES = [1, NUM_CHANNELS, 2, 2]

def __init__(self, is_single_input: bool, use_normalize: bool):
super().__init__()
self.conv = create_conv(self.NUM_CHANNELS, 1, 1)
self.is_single_input = is_single_input
self.use_normalize = use_normalize

@classmethod
def get_name(cls, is_single_input: bool, use_normalize: bool):
suffix_1 = "single" if is_single_input else "multi"
suffix_2 = "__normalize" if use_normalize else ""
return f"ShiftScale{suffix_2}__{suffix_1}_input_branch"

def forward(self, x):
values = [1] * self.NUM_CHANNELS
if self.use_normalize:
pre_proc = normalize(x, values, values, inplace=False)
else:
vector = torch.Tensor(values).unsqueeze(dim=0).unsqueeze(dim=2).unsqueeze(dim=3)
pre_proc = (x - vector) / vector

output = self.conv(pre_proc)
if self.is_single_input:
return output
return output, self.conv(x)
1 change: 0 additions & 1 deletion tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
HWFusedPatternNames.INPUT_REVERSE_ADD: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_REVERSE_SCALE_SHIFT: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_SCALE_SHIFT: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_SHIFT_SCALE: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_TRANSPOSE_PROCESSING: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_TRANSPOSE_REVERSE_ADD: "Not relevant for Torch.",
HWFusedPatternNames.INPUT_TRANSPOSE_SCALE_SHIFT: "Not relevant for Torch.",
Expand Down

0 comments on commit 529523a

Please sign in to comment.