Skip to content

Commit

Permalink
[PTQ][Torch][KQV self attention] Align FQ placement between OV and To…
Browse files Browse the repository at this point in the history
…rch backend (#2166)

### Changes

* "unbing" and "__matmul__" ops are added to torch patterns
* Dropout removing pass is added to function
`transform_to_inference_graph`
* LayerNorm and GroupNorm metatypes are added to ignored metatypes in
MinMax algorithm

### Reason for changes

To align quantization placement between OV and Torch backends for the
following models:
* timm/crossvit_9_240
* timm/deit3_small_patch16_224
* timm/swin_base_patch4_window7_224


### Related tickets

121647

### Tests
*  test_metatypes_to_ignore for quantization propagation solver
  • Loading branch information
daniil-lyakhov authored Oct 24, 2023
1 parent 5336405 commit d879361
Show file tree
Hide file tree
Showing 16 changed files with 345 additions and 40 deletions.
12 changes: 11 additions & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
attrs_edge = {}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]

if extended:
if edge[NNCFGraph.DTYPE_EDGE_ATTR] is Dtype.INTEGER:
attrs_edge["style"] = "dashed"
else:
attrs_edge["style"] = "solid"
attrs_edge["label"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]
label["shape"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]

if label:
if "shape" in label and len(label) == 1:
attrs_edge["label"] = label["shape"]
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
out_graph.add_edge(u, v, **attrs_edge)
return out_graph

Expand Down
5 changes: 4 additions & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ def _get_quantization_target_points(
hw_patterns = PatternsManager.get_full_hw_pattern_graph(backend=backend, device=device, model_type=model_type)

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.read_variable_metatypes
deepcopy(nncf_graph),
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
self._backend_entity.read_variable_metatypes,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
19 changes: 13 additions & 6 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,23 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:

@property
@abstractmethod
def shapeof_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific ShapeOf metatypes.
Property for the backend-specific Convolution metatypes.
"""

@property
@abstractmethod
def conv_metatypes(self) -> List[OperatorMetatype]:
def shapeof_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Convolution metatypes.
Property for the backend-specific ShapeOf metatypes.
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
def dropout_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes for which overflow_fix is applicable.
Property for the backend-specific Dropout metatypes.
"""

@property
Expand All @@ -77,6 +77,13 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue).
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes for which overflow_fix is applicable.
"""

@property
@abstractmethod
def add_metatypes(self) -> List[OperatorMetatype]:
Expand Down
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXTopKMetatype, om.ONNXNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype]
Expand All @@ -68,10 +64,6 @@ def conv_metatypes(self) -> List[OperatorMetatype]:
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype, om.ONNXConvolutionTransposeMetatype, *MATMUL_METATYPES]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXAddLayerMetatype]
Expand All @@ -80,6 +72,18 @@ def add_metatypes(self) -> List[OperatorMetatype]:
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return self.conv_metatypes

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.OVTopKMetatype, om.OVNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.OVShapeOfMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype]
Expand All @@ -74,10 +70,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.OVMatMulMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.OVAddMetatype]
Expand All @@ -86,6 +78,18 @@ def add_metatypes(self) -> List[OperatorMetatype]:
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVGroupConvolutionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.OVShapeOfMetatype]

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
Expand Down
16 changes: 12 additions & 4 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return [om.PTDropoutMetatype]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand All @@ -85,10 +93,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.PTModuleConvTranspose3dMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.PTAddMetatype]
Expand Down Expand Up @@ -307,6 +311,10 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTDivMetatype,
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
]
if device != TargetDevice.CPU_SPR:
types.append(om.PTMulMetatype)
Expand Down
65 changes: 59 additions & 6 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@
def transform_to_inference_graph(
nncf_graph: NNCFGraph,
shapeof_metatypes: List[OperatorMetatype],
dropout_metatypes: List[OperatorMetatype],
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
This method contains pipeline of the passes that uses to provide inference graph without constant flows.
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:return: NNCFGraph in the inference style.
"""
inference_nncf_graph = remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
inference_nncf_graph = filter_constant_nodes(nncf_graph, read_variable_metatypes)
return inference_nncf_graph
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
filter_constant_nodes(nncf_graph, read_variable_metatypes)
return nncf_graph


def remove_shapeof_subgraphs(
Expand All @@ -45,7 +48,7 @@ def remove_shapeof_subgraphs(
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
Removes the ShapeOf subgraphs from the provided NNCFGraph instance.
Removes the ShapeOf subgraphs from the provided NNCFGraph instance inplace.
:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
Expand Down Expand Up @@ -88,11 +91,61 @@ def remove_shapeof_subgraphs(
return nncf_graph


def remove_nodes_and_reconnect_graph(
nncf_graph: NNCFGraph,
metatypes: List[OperatorMetatype],
) -> NNCFGraph:
"""
Removes nodes with metatypes specified by `metatypes` parameter from
the provided NNCFGraph instance and connects previous node of a matched node
with next nodes of a matched node inplace for each matched node.
Matched nodes should have only one input node and only one output port.
:param nncf_graph: NNCFGraph instance for the transformation.
:param metatypes: List of backend-specific metatypes.
:return: Resulting NNCFGraph.
"""
if not metatypes:
return nncf_graph

nodes_to_drop = []
for node in nncf_graph.get_nodes_by_metatypes(metatypes):
if node.metatype in metatypes:
nodes_to_drop.append(node)

prev_nodes = nncf_graph.get_previous_nodes(node)
input_edges = nncf_graph.get_input_edges(node)
assert len(prev_nodes) == len(input_edges) == 1
prev_node = prev_nodes[0]
input_edge = input_edges[0]
assert not input_edge.parallel_input_port_ids

# nncf_graph.get_next_edges is not used to preserve
# parallel_input_port_ids
for output_node in nncf_graph.get_next_nodes(node):
output_edge = nncf_graph.get_edge(node, output_node)
# Connects previous node with all next nodes
# to keep NNCFGraph connected.
assert input_edge.dtype == output_edge.dtype
assert input_edge.tensor_shape == output_edge.tensor_shape
nncf_graph.add_edge_between_nncf_nodes(
from_node_id=prev_node.node_id,
to_node_id=output_edge.to_node.node_id,
tensor_shape=input_edge.tensor_shape,
input_port_id=output_edge.input_port_id,
output_port_id=input_edge.output_port_id,
dtype=input_edge.dtype,
parallel_input_port_ids=output_edge.parallel_input_port_ids,
)
nncf_graph.remove_nodes_from(nodes_to_drop)
return nncf_graph


def filter_constant_nodes(
nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None
) -> NNCFGraph:
"""
Removes all Constant nodes from NNCFGraph, making it inference graph.
Removes all Constant nodes from NNCFGraph inplace, making it inference graph.
The traversing starts from the input nodes and nodes with weights.
:param nncf_graph: NNCFGraph instance for the transformation.
Expand Down
11 changes: 9 additions & 2 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,15 @@ def _add_softmax_reshape_matmul(

@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.MULTIHEAD_ATTENTION_OUTPUT)
def create_multihead_attention_output() -> GraphPattern:
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm"]
reshape_squeeze_aliases = ["reshape", "view", "flatten", "squeeze", "unsqueeze", "squeeze", "flatten", "unsqueeze"]
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm", "__matmul__"]
reshape_squeeze_aliases = [
"reshape",
"view",
"flatten",
"unsqueeze",
"squeeze",
"unbind",
]
gather_aliases = ["gather", "index_select", "where", "index_select", "__getitem__"]
transpose_aliases = ["transpose", "permute", "transpose_"]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /Split_1_0" [id=1, type=Split_1];
"5 /Output_1_0" [id=5, type=Output_1];
"6 /Output_2_1_0" [id=6, type=Output_2_1];
"7 /Output_2_2_0" [id=7, type=Output_2_2];
"8 /Output_2_3_0" [id=8, type=Output_2_3];
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "5 /Output_1_0";
"1 /Split_1_0" -> "6 /Output_2_1_0";
"1 /Split_1_0" -> "7 /Output_2_2_0";
"1 /Split_1_0" -> "8 /Output_2_3_0";
"1 /Split_1_0" -> "9 /Output_3_0";
"1 /Split_1_0" -> "10 /Output_2_4_0";
"1 /Split_1_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /Split_1_0" [id=1, type=Split_1];
"2 /Dropout_1_0" [id=2, type=Dropout_1];
"3 /Dropout_2_0" [id=3, type=Dropout_2];
"4 /Dropout_3_0" [id=4, type=Dropout_3];
"5 /Output_1_0" [id=5, type=Output_1];
"6 /Output_2_1_0" [id=6, type=Output_2_1];
"7 /Output_2_2_0" [id=7, type=Output_2_2];
"8 /Output_2_3_0" [id=8, type=Output_2_3];
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "2 /Dropout_1_0";
"1 /Split_1_0" -> "3 /Dropout_2_0";
"1 /Split_1_0" -> "4 /Dropout_3_0";
"2 /Dropout_1_0" -> "5 /Output_1_0";
"3 /Dropout_2_0" -> "6 /Output_2_1_0";
"3 /Dropout_2_0" -> "7 /Output_2_2_0";
"3 /Dropout_2_0" -> "8 /Output_2_3_0";
"3 /Dropout_2_0" -> "10 /Output_2_4_0";
"4 /Dropout_3_0" -> "9 /Output_3_0";
"4 /Dropout_3_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
}
Loading

0 comments on commit d879361

Please sign in to comment.