Skip to content

Commit

Permalink
Enable ChannelAlignment algorithm by default
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 25, 2023
1 parent 61538a2 commit 1eb303f
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 76 deletions.
12 changes: 12 additions & 0 deletions nncf/common/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def create_command_to_update_bias(
:return: The command to update bias value.
"""

@staticmethod
@abstractmethod
def create_command_to_insert_bias(node_without_bias: NNCFNode, bias_value: Any) -> TransformationCommand:
"""
Creates command to insert bias after given node.
:param node_without_bias: The node that corresponds to the operation without bias.
:param bias_value: Bias value to insert.
:param nncf_graph: The NNCF graph.
:return: The command to insert bias value.
"""

@staticmethod
@abstractmethod
def create_command_to_update_weight(
Expand Down
18 changes: 18 additions & 0 deletions nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from typing import Any

import numpy as np
import openvino.runtime as ov

from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.openvino.graph.metatypes.common import FAKE_QUANTIZE_OPERATIONS
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype
Expand Down Expand Up @@ -52,6 +55,21 @@ def insert_null_biases(model: ov.Model, graph: NNCFGraph) -> ov.Model:
return model_transformer.transform(transformation_layout)


def create_bias_constant_value(node_without_bias: NNCFNode, graph: NNCFGraph, value: Any) -> np.ndarray:
"""
Creates bias value constant array filled by given value.
:param node_without_bias: NNCFNode to add bias to.
:param graph: Target NNCFgraph.
:param value: Value to fill bias constant array.
:return: Bias value constant array filled by given value.
"""
node_shape = graph.get_output_edges(node_without_bias)[0].tensor_shape
bias_shape = [1] * len(node_shape)
bias_shape[1] = node_shape[1]
return np.full(bias_shape, value)


def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model:
"""
This method removes the activation Fake Quantize nodes from the model.
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class AdvancedQuantizationParameters:
overflow_fix: OverflowFix = OverflowFix.FIRST_LAYER
quantize_outputs: bool = False
inplace_statistics: bool = True
disable_channel_alignment: bool = True
disable_channel_alignment: bool = False
disable_bias_correction: bool = False
smooth_quant_alpha: float = 0.95

Expand Down
49 changes: 25 additions & 24 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.openvino.graph.model_utils import create_bias_constant_value
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend
Expand Down Expand Up @@ -59,20 +60,17 @@ def __init__(
self,
subset_size: int = 100,
inplace_statistics: bool = True,
backend_params: Optional[Dict[str, Any]] = None,
):
"""
:param subset_size: Size of a subset for the statistics collection,
defaults to 100.
:param inplace_statistics: Defines wheather to calculate quantizers statistics
by backend graph operations or by default Python implementation, defaults
to True.
:param backend_params: Backend specific parameters.
"""
super().__init__()
self.subset_size = subset_size
self.inplace_statistics = inplace_statistics
self.backend_params = backend_params
self._backend_entity = None
self._quantile = 1e-4
self._algorithm_key = f"CA_{hash(self)}"
Expand Down Expand Up @@ -120,15 +118,14 @@ def filter_func(point: StatisticPoint) -> bool:
conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity)
conv_out_cont = ConvParamsContainer(conv_out, model, graph, self._backend_entity)

if conv_in_cont.has_bias() and conv_out_cont.has_bias():
amean = (stat.max_values + stat.min_values) * 0.5
conv_in_cont.bias, conv_out_cont.bias = self._align_means(
conv_in_cont.bias,
conv_out_cont.bias,
conv_out_cont.weight,
amean,
conv_out_cont.dims,
)
amean = (stat.max_values + stat.min_values) * 0.5
conv_in_cont.bias, conv_out_cont.bias = self._align_means(
conv_in_cont.bias,
conv_out_cont.bias,
conv_out_cont.weight,
amean,
conv_out_cont.dims,
)

ascale = (stat.max_values - stat.min_values).astype(np.float32)
eps = np.finfo(ascale.dtype).eps
Expand All @@ -153,9 +150,11 @@ def filter_func(point: StatisticPoint) -> bool:
)

if container.stated_bias.is_modified():
transformation_layout.register(
command_creator.create_command_to_update_bias(container.op, container.bias, graph),
)
if container.bias_op_exist():
command = command_creator.create_command_to_update_bias(container.op, container.bias, graph)
else:
command = command_creator.create_command_to_insert_bias(container.op, container.bias)
transformation_layout.register(command)

transformed_model = model_transformer.transform(transformation_layout)
return transformed_model
Expand Down Expand Up @@ -239,10 +238,7 @@ def _align_scales(
scale_in_shape[conv_in_descr.conv_weight_out_channels_dim] = scale_factor.shape[conv_in_descr.bias_channels_dim]
updated_conv_in_value = conv_in_value / scale_factor.reshape(scale_in_shape)

if bias_in_value is not None:
updated_bias_in_value = bias_in_value / scale_factor.reshape(bias_in_value.shape)
else:
updated_bias_in_value = None
updated_bias_in_value = bias_in_value / scale_factor.reshape(bias_in_value.shape)

scale_out_shape = np.ones(len(conv_out_value.shape), dtype=int)
scale_out_shape[conv_out_descr.conv_weight_in_channels_dim] = scale_factor.shape[
Expand Down Expand Up @@ -433,18 +429,23 @@ class ConvParamsContainer:
Convolution container class which is incapsulating common convolutional parameters collection.
"""

def __init__(self, conv_op, model, nncf_graph, backend_entity: ChannelAlignmentAlgoBackend):
def __init__(
self, conv_op: NNCFNode, model: TModel, nncf_graph: NNCFGraph, backend_entity: ChannelAlignmentAlgoBackend
):
"""
:param conv_op: Backend-specific conv node.
:param conv_op: NNCF conv node.
:param model: Backend-specific model instance.
:param nncf_graph: NNCFGraph of given backend-specific model.
:param backend_entity: Current backend entity to retrieve parameters from given conv node
"""
_, self._weights_port_id = backend_entity.get_weights_port_ids_for_node(conv_op)
self.stated_weight = StatedTensor(backend_entity.get_weight_value(conv_op, model, self._weights_port_id))
bias = None
self._bias_op_exist = False
if backend_entity.is_node_with_bias(conv_op, nncf_graph):
bias = backend_entity.get_bias_value(conv_op, model, nncf_graph)
self._bias_op_exist = True
else:
bias = create_bias_constant_value(conv_op, nncf_graph, 0)
self.stated_bias = StatedTensor(bias)
self._op = conv_op
self._dims = backend_entity.get_dims_descriptor(conv_op)
Expand Down Expand Up @@ -477,5 +478,5 @@ def weight_port_id(self):
def dims(self) -> LayoutDescriptor:
return self._dims

def has_bias(self) -> bool:
return self.bias is not None
def bias_op_exist(self) -> bool:
return self._bias_op_exist
22 changes: 4 additions & 18 deletions nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar

from nncf import Dataset
Expand All @@ -34,7 +33,6 @@
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.passes import insert_null_biases_pass
from nncf.scopes import IgnoredScope

TModel = TypeVar("TModel")
Expand All @@ -49,11 +47,6 @@ class PostTrainingQuantization(Algorithm):
3) FastBiasCorrection or BiasCorrection
"""

@dataclass
class FirstStageAlgorithm:
algorithm: "Algorithm"
pre_passes: List[TPass]

def __init__(
self,
preset: QuantizationPreset = QuantizationPreset.PERFORMANCE,
Expand Down Expand Up @@ -87,7 +80,7 @@ def __init__(
"""
super().__init__()
self.algorithms = []
self.first_stage_algorithms: List[self.FirstStageAlgorithm] = []
self.first_stage_algorithms: List[Algorithm] = []

if target_device is TargetDevice.VPU:
warning_deprecated("VPU device is deprecated and will no longer be supported in the future.")
Expand All @@ -101,15 +94,14 @@ def __init__(
inplace_statistics=advanced_parameters.inplace_statistics,
alpha=advanced_parameters.smooth_quant_alpha,
)
self.first_stage_algorithms.append(self.FirstStageAlgorithm(smooth_quant_algorithm, []))
self.first_stage_algorithms.append(smooth_quant_algorithm)

if not advanced_parameters.disable_channel_alignment:
channel_alignment = ChannelAlignment(
subset_size=subset_size,
inplace_statistics=advanced_parameters.inplace_statistics,
backend_params=advanced_parameters.backend_params,
)
self.first_stage_algorithms.append(self.FirstStageAlgorithm(channel_alignment, [insert_null_biases_pass]))
self.first_stage_algorithms.append(channel_alignment)

min_max_quantization = MinMaxQuantization(
preset=preset,
Expand Down Expand Up @@ -187,9 +179,7 @@ def apply(
modified_model_graph = graph
backend = get_backend(modified_model)

for first_stage_algorithm in self.first_stage_algorithms:
algorithm = first_stage_algorithm.algorithm

for algorithm in self.first_stage_algorithms:
if isinstance(algorithm, SmoothQuant) and backend != BackendType.OPENVINO:
nncf_logger.debug(f"{backend.name} does not support SmoothQuant algorithm yet.")
continue
Expand All @@ -198,10 +188,6 @@ def apply(
nncf_logger.debug(f"{backend.name} does not support ChannelAlignment algorithm yet.")
continue

for pre_pass in first_stage_algorithm.pre_passes:
modified_model = pre_pass(modified_model, modified_model_graph)
modified_model_graph = NNCFGraphFactory.create(modified_model)

statistics_aggregator = StatisticsAggregatorFactory.create(modified_model, dataset)
algo_statistic_points = algorithm.get_statistic_points(modified_model, modified_model_graph)
statistics_aggregator.register_statistic_points(algo_statistic_points)
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/native/quantization/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_constant_metatype(self):
return OVConstantMetatype

def get_transformation_commands(self):
return OVBiasCorrectionCommand, OVWeightUpdateCommand
return OVBiasInsertionCommand, OVBiasCorrectionCommand, OVWeightUpdateCommand

def mock_command_creation_factory(self, mocker) -> None:
mocker.patch("nncf.common.factory.CommandCreatorFactory.create", return_value=OVCommandCreator)
Expand Down
61 changes: 29 additions & 32 deletions tests/post_training/test_templates/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,15 @@ def test_align_means(self, conv_out_value, refs, transposed):
REF_UPDATED_CONV_OUT = np.array([[0.0, 2.0, 0.04, 600, 8], [10, 12, 0.14, 1600, 18]])
REF_UPDATED_BIAS_IN = np.array([2, 4, 600, 0.08, 10])

@pytest.mark.parametrize("bias_in_value", [np.array([2, 4, 6, 8, 10]), None])
@pytest.mark.parametrize("bias_in_value", [np.array([2, 4, 6, 8, 10])])
def test_align_scales(self, bias_in_value):
def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in):
assert updated_conv_in.shape == self.REF_UPDATED_CONV_IN.shape
assert np.allclose(updated_conv_in, self.REF_UPDATED_CONV_IN)
assert updated_conv_out.shape == self.REF_UPDATED_CONV_OUT.shape
assert np.allclose(updated_conv_out, self.REF_UPDATED_CONV_OUT)
if bias_in_value is None:
assert updated_bias_in is None
else:
assert updated_bias_in.shape == self.REF_UPDATED_BIAS_IN.shape
assert np.allclose(updated_bias_in, self.REF_UPDATED_BIAS_IN)
assert updated_bias_in.shape == self.REF_UPDATED_BIAS_IN.shape
assert np.allclose(updated_bias_in, self.REF_UPDATED_BIAS_IN)

conv_in_value = np.arange(5).reshape(5, 1)
conv_out_value = np.arange(10).reshape(2, 5) * 2
Expand Down Expand Up @@ -225,7 +222,6 @@ def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in):
updated_conv_in = updated_conv_in.reshape(updated_conv_in.shape[1:])
check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in)

GET_NODES_TEST_CASES = []
GET_NODES_TEST_CASES = [(VALID_CONV_LAYER_ATTR, VALID_CONV_LAYER_ATTR, True)]
GET_NODES_TEST_CASES.extend([(attr, VALID_CONV_LAYER_ATTR, True) for attr in INVALID_CONSUMER_CONV_LAYER_ATTRS])
GET_NODES_TEST_CASES.extend([(VALID_CONV_LAYER_ATTR, attr, False) for attr in INVALID_CONSUMER_CONV_LAYER_ATTRS])
Expand Down Expand Up @@ -370,27 +366,24 @@ class MockBackend(backend_cls):
assert len(arg.transformations) == 0
return

align_means_called = 1 if num_biases == 2 else 0
assert algorithm._align_means.call_count == align_means_called
if align_means_called:
algorithm._align_means.assert_called_once_with(
ref_bias_val + "1",
ref_bias_val + "2",
ref_weights_val + "2",
np.array(0.5, dtype=np.float32),
ref_dims_descr + "2",
)
assert algorithm._align_means.call_count == 1
args = [
np.zeros((1, 1, 1, 1)),
np.zeros((1, 1, 1, 1)),
ref_weights_val + "2",
np.array(0.5, dtype=np.float32),
ref_dims_descr + "2",
]
for i in range(num_biases):
args[i] = f"ref_bias_val{i + 1}"

algorithm._align_means.assert_called_once_with(*args)

assert algorithm._align_scales.call_count == 1
args = algorithm._align_scales.call_args.args
assert args[0] == ref_weights_val + "1"
assert args[1] == ref_weights_val + "2"
if num_biases == 2:
assert args[2] == ref_bias_in_after_align
elif num_biases == 1:
assert args[2] == ref_bias_val + "1"
else:
assert args[2] is None
assert args[2] == ref_bias_in_after_align
assert ((args[3] - 3) < EPS).all()
assert args[4] == ref_dims_descr + "1"
assert args[5] == ref_dims_descr + "2"
Expand All @@ -408,14 +401,20 @@ class MockBackend(backend_cls):
},
"/Conv_2_0": {"weight_value": ref_weights_out_after_scale_align, "bias_value": ref_bias_out_after_align},
}
bias_update_cls, weights_update_cls = self.get_transformation_commands()

bias_insert_cls, bias_update_cls, weights_update_cls = self.get_transformation_commands()
for transformation in transformations:
assert transformation.type == TransformationType.CHANGE
tp = transformation.target_point
if isinstance(transformation, bias_update_cls):
assert transformation.type == TransformationType.CHANGE
_class = bias_update_cls
_attr = "bias_value"
elif isinstance(transformation, bias_insert_cls):
assert transformation.type == TransformationType.INSERT
_class = bias_insert_cls
_attr = "bias_value"
elif isinstance(transformation, weights_update_cls):
assert transformation.type == TransformationType.CHANGE
_class = weights_update_cls
_attr = "weight_value"
else:
Expand All @@ -425,18 +424,16 @@ class MockBackend(backend_cls):
assert ref_values[tp.target_node_name][_attr] == getattr(transformation, _attr)

if num_biases == 2:
ref_len = {"/Conv_1_0": 2, "/Conv_2_0": 2}
ref_classes = {"/Conv_1_0": bias_update_cls, "/Conv_2_0": bias_update_cls}
elif num_biases == 1:
ref_len = {"/Conv_1_0": 2, "/Conv_2_0": 1}
ref_classes = {"/Conv_1_0": bias_update_cls, "/Conv_2_0": bias_insert_cls}
else:
ref_len = {"/Conv_1_0": 1, "/Conv_2_0": 1}
ref_classes = {"/Conv_1_0": bias_insert_cls, "/Conv_2_0": bias_insert_cls}

for node_name, _transformations in target_names.items():
_ref_len = ref_len[node_name]
assert len(_transformations) == _ref_len
assert len(_transformations) == 2
assert weights_update_cls in _transformations
if _ref_len == 2:
assert bias_update_cls in _transformations
assert ref_classes[node_name] in _transformations

@pytest.mark.parametrize("num_biases", [0, 1, 2])
def test_get_statistic_points(self, num_biases, mocker):
Expand Down

0 comments on commit 1eb303f

Please sign in to comment.