Skip to content

Commit

Permalink
[PTQ][Micro refactoring] Refactor insert null bias to insert bias wit…
Browse files Browse the repository at this point in the history
…h value (openvinotoolkit#2056)

### Changes

Refactor insert null bias to insert bias with value
### Reason for changes

To make it possible to insert bias with a value. This is needed for the
ChannelAlignment algorithm

### Related tickets

114583

### Tests
* tests/openvino/native/test_model_transformer.py updated
* tests/openvino/native/test_model_utils.py presented
  • Loading branch information
daniil-lyakhov authored and andreyanufr committed Sep 12, 2023
1 parent 45947aa commit 74655f5
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 43 deletions.
14 changes: 5 additions & 9 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.openvino.graph.node_utils import get_result_node_name
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
Expand All @@ -50,7 +50,7 @@ def __init__(self, model: TModel):
(OVModelExtractionCommand, self._apply_model_extraction_transformation),
(OVInplaceFnInsertionCommand, self._apply_insert_operation),
(OVOutputInsertionCommand, self._apply_output_insertion_transformations),
(OVNullBiasInsertionCommand, self._apply_bias_insertion_transformations),
(OVBiasInsertionCommand, self._apply_bias_insertion_transformations),
(OVMultiplyInsertionCommand, self._apply_multiply_insertion_transformations),
]

Expand Down Expand Up @@ -462,10 +462,10 @@ def _insert_inplace_operation(

@staticmethod
def _apply_bias_insertion_transformations(
model: ov.Model, transformations: List[OVNullBiasInsertionCommand]
model: ov.Model, transformations: List[OVBiasInsertionCommand]
) -> ov.Model:
"""
Inserts null bias operation after corresponding layer.
Inserts bias operation after corresponding layer.
:param transformations: List of the bias insertion transformations.
:returns: Transformed model with null biases.
Expand All @@ -476,14 +476,10 @@ def _apply_bias_insertion_transformations(
node = name_to_node_mapping[node_name]
# Since layers that may have biases mostly are Convolution or MatMul variations,
# we may use only 0 output port.
node_shape = node.output(0).partial_shape.get_max_shape()
node_output_port = node.output(transformation.target_point.port_id)
node_output_source_ports = node_output_port.get_target_inputs()

bias_shape = [1] * len(node_shape)
bias_shape[1] = node_shape[1]
const_value = np.zeros(bias_shape, dtype=node.get_element_type().to_dtype())
bias_const_node = opset.constant(const_value, dtype=node.get_element_type())
bias_const_node = opset.constant(transformation.bias_value, dtype=node.get_element_type().to_dtype())
bias_const_output_port = bias_const_node.output(0)

add_node = opset.add(node_output_port, bias_const_output_port, name=f"{node_name}/nncf_null_bias_")
Expand Down
4 changes: 3 additions & 1 deletion nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionBackpropDataMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype
from nncf.openvino.graph.node_utils import create_bias_tensor
from nncf.openvino.graph.node_utils import is_node_with_bias
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator

Expand All @@ -45,7 +46,8 @@ def insert_null_biases(model: ov.Model, graph: NNCFGraph) -> ov.Model:
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)
for node_without_bias in nodes_without_biases:
bias_insertion_command = OVCommandCreator.create_command_to_insert_bias(node_without_bias)
const_value = create_bias_tensor(node_without_bias, graph, 0)
bias_insertion_command = OVCommandCreator.create_command_to_insert_bias(node_without_bias, const_value)
transformation_layout.register(bias_insertion_command)
return model_transformer.transform(transformation_layout)

Expand Down
18 changes: 17 additions & 1 deletion nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Optional, Tuple, Type
from typing import Any, Callable, List, Optional, Tuple, Type

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -364,3 +364,19 @@ def get_channel_agnostic_reduction_shape(channel_axes: List[int], shape: List[in
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_shape[channel_axis]
return tuple(reduction_shape)


def create_bias_tensor(node_without_bias: NNCFNode, graph: NNCFGraph, value: Any) -> np.ndarray:
"""
Creates bias value constant array filled by given value.
:param node_without_bias: NNCFNode to add bias to.
:param graph: Target NNCFgraph.
:param value: Value to fill bias constant array.
:return: Bias value constant array filled by given value.
"""
node_shape = graph.get_output_edges(node_without_bias)[0].tensor_shape
bias_shape = [1] * len(node_shape)
channel_axis = node_without_bias.metatype.output_channel_axis
bias_shape[channel_axis] = node_shape[1]
return np.full(bias_shape, value)
6 changes: 3 additions & 3 deletions nncf/openvino/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from nncf.common.graph.transformations.command_creation import CommandCreator
from nncf.common.graph.transformations.commands import TargetType
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand

Expand Down Expand Up @@ -54,9 +54,9 @@ def create_command_to_update_weight(
return OVWeightUpdateCommand(target_point, weight_value)

@staticmethod
def create_command_to_insert_bias(node_without_bias: NNCFNode) -> OVNullBiasInsertionCommand:
def create_command_to_insert_bias(node_without_bias: NNCFNode, bias_value: np.ndarray) -> OVBiasInsertionCommand:
target_point = OVTargetPoint(TargetType.POST_LAYER_OPERATION, node_without_bias.node_name, 0)
return OVNullBiasInsertionCommand(target_point)
return OVBiasInsertionCommand(target_point, bias_value)

@staticmethod
def multiply_insertion_command(
Expand Down
8 changes: 5 additions & 3 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,18 @@ def union(self, other: "Command") -> "Command":
raise NotImplementedError()


class OVNullBiasInsertionCommand(TransformationCommand):
class OVBiasInsertionCommand(TransformationCommand):
"""
Inserts null bias for the corresponding node.
Inserts bias for the corresponding node.
"""

def __init__(self, target_point: OVTargetPoint):
def __init__(self, target_point: OVTargetPoint, bias_value: np.ndarray):
"""
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
:param bias_value: Constant value for the bias layer.
"""
super().__init__(TransformationType.INSERT, target_point)
self.bias_value = bias_value

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
Expand Down
66 changes: 40 additions & 26 deletions tests/openvino/native/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from nncf.openvino.graph.node_utils import get_ov_model_reduce_node_name
from nncf.openvino.graph.node_utils import get_result_node_name
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
Expand All @@ -56,9 +56,12 @@
TARGET_WEIGHTS_FQS = [["Add/fq_weights_1"], ["MatMul/fq_weights_1"], ["Add/fq_weights_1", "MatMul/fq_weights_1"]]


def create_transformed_model(model, target_layers, target_type, command_type, port_id=0, **kwargs):
def create_transformed_model(model, target_layers, target_type, command_type, port_id=0, command_kwargs=None):
transformation_layout = TransformationLayout()
for target_layer in target_layers:
command_kwargs = command_kwargs or {}
if isinstance(command_kwargs, dict):
command_kwargs = [command_kwargs] * len(target_layers)
for target_layer, kwargs in zip(target_layers, command_kwargs):
target_point = OVTargetPoint(target_type, target_layer, port_id=port_id)
command = command_type(target_point, **kwargs)
transformation_layout.register(command)
Expand Down Expand Up @@ -204,9 +207,11 @@ def test_inplace_fn_insertion(test_params: InplaceOpTestCase, target_type, targe
target_layers,
target_type,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=test_params.op_builder(test_params.name, test_params.reduce_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": test_params.op_builder(test_params.name, test_params.reduce_shape),
"fn_output_port_id": 0,
},
)

inplace_branches_num = 1
Expand Down Expand Up @@ -255,9 +260,11 @@ def test_split_inplace_fn_insertion(test_params: InplaceOpTestCase):
[target_layer],
TargetType.POST_LAYER_OPERATION,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=test_params.op_builder(test_params.name, test_params.reduce_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": test_params.op_builder(test_params.name, test_params.reduce_shape),
"fn_output_port_id": 0,
},
)

target_node = get_node_by_name(transformed_model, target_layer)
Expand Down Expand Up @@ -301,9 +308,11 @@ def test_inplace_reduce_fn_zero_rank_output(reduction_shape):
[target_layer],
TargetType.OPERATION_WITH_WEIGHTS,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=get_inplace_min_op(name, reduction_shape=reduction_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": get_inplace_min_op(name, reduction_shape=reduction_shape),
"fn_output_port_id": 0,
},
)
target_node = get_prev_node(get_node_by_name(transformed_model, target_layer), 1)
check_inplace_op(target_node, ["ReduceMin"], [[]], 1, 0)
Expand Down Expand Up @@ -354,9 +363,7 @@ def test_inplace_mean_per_ch_fn_dynamic_shapes(test_params: InplaceOpTestCase, i
def test_output_insertion(target_type, target_layers):
model = LinearModel().ov_model
port_id = 1 if target_type == TargetType.OPERATION_WITH_WEIGHTS else 0
transformed_model = create_transformed_model(
model, target_layers, target_type, OVOutputInsertionCommand, port_id=port_id
)
transformed_model = create_transformed_model(model, target_layers, target_type, OVOutputInsertionCommand, port_id)

if target_type == TargetType.PRE_LAYER_OPERATION:
target_layers = ["Reshape"]
Expand Down Expand Up @@ -386,7 +393,7 @@ def test_split_output_insertion(test_params: InplaceOpTestCase):
target_layer = "Split"
port_id = 1
transformed_model = create_transformed_model(
model, [target_layer], TargetType.POST_LAYER_OPERATION, OVOutputInsertionCommand, port_id=port_id
model, [target_layer], TargetType.POST_LAYER_OPERATION, OVOutputInsertionCommand, port_id
)

target_node = get_node_by_name(transformed_model, target_layer)
Expand Down Expand Up @@ -431,7 +438,7 @@ def test_fq_insertion_pre_layer(target_layers, ref_fq_names):
target_layers,
TargetType.PRE_LAYER_OPERATION,
OVQuantizerInsertionCommand,
quantizer_parameters=quantizer_parameters,
command_kwargs={"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand All @@ -452,7 +459,7 @@ def test_fq_insertion_post_layer(target_layers, ref_fq_names):
target_layers,
TargetType.POST_LAYER_OPERATION,
OVQuantizerInsertionCommand,
quantizer_parameters=quantizer_parameters,
command_kwargs={"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand All @@ -473,8 +480,8 @@ def test_fq_insertion_weights(target_layers, ref_fq_names):
target_layers,
TargetType.OPERATION_WITH_WEIGHTS,
OVQuantizerInsertionCommand,
port_id=1,
quantizer_parameters=quantizer_parameters,
1,
{"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand Down Expand Up @@ -507,7 +514,7 @@ def test_bias_correction(model_with_parameters):
refs = model_with_parameters["refs"]

transformed_model = create_transformed_model(
model, layers, TargetType.LAYER, OVBiasCorrectionCommand, port_id=1, **{"bias_value": values}
model, layers, TargetType.LAYER, OVBiasCorrectionCommand, 1, {"bias_value": values}
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

Expand Down Expand Up @@ -545,8 +552,8 @@ def infer_model_with_ones(model, shape):


MODELS_WITH_PARAMETERS = [
{"model": ConvNotBiasModel().ov_model, "layers": ["Conv"]},
{"model": WeightsModel().ov_model, "layers": ["Conv", "Conv_backprop"]},
{"model": ConvNotBiasModel().ov_model, "layers": [["Conv"], [(1, 3, 1, 1)]]},
{"model": WeightsModel().ov_model, "layers": [["Conv", "Conv_backprop"], [(1, 3, 1, 1), (1, 3, 1, 1)]]},
]


Expand All @@ -555,10 +562,17 @@ def test_null_biases_insertion(model_with_parameters):
model = model_with_parameters["model"]
layers = model_with_parameters["layers"]

transformed_model = create_transformed_model(model, layers, TargetType.LAYER, OVNullBiasInsertionCommand, port_id=0)
transformed_model = create_transformed_model(
model,
layers[0],
TargetType.LAYER,
OVBiasInsertionCommand,
port_id=0,
command_kwargs=[{"bias_value": np.zeros(shape, dtype=np.int8)} for shape in layers[1]],
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

for layer_name in layers:
for layer_name in layers[0]:
node = ops_dict[layer_name]
layer_shape = ops_dict[layer_name].shape
bias_dtype = node.get_element_type().to_dtype()
Expand Down Expand Up @@ -626,7 +640,7 @@ def test_multiply_insertion(model_with_parameters):
TargetType.POST_LAYER_OPERATION,
OVMultiplyInsertionCommand,
port_id=output_port_id,
**{"scale_value": scale, "destination_node_names": dest_nodes, "multiply_node_name": "test_name"},
command_kwargs={"scale_value": scale, "destination_node_names": dest_nodes, "multiply_node_name": "test_name"},
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

Expand Down
49 changes: 49 additions & 0 deletions tests/openvino/native/test_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.node_utils import create_bias_tensor
from tests.common.quantization.mock_graphs import NodeWithType
from tests.common.quantization.mock_graphs import create_mock_graph
from tests.common.quantization.mock_graphs import get_nncf_graph_from_mock_nx_graph

# pylint:disable=protected-access


def get_nncf_graph_for_test(edge_shape, dtype):
nodes = [
NodeWithType("Input_1", None),
NodeWithType("Conv_1", OVConvolutionMetatype),
NodeWithType("Output_1", None),
]
node_edges = [
("Input_1", "Conv_1"),
("Conv_1", "Output_1"),
]
original_mock_graph = create_mock_graph(nodes, node_edges)
nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph)
nncf_graph._nx_graph.out_edges[("1 /Conv_1_0", "2 /Output_1_0")][nncf_graph.ACTIVATION_SHAPE_EDGE_ATTR] = edge_shape
nncf_graph._nx_graph.out_edges[("1 /Conv_1_0", "2 /Output_1_0")][nncf_graph.DTYPE_EDGE_ATTR] = dtype
return nncf_graph


@pytest.mark.parametrize(
"edge_shape,dtype,ref_shape",
[((2, 3, 4, 5), np.float32, (1, 3, 1, 1)), ((1, 1, 2, 3), np.float64, (1, 1, 1, 1))],
)
def test_create_bias_constant_value(edge_shape, dtype, ref_shape):
graph = get_nncf_graph_for_test(edge_shape, dtype)
val = create_bias_tensor(graph.get_node_by_name("/Conv_1_0"), graph, 5)
assert val.shape == ref_shape
assert np.equal(val, np.full(ref_shape, 5)).all()

0 comments on commit 74655f5

Please sign in to comment.