Skip to content

Commit

Permalink
Refactor insert null bias to insert bias with value
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 17, 2023
1 parent 12b1e90 commit 8da977c
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 40 deletions.
12 changes: 4 additions & 8 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.openvino.graph.node_utils import get_result_node_name
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
Expand All @@ -50,7 +50,7 @@ def __init__(self, model: TModel):
(OVModelExtractionCommand, self._apply_model_extraction_transformation),
(OVInplaceFnInsertionCommand, self._apply_insert_operation),
(OVOutputInsertionCommand, self._apply_output_insertion_transformations),
(OVNullBiasInsertionCommand, self._apply_bias_insertion_transformations),
(OVBiasInsertionCommand, self._apply_bias_insertion_transformations),
(OVMultiplyInsertionCommand, self._apply_multiply_insertion_transformations),
]

Expand Down Expand Up @@ -462,7 +462,7 @@ def _insert_inplace_operation(

@staticmethod
def _apply_bias_insertion_transformations(
model: ov.Model, transformations: List[OVNullBiasInsertionCommand]
model: ov.Model, transformations: List[OVBiasInsertionCommand]
) -> ov.Model:
"""
Inserts null bias operation after corresponding layer.
Expand All @@ -476,14 +476,10 @@ def _apply_bias_insertion_transformations(
node = name_to_node_mapping[node_name]
# Since layers that may have biases mostly are Convolution or MatMul variations,
# we may use only 0 output port.
node_shape = node.output(0).partial_shape.get_max_shape()
node_output_port = node.output(transformation.target_point.port_id)
node_output_source_ports = node_output_port.get_target_inputs()

bias_shape = [1] * len(node_shape)
bias_shape[1] = node_shape[1]
const_value = np.zeros(bias_shape, dtype=node.get_element_type().to_dtype())
bias_const_node = opset.constant(const_value, dtype=node.get_element_type())
bias_const_node = opset.constant(transformation.bias_value, dtype=transformation.bias_value.dtype)
bias_const_output_port = bias_const_node.output(0)

add_node = opset.add(node_output_port, bias_const_output_port, name=f"{node_name}/nncf_null_bias_")
Expand Down
19 changes: 18 additions & 1 deletion nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from typing import Any

import numpy as np
import openvino.runtime as ov

from nncf.common.factory import ModelTransformerFactory
Expand Down Expand Up @@ -45,11 +47,26 @@ def insert_null_biases(model: ov.Model, graph: NNCFGraph) -> ov.Model:
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)
for node_without_bias in nodes_without_biases:
bias_insertion_command = OVCommandCreator.create_command_to_insert_bias(node_without_bias)
const_value = create_bias_constant_value(node_without_bias, 0)
bias_insertion_command = OVCommandCreator.create_command_to_insert_bias(node_without_bias, const_value)
transformation_layout.register(bias_insertion_command)
return model_transformer.transform(transformation_layout)


def create_bias_constant_value(node_without_bias: ov.Node, value: Any) -> np.ndarray:
"""
Creates bias value constant array filled by given value.
:param node_without_bias: Node to add bias to.
:param value: Value to fill bias constant array.
:return: Bias value constant array filled by given value.
"""
node_shape = node_without_bias.output(0).partial_shape.get_max_shape()
bias_shape = [1] * len(node_shape)
bias_shape[1] = node_shape[1]
return np.full(bias_shape, value, dtype=node_without_bias.get_element_type().to_dtype())


def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model:
"""
This method removes the activation Fake Quantize nodes from the model.
Expand Down
6 changes: 3 additions & 3 deletions nncf/openvino/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from nncf.common.graph.transformations.command_creation import CommandCreator
from nncf.common.graph.transformations.commands import TargetType
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand

Expand Down Expand Up @@ -54,9 +54,9 @@ def create_command_to_update_weight(
return OVWeightUpdateCommand(target_point, weight_value)

@staticmethod
def create_command_to_insert_bias(node_without_bias: NNCFNode) -> OVNullBiasInsertionCommand:
def create_command_to_insert_bias(node_without_bias: NNCFNode, bias_value: np.ndarray) -> OVBiasInsertionCommand:
target_point = OVTargetPoint(TargetType.POST_LAYER_OPERATION, node_without_bias.node_name, 0)
return OVNullBiasInsertionCommand(target_point)
return OVBiasInsertionCommand(target_point, bias_value)

@staticmethod
def multiply_insertion_command(
Expand Down
5 changes: 3 additions & 2 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,17 @@ def union(self, other: "Command") -> "Command":
raise NotImplementedError()


class OVNullBiasInsertionCommand(TransformationCommand):
class OVBiasInsertionCommand(TransformationCommand):
"""
Inserts null bias for the corresponding node.
"""

def __init__(self, target_point: OVTargetPoint):
def __init__(self, target_point: OVTargetPoint, bias_value: np.ndarray):
"""
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
"""
super().__init__(TransformationType.INSERT, target_point)
self.bias_value = bias_value

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
Expand Down
66 changes: 40 additions & 26 deletions tests/openvino/native/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from nncf.openvino.graph.node_utils import get_ov_model_reduce_node_name
from nncf.openvino.graph.node_utils import get_result_node_name
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVNullBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
Expand All @@ -56,9 +56,12 @@
TARGET_WEIGHTS_FQS = [["Add/fq_weights_1"], ["MatMul/fq_weights_1"], ["Add/fq_weights_1", "MatMul/fq_weights_1"]]


def create_transformed_model(model, target_layers, target_type, command_type, port_id=0, **kwargs):
def create_transformed_model(model, target_layers, target_type, command_type, port_id=0, command_kwargs=None):
transformation_layout = TransformationLayout()
for target_layer in target_layers:
command_kwargs = command_kwargs or {}
if isinstance(command_kwargs, dict):
command_kwargs = [command_kwargs] * len(target_layers)
for target_layer, kwargs in zip(target_layers, command_kwargs):
target_point = OVTargetPoint(target_type, target_layer, port_id=port_id)
command = command_type(target_point, **kwargs)
transformation_layout.register(command)
Expand Down Expand Up @@ -204,9 +207,11 @@ def test_inplace_fn_insertion(test_params: InplaceOpTestCase, target_type, targe
target_layers,
target_type,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=test_params.op_builder(test_params.name, test_params.reduce_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": test_params.op_builder(test_params.name, test_params.reduce_shape),
"fn_output_port_id": 0,
},
)

inplace_branches_num = 1
Expand Down Expand Up @@ -255,9 +260,11 @@ def test_split_inplace_fn_insertion(test_params: InplaceOpTestCase):
[target_layer],
TargetType.POST_LAYER_OPERATION,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=test_params.op_builder(test_params.name, test_params.reduce_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": test_params.op_builder(test_params.name, test_params.reduce_shape),
"fn_output_port_id": 0,
},
)

target_node = get_node_by_name(transformed_model, target_layer)
Expand Down Expand Up @@ -301,9 +308,11 @@ def test_inplace_reduce_fn_zero_rank_output(reduction_shape):
[target_layer],
TargetType.OPERATION_WITH_WEIGHTS,
OVInplaceFnInsertionCommand,
port_id=port_id,
inplace_op_fn=get_inplace_min_op(name, reduction_shape=reduction_shape),
fn_output_port_id=0,
port_id,
{
"inplace_op_fn": get_inplace_min_op(name, reduction_shape=reduction_shape),
"fn_output_port_id": 0,
},
)
target_node = get_prev_node(get_node_by_name(transformed_model, target_layer), 1)
check_inplace_op(target_node, ["ReduceMin"], [[]], 1, 0)
Expand Down Expand Up @@ -354,9 +363,7 @@ def test_inplace_mean_per_ch_fn_dynamic_shapes(test_params: InplaceOpTestCase, i
def test_output_insertion(target_type, target_layers):
model = LinearModel().ov_model
port_id = 1 if target_type == TargetType.OPERATION_WITH_WEIGHTS else 0
transformed_model = create_transformed_model(
model, target_layers, target_type, OVOutputInsertionCommand, port_id=port_id
)
transformed_model = create_transformed_model(model, target_layers, target_type, OVOutputInsertionCommand, port_id)

if target_type == TargetType.PRE_LAYER_OPERATION:
target_layers = ["Reshape"]
Expand Down Expand Up @@ -386,7 +393,7 @@ def test_split_output_insertion(test_params: InplaceOpTestCase):
target_layer = "Split"
port_id = 1
transformed_model = create_transformed_model(
model, [target_layer], TargetType.POST_LAYER_OPERATION, OVOutputInsertionCommand, port_id=port_id
model, [target_layer], TargetType.POST_LAYER_OPERATION, OVOutputInsertionCommand, port_id
)

target_node = get_node_by_name(transformed_model, target_layer)
Expand Down Expand Up @@ -431,7 +438,7 @@ def test_fq_insertion_pre_layer(target_layers, ref_fq_names):
target_layers,
TargetType.PRE_LAYER_OPERATION,
OVQuantizerInsertionCommand,
quantizer_parameters=quantizer_parameters,
command_kwargs={"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand All @@ -452,7 +459,7 @@ def test_fq_insertion_post_layer(target_layers, ref_fq_names):
target_layers,
TargetType.POST_LAYER_OPERATION,
OVQuantizerInsertionCommand,
quantizer_parameters=quantizer_parameters,
command_kwargs={"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand All @@ -473,8 +480,8 @@ def test_fq_insertion_weights(target_layers, ref_fq_names):
target_layers,
TargetType.OPERATION_WITH_WEIGHTS,
OVQuantizerInsertionCommand,
port_id=1,
quantizer_parameters=quantizer_parameters,
1,
{"quantizer_parameters": quantizer_parameters},
)
fq_nodes = get_fq_nodes(transformed_model)

Expand Down Expand Up @@ -507,7 +514,7 @@ def test_bias_correction(model_with_parameters):
refs = model_with_parameters["refs"]

transformed_model = create_transformed_model(
model, layers, TargetType.LAYER, OVBiasCorrectionCommand, port_id=1, **{"bias_value": values}
model, layers, TargetType.LAYER, OVBiasCorrectionCommand, 1, {"bias_value": values}
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

Expand Down Expand Up @@ -545,8 +552,8 @@ def infer_model_with_ones(model, shape):


MODELS_WITH_PARAMETERS = [
{"model": ConvNotBiasModel().ov_model, "layers": ["Conv"]},
{"model": WeightsModel().ov_model, "layers": ["Conv", "Conv_backprop"]},
{"model": ConvNotBiasModel().ov_model, "layers": [["Conv"], [(1, 3, 1, 1)]]},
{"model": WeightsModel().ov_model, "layers": [["Conv", "Conv_backprop"], [(1, 3, 1, 1), (1, 3, 1, 1)]]},
]


Expand All @@ -555,10 +562,17 @@ def test_null_biases_insertion(model_with_parameters):
model = model_with_parameters["model"]
layers = model_with_parameters["layers"]

transformed_model = create_transformed_model(model, layers, TargetType.LAYER, OVNullBiasInsertionCommand, port_id=0)
transformed_model = create_transformed_model(
model,
layers[0],
TargetType.LAYER,
OVBiasInsertionCommand,
port_id=0,
command_kwargs=[{"bias_value": np.zeros(shape, dtype=np.float32)} for shape in layers[1]],
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

for layer_name in layers:
for layer_name in layers[0]:
node = ops_dict[layer_name]
layer_shape = ops_dict[layer_name].shape
bias_dtype = node.get_element_type().to_dtype()
Expand Down Expand Up @@ -626,7 +640,7 @@ def test_multiply_insertion(model_with_parameters):
TargetType.POST_LAYER_OPERATION,
OVMultiplyInsertionCommand,
port_id=output_port_id,
**{"scale_value": scale, "destination_node_names": dest_nodes},
command_kwargs={"scale_value": scale, "destination_node_names": dest_nodes},
)
ops_dict = {op.get_friendly_name(): op for op in transformed_model.get_ops()}

Expand Down
37 changes: 37 additions & 0 deletions tests/openvino/native/test_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
from openvino.runtime import opset9 as opset

from nncf.openvino.graph.model_utils import create_bias_constant_value


def get_conv_node(input_shape, dtype):
input_node = opset.parameter(input_shape, dtype=dtype)
strides = [1, 1]
pads = [0, 0]
dilations = [1, 1]
return opset.convolution(
input_node, np.zeros((4, input_shape[1], 1, 1), dtype=dtype), strides, pads, pads, dilations
)


@pytest.mark.parametrize(
"input_shape,dtype",
[((2, 3, 4, 5), np.float32), ((1, 1, 1, 1), np.float64)],
)
def test_create_bias_constant_value(input_shape, dtype):
conv = get_conv_node(input_shape, dtype)
val = create_bias_constant_value(conv, 5)
assert val.shape == (1, 4, 1, 1)
assert np.equal(val, np.full((1, 4, 1, 1), 5)).all()

0 comments on commit 8da977c

Please sign in to comment.