Skip to content

Commit

Permalink
CA: insert bias only when it is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 18, 2023
1 parent 0903907 commit 5ea863a
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 56 deletions.
12 changes: 12 additions & 0 deletions nncf/common/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def create_command_to_update_bias(
:return: The command to update bias value.
"""

@staticmethod
@abstractmethod
def create_command_to_insert_bias(node_without_bias: NNCFNode, bias_value: Any) -> TransformationCommand:
"""
Creates command to insert bias after given node.
:param node_without_bias: The node that corresponds to the operation without bias.
:param bias_value: Bias value to insert.
:param nncf_graph: The NNCF graph.
:return: The command to insert bias value.
"""

@staticmethod
@abstractmethod
def create_command_to_update_weight(
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class AdvancedQuantizationParameters:
overflow_fix: OverflowFix = OverflowFix.FIRST_LAYER
quantize_outputs: bool = False
inplace_statistics: bool = True
disable_channel_alignment: bool = True
disable_channel_alignment: bool = False
disable_bias_correction: bool = False
smooth_quant_alpha: float = 0.95

Expand Down
46 changes: 25 additions & 21 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.openvino.graph.model_utils import create_bias_constant_value
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend
Expand Down Expand Up @@ -120,15 +121,14 @@ def filter_func(point: StatisticPoint) -> bool:
conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity)
conv_out_cont = ConvParamsContainer(conv_out, model, graph, self._backend_entity)

if conv_in_cont.has_bias() and conv_out_cont.has_bias():
amean = (stat.max_values + stat.min_values) * 0.5
conv_in_cont.bias, conv_out_cont.bias = self._align_means(
conv_in_cont.bias,
conv_out_cont.bias,
conv_out_cont.weight,
amean,
conv_out_cont.dims,
)
amean = (stat.max_values + stat.min_values) * 0.5
conv_in_cont.bias, conv_out_cont.bias = self._align_means(
conv_in_cont.bias,
conv_out_cont.bias,
conv_out_cont.weight,
amean,
conv_out_cont.dims,
)

ascale = (stat.max_values - stat.min_values).astype(np.float32)
eps = np.finfo(ascale.dtype).eps
Expand All @@ -153,9 +153,11 @@ def filter_func(point: StatisticPoint) -> bool:
)

if container.stated_bias.is_modified():
transformation_layout.register(
command_creator.create_command_to_update_bias(container.op, container.bias, graph),
)
if container.bias_op_exist():
command = command_creator.create_command_to_update_bias(container.op, container.bias, graph)
else:
command = command_creator.create_command_to_insert_bias(container.op, container.bias)
transformation_layout.register(command)

transformed_model = model_transformer.transform(transformation_layout)
return transformed_model
Expand Down Expand Up @@ -239,10 +241,7 @@ def _align_scales(
scale_in_shape[conv_in_descr.conv_weight_out_channels_dim] = scale_factor.shape[conv_in_descr.bias_channels_dim]
updated_conv_in_value = conv_in_value / scale_factor.reshape(scale_in_shape)

if bias_in_value is not None:
updated_bias_in_value = bias_in_value / scale_factor.reshape(bias_in_value.shape)
else:
updated_bias_in_value = None
updated_bias_in_value = bias_in_value / scale_factor.reshape(bias_in_value.shape)

scale_out_shape = np.ones(len(conv_out_value.shape), dtype=int)
scale_out_shape[conv_out_descr.conv_weight_in_channels_dim] = scale_factor.shape[
Expand Down Expand Up @@ -433,18 +432,23 @@ class ConvParamsContainer:
Convolution container class which is incapsulating common convolutional parameters collection.
"""

def __init__(self, conv_op, model, nncf_graph, backend_entity: ChannelAlignmentAlgoBackend):
def __init__(
self, conv_op: NNCFNode, model: TModel, nncf_graph: NNCFGraph, backend_entity: ChannelAlignmentAlgoBackend
):
"""
:param conv_op: Backend-specific conv node.
:param conv_op: NNCF conv node.
:param model: Backend-specific model instance.
:param nncf_graph: NNCFGraph of given backend-specific model.
:param backend_entity: Current backend entity to retrieve parameters from given conv node
"""
_, self._weights_port_id = backend_entity.get_weights_port_ids_for_node(conv_op)
self.stated_weight = StatedTensor(backend_entity.get_weight_value(conv_op, model, self._weights_port_id))
bias = None
self._bias_op_exist = False
if backend_entity.is_node_with_bias(conv_op, nncf_graph):
bias = backend_entity.get_bias_value(conv_op, model, nncf_graph)
self._bias_op_exist = True
else:
bias = create_bias_constant_value(conv_op, nncf_graph, 0)
self.stated_bias = StatedTensor(bias)
self._op = conv_op
self._dims = backend_entity.get_dims_descriptor(conv_op)
Expand Down Expand Up @@ -477,5 +481,5 @@ def weight_port_id(self):
def dims(self) -> LayoutDescriptor:
return self._dims

def has_bias(self) -> bool:
return self.bias is not None
def bias_op_exist(self) -> bool:
return self._bias_op_exist
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
inplace_statistics=advanced_parameters.inplace_statistics,
backend_params=advanced_parameters.backend_params,
)
self.first_stage_algorithms.append(self.FirstStageAlgorithm(channel_alignment, [insert_null_biases_pass]))
self.first_stage_algorithms.append(self.FirstStageAlgorithm(channel_alignment, []))

min_max_quantization = MinMaxQuantization(
preset=preset,
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/native/quantization/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_constant_metatype(self):
return OVConstantMetatype

def get_transformation_commands(self):
return OVBiasCorrectionCommand, OVWeightUpdateCommand
return OVBiasInsertionCommand, OVBiasCorrectionCommand, OVWeightUpdateCommand

def mock_command_creation_factory(self, mocker) -> None:
mocker.patch("nncf.common.factory.CommandCreatorFactory.create", return_value=OVCommandCreator)
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/tools/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def calculate_per_sample_subset_size(self, sequence_subset_size):


def quantize_model(xml_path, bin_path, accuracy_checker_config, quantization_impl, quantization_parameters):
quantization_parameters["subset_size"] = 1
ov_model = ov.Core().read_model(model=xml_path, weights=bin_path)
model_evaluator = create_model_evaluator(accuracy_checker_config)
model_evaluator.load_network([{"model": ov_model}])
Expand Down
61 changes: 29 additions & 32 deletions tests/post_training/test_templates/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,15 @@ def test_align_means(self, conv_out_value, refs, transposed):
REF_UPDATED_CONV_OUT = np.array([[0.0, 2.0, 0.04, 600, 8], [10, 12, 0.14, 1600, 18]])
REF_UPDATED_BIAS_IN = np.array([2, 4, 600, 0.08, 10])

@pytest.mark.parametrize("bias_in_value", [np.array([2, 4, 6, 8, 10]), None])
@pytest.mark.parametrize("bias_in_value", [np.array([2, 4, 6, 8, 10])])
def test_align_scales(self, bias_in_value):
def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in):
assert updated_conv_in.shape == self.REF_UPDATED_CONV_IN.shape
assert np.allclose(updated_conv_in, self.REF_UPDATED_CONV_IN)
assert updated_conv_out.shape == self.REF_UPDATED_CONV_OUT.shape
assert np.allclose(updated_conv_out, self.REF_UPDATED_CONV_OUT)
if bias_in_value is None:
assert updated_bias_in is None
else:
assert updated_bias_in.shape == self.REF_UPDATED_BIAS_IN.shape
assert np.allclose(updated_bias_in, self.REF_UPDATED_BIAS_IN)
assert updated_bias_in.shape == self.REF_UPDATED_BIAS_IN.shape
assert np.allclose(updated_bias_in, self.REF_UPDATED_BIAS_IN)

conv_in_value = np.arange(5).reshape(5, 1)
conv_out_value = np.arange(10).reshape(2, 5) * 2
Expand Down Expand Up @@ -225,7 +222,6 @@ def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in):
updated_conv_in = updated_conv_in.reshape(updated_conv_in.shape[1:])
check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in)

GET_NODES_TEST_CASES = []
GET_NODES_TEST_CASES = [(VALID_CONV_LAYER_ATTR, VALID_CONV_LAYER_ATTR, True)]
GET_NODES_TEST_CASES.extend([(attr, VALID_CONV_LAYER_ATTR, True) for attr in INVALID_CONSUMER_CONV_LAYER_ATTRS])
GET_NODES_TEST_CASES.extend([(VALID_CONV_LAYER_ATTR, attr, False) for attr in INVALID_CONSUMER_CONV_LAYER_ATTRS])
Expand Down Expand Up @@ -370,27 +366,24 @@ class MockBackend(backend_cls):
assert len(arg.transformations) == 0
return

align_means_called = 1 if num_biases == 2 else 0
assert algorithm._align_means.call_count == align_means_called
if align_means_called:
algorithm._align_means.assert_called_once_with(
ref_bias_val + "1",
ref_bias_val + "2",
ref_weights_val + "2",
np.array(0.5, dtype=np.float32),
ref_dims_descr + "2",
)
assert algorithm._align_means.call_count == 1
args = [
np.zeros((1, 1, 1, 1)),
np.zeros((1, 1, 1, 1)),
ref_weights_val + "2",
np.array(0.5, dtype=np.float32),
ref_dims_descr + "2",
]
for i in range(num_biases):
args[i] = f"ref_bias_val{i + 1}"

algorithm._align_means.assert_called_once_with(*args)

assert algorithm._align_scales.call_count == 1
args = algorithm._align_scales.call_args.args
assert args[0] == ref_weights_val + "1"
assert args[1] == ref_weights_val + "2"
if num_biases == 2:
assert args[2] == ref_bias_in_after_align
elif num_biases == 1:
assert args[2] == ref_bias_val + "1"
else:
assert args[2] is None
assert args[2] == ref_bias_in_after_align
assert ((args[3] - 3) < EPS).all()
assert args[4] == ref_dims_descr + "1"
assert args[5] == ref_dims_descr + "2"
Expand All @@ -408,14 +401,20 @@ class MockBackend(backend_cls):
},
"/Conv_2_0": {"weight_value": ref_weights_out_after_scale_align, "bias_value": ref_bias_out_after_align},
}
bias_update_cls, weights_update_cls = self.get_transformation_commands()

bias_insert_cls, bias_update_cls, weights_update_cls = self.get_transformation_commands()
for transformation in transformations:
assert transformation.type == TransformationType.CHANGE
tp = transformation.target_point
if isinstance(transformation, bias_update_cls):
assert transformation.type == TransformationType.CHANGE
_class = bias_update_cls
_attr = "bias_value"
elif isinstance(transformation, bias_insert_cls):
assert transformation.type == TransformationType.INSERT
_class = bias_insert_cls
_attr = "bias_value"
elif isinstance(transformation, weights_update_cls):
assert transformation.type == TransformationType.CHANGE
_class = weights_update_cls
_attr = "weight_value"
else:
Expand All @@ -425,18 +424,16 @@ class MockBackend(backend_cls):
assert ref_values[tp.target_node_name][_attr] == getattr(transformation, _attr)

if num_biases == 2:
ref_len = {"/Conv_1_0": 2, "/Conv_2_0": 2}
ref_classes = {"/Conv_1_0": bias_update_cls, "/Conv_2_0": bias_update_cls}
elif num_biases == 1:
ref_len = {"/Conv_1_0": 2, "/Conv_2_0": 1}
ref_classes = {"/Conv_1_0": bias_update_cls, "/Conv_2_0": bias_insert_cls}
else:
ref_len = {"/Conv_1_0": 1, "/Conv_2_0": 1}
ref_classes = {"/Conv_1_0": bias_insert_cls, "/Conv_2_0": bias_insert_cls}

for node_name, _transformations in target_names.items():
_ref_len = ref_len[node_name]
assert len(_transformations) == _ref_len
assert len(_transformations) == 2
assert weights_update_cls in _transformations
if _ref_len == 2:
assert bias_update_cls in _transformations
assert ref_classes[node_name] in _transformations

@pytest.mark.parametrize("num_biases", [0, 1, 2])
def test_get_statistic_points(self, num_biases, mocker):
Expand Down

0 comments on commit 5ea863a

Please sign in to comment.