Skip to content

Commit

Permalink
[PTQ] CPU_SPR device support (#1979)
Browse files Browse the repository at this point in the history
### Changes

- Updated PTQ part with the CPU_SPR device

### Reason for changes

- New device

### Related tickets

- 100682

### Tests

- Updated test_graph.py
  • Loading branch information
KodiaqQ authored Jul 25, 2023
1 parent 5475110 commit 459e724
Show file tree
Hide file tree
Showing 24 changed files with 4,980 additions and 2,404 deletions.
6 changes: 5 additions & 1 deletion nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,9 @@ class IgnoredPatternNames(Enum):
Describes the patterns, which nodes should be ignored during FakeQuantize placement.
"""

MULTIHEAD_ATTENTION_OUTPUT = PatternDesc("multihead_attention_output", model_types=[ModelType.TRANSFORMER])
MULTIHEAD_ATTENTION_OUTPUT = PatternDesc(
"multihead_attention_output",
model_types=[ModelType.TRANSFORMER],
devices=[TargetDevice.ANY, TargetDevice.CPU, TargetDevice.GPU, TargetDevice.VPU],
)
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
3 changes: 1 addition & 2 deletions nncf/common/hardware/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class HWConfigType(Enum):
"CPU": HWConfigType.CPU.value,
"VPU": HWConfigType.VPU.value,
"GPU": HWConfigType.GPU.value,
"CPU_SPR": HWConfigType.CPU.value,
}


Expand All @@ -55,8 +56,6 @@ def get_hw_config_type(target_device: str) -> Optional[HWConfigType]:
"""
if target_device == "TRIAL":
return None
if target_device == "CPU_SPR":
raise ValueError(f"{target_device} target device is not supported yet")
return HWConfigType(HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device])


Expand Down
112 changes: 104 additions & 8 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, OrderedDict, Set, TypeVar

import numpy as np

from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
from nncf.common.factory import NNCFGraphFactory
Expand Down Expand Up @@ -344,7 +346,11 @@ def _get_ignored_scope(self, inference_nncf_graph: NNCFGraph, ignored_patterns:
return IgnoredScope(names=nncf_node_names)

def _get_quantizer_setup(
self, nncf_graph: NNCFGraph, hw_patterns: GraphPattern, ignored_patterns: GraphPattern
self,
nncf_graph: NNCFGraph,
inference_nncf_graph: NNCFGraph,
hw_patterns: GraphPattern,
ignored_patterns: GraphPattern,
) -> SingleConfigQuantizerSetup:
"""
Returns SingleConfigQuantizerSetup instance based on the input NNCFGraph.
Expand All @@ -358,9 +364,6 @@ def _get_quantizer_setup(
hw_config_path = self._backend_entity.hw_config.get_path_to_hw_config(hw_config_type)
hw_config = self._backend_entity.hw_config.from_json(hw_config_path)

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.read_variable_metatypes
)
ignored_names = self._get_ignored_names(nncf_graph, inference_nncf_graph, ignored_patterns)
weight_nodes = self._backend_entity.get_weight_nodes(nncf_graph)

Expand Down Expand Up @@ -478,8 +481,14 @@ def _get_quantization_target_points(
backend=backend, device=device, model_type=model_type
)
hw_patterns = PatternsManager.get_full_hw_pattern_graph(backend=backend, device=device, model_type=model_type)
quantizer_setup = self._get_quantizer_setup(nncf_graph, hw_patterns, ignored_patterns)

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.read_variable_metatypes
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
self._apply_model_type_pass(self._model_type, quantizer_setup, nncf_graph)
self._apply_device_pass(self._target_device, quantizer_setup, inference_nncf_graph)
self._unified_scale_groups = self._collect_unified_groups(quantizer_setup)
quantization_points = list(quantizer_setup.quantization_points.values())
quantization_points = self._topological_sort_quantization_points(quantization_points, nncf_graph)
Expand Down Expand Up @@ -583,7 +592,7 @@ def _get_quantization_points_overflow_fix(
)
if overflow_fix == OverflowFix.FIRST_LAYER:
weight_quantization_points = _filter_target_points_by_metatypes(
weight_quantization_points, self._backend_entity.conv_metatype, nncf_graph
weight_quantization_points, self._backend_entity.conv_metatypes, nncf_graph
)
for input_node in nncf_graph.get_input_nodes():
nodes = self._get_first_quantized_convolutions(weight_quantization_points, input_node, nncf_graph)
Expand Down Expand Up @@ -704,8 +713,7 @@ def _apply_model_type_pass(
if quantization_point.is_activation_quantization_point():
for node_name in quantization_point.directly_quantized_operator_node_names:
node = nncf_graph.get_node_by_name(node_name)
mat_mul_metatype = self._backend_entity.mat_mul_metatype
if node.metatype != mat_mul_metatype:
if node.metatype not in self._backend_entity.mat_mul_metatypes:
continue
if (
quantization_point.qconfig.mode != QuantizationMode.SYMMETRIC
Expand All @@ -716,3 +724,91 @@ def _apply_model_type_pass(
f"Update quantization mode for the node {node_name}"
f" to the symmetric due to ModelType parameter."
)

def _apply_device_pass(
self, target_device: TargetDevice, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph
) -> None:
"""
This method applies model post-processing device passes to SingleConfigQuantizerSetup in-place.
:param target_device: TargetDevice instance.
:param quantizer_setup: SingleConfigQuantizerSetup instance to update.
:param nncf_graph: NNCFGraph.
:return: None.
"""

passes_map = {TargetDevice.CPU_SPR: self._apply_spr_pass}

if target_device not in passes_map:
return

passes_map[target_device](quantizer_setup, nncf_graph)

def _apply_spr_pass(
self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph
) -> SingleConfigQuantizerSetup:
"""
Applies CPU_SPR-related pass.
The main action is to remove one of the quantizers before elementwise layer (e.g. Add).
This action allows to get performance boost on SPR devices.
:param quantizer_setup: SingleConfigQuantizerSetup instance to update.
:param nncf_graph: NNCFGraph instance to update.
:return: Modified SingleConfigQuantizerSetup.
"""

def _is_node_after_producers(node):
input_node = node
while True:
input_node = nncf_graph.get_previous_nodes(input_node)
if len(input_node) > 1:
return False
input_node = input_node[0]
if input_node.metatype in producer_metatypes:
return True

producer_metatypes = (
self._backend_entity.conv_metatypes
+ self._backend_entity.mat_mul_metatypes
+ self._backend_entity.group_conv_metatypes
)

quantizer_setup_map = {
p.insertion_point.target_node_name: q_key for q_key, p in quantizer_setup.quantization_points.items()
}

# Walking through all Add layers.
for add_node in nncf_graph.get_nodes_by_metatypes(self._backend_entity.add_metatypes):
add_inputs = nncf_graph.get_previous_nodes(add_node)

# Filtering Add based on it's input.
# Need to find Add layer only with two activations as input.
if len(add_inputs) == 2 and all(n.node_name in quantizer_setup_map for n in add_inputs):
# Sorting of the inputs based on length of input's consumer in descending order.
add_inputs.sort(key=lambda n: len(nncf_graph.get_next_nodes(n)), reverse=True)
fq_1_producer, fq_2_producer = add_inputs
fq_1_q_key = quantizer_setup_map[fq_1_producer.node_name]
fq_2_q_key = quantizer_setup_map[fq_2_producer.node_name]

# In the case of the two quantizers where one of them produces data into branching,
# it needs to remove the quantizer without branching after it.
if (
len(nncf_graph.get_next_nodes(fq_1_producer)) > 1
and len(nncf_graph.get_next_nodes(fq_2_producer)) == 1
):
quantizer_setup.discard(fq_2_q_key, True)
continue

# In the case of the two quantizers without the brancking after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
fq_2_prod_shape = np.prod(nncf_graph.get_output_edges(fq_2_producer)[0].tensor_shape)

# Then it needs to remove quantizer with the smallest shape.
if fq_1_prod_shape >= fq_2_prod_shape:
quantizer_setup.discard(fq_1_q_key, True)
else:
quantizer_setup.discard(fq_2_q_key, True)

return quantizer_setup
22 changes: 19 additions & 3 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
ALGO_BACKENDS = Registry("algo_backends")


# pylint:disable=too-many-public-methods
class MinMaxAlgoBackend(ABC):
@property
@abstractmethod
def mat_mul_metatype(self) -> OperatorMetatype:
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific MatMul metatype.
Property for the backend-specific MatMul metatypes.
"""

@property
Expand All @@ -58,7 +59,7 @@ def shapeof_metatypes(self) -> List[OperatorMetatype]:

@property
@abstractmethod
def conv_metatype(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Convolution metatypes.
"""
Expand All @@ -78,6 +79,21 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
"""

@property
@abstractmethod
def add_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes that also can be interpreted as Add layer.
"""

@property
@abstractmethod
def group_conv_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Grouped Convolution metatypes.
"""

@property
@abstractmethod
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
"""
Property for the backend-specific metatypes that produces quantizers that might be unified.
Expand Down
13 changes: 11 additions & 2 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
from nncf.scopes import IgnoredScope


# pylint:disable=too-many-public-methods
@ALGO_BACKENDS.register(BackendType.ONNX)
class ONNXMinMaxAlgoBackend(MinMaxAlgoBackend):
@property
def mat_mul_metatype(self) -> OperatorMetatype:
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return om.MATMUL_METATYPES

@property
Expand All @@ -58,7 +59,7 @@ def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def conv_metatype(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype]

@property
Expand All @@ -69,6 +70,14 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXAddLayerMetatype]

@property
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return self.conv_metatypes

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
63 changes: 36 additions & 27 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes import openvino_metatypes as ov_metatypes
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.metatypes.openvino_metatypes import GENERAL_WEIGHT_LAYER_METATYPES
from nncf.openvino.graph.node_utils import get_weight_channel_axes
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
Expand All @@ -45,41 +45,50 @@
from nncf.scopes import IgnoredScope


# pylint:disable=too-many-public-methods
@ALGO_BACKENDS.register(BackendType.OPENVINO)
class OVMinMaxAlgoBackend(MinMaxAlgoBackend):
@property
def mat_mul_metatype(self) -> OperatorMetatype:
return ov_metatypes.OVMatMulMetatype
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.OVMatMulMetatype]

@property
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [ov_metatypes.OVTopKMetatype, ov_metatypes.OVNonMaxSuppressionMetatype]
return [om.OVTopKMetatype, om.OVNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [ov_metatypes.OVShapeOfMetatype]
return [om.OVShapeOfMetatype]

@property
def conv_metatype(self) -> List[OperatorMetatype]:
return [ov_metatypes.OVConvolutionMetatype]
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype]

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
ov_metatypes.OVConvolutionMetatype,
ov_metatypes.OVGroupConvolutionMetatype,
ov_metatypes.OVConvolutionBackpropDataMetatype,
ov_metatypes.OVGroupConvolutionBackpropDataMetatype,
ov_metatypes.OVMatMulMetatype,
om.OVConvolutionMetatype,
om.OVGroupConvolutionMetatype,
om.OVConvolutionBackpropDataMetatype,
om.OVGroupConvolutionBackpropDataMetatype,
om.OVMatMulMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [ov_metatypes.OVReadValueMetatype]
return [om.OVReadValueMetatype]

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.OVAddMetatype]

@property
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVGroupConvolutionMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {ov_metatypes.OVConcatMetatype: self.overflow_fix_metatypes}
return {om.OVConcatMetatype: self.overflow_fix_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down Expand Up @@ -217,21 +226,21 @@ def get_ignored_scope(model_type: ModelType, device: TargetDevice) -> IgnoredSco
if model_type == ModelType.TRANSFORMER:
types = []
metatypes_to_add = [
ov_metatypes.OVAddMetatype,
ov_metatypes.OVPowerMetatype,
ov_metatypes.OVSqueezeMetatype,
ov_metatypes.OVSubtractMetatype,
ov_metatypes.OVReduceMeanMetatype,
ov_metatypes.OVReduceL2Metatype,
ov_metatypes.OVSumMetatype,
ov_metatypes.OVSquaredDifferenceMetatype,
ov_metatypes.OVMVNMetatype,
ov_metatypes.OVDivideMetatype,
ov_metatypes.OVSqrtMetatype,
ov_metatypes.OVMaximumMetatype,
om.OVAddMetatype,
om.OVPowerMetatype,
om.OVSqueezeMetatype,
om.OVSubtractMetatype,
om.OVReduceMeanMetatype,
om.OVReduceL2Metatype,
om.OVSumMetatype,
om.OVSquaredDifferenceMetatype,
om.OVMVNMetatype,
om.OVDivideMetatype,
om.OVSqrtMetatype,
om.OVMaximumMetatype,
]
if device != TargetDevice.CPU_SPR:
metatypes_to_add.append(ov_metatypes.OVMultiplyMetatype)
metatypes_to_add.append(om.OVMultiplyMetatype)
for metatype in metatypes_to_add:
types.extend(metatype.get_all_aliases())
return IgnoredScope(types=types)
Expand Down
Loading

0 comments on commit 459e724

Please sign in to comment.