Skip to content

Commit

Permalink
[PTQ][Experimental TensorCollector] Empty statistic aggregators conta…
Browse files Browse the repository at this point in the history
…iners handling (#2035)

### Changes

Cases when all registered tensors are empty or no tensors was registered
are handled by experimental TensorCollector, MinMax algorithm,
SmoothQuant and ChannelAlignment algorithms. BC and FBC algorithms are
skipped because nodes only quantized nodes biases could be corrected, so
empty statistics should raise an error during MinMax quantization

### Reason for changes

To enable models with empty branches

### Related tickets
116929

### Tests

* Test statistic collectors is updated
* Channel Alignment tests are updated
* SmoothQuant tests are updated
* Update MinMax tests
  • Loading branch information
daniil-lyakhov authored Aug 16, 2023
1 parent 4289324 commit 182903c
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 64 deletions.
41 changes: 26 additions & 15 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,21 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
:param x: Tensor to register.
"""

@abstractmethod
def aggregate(self) -> Any:
"""
Aggregates collected tensors and returns aggregated result.
In case no tensors were collected returns None.
:return: Aggregated result.
"""
if self._collected_samples:
return self._aggregate_impl()
return None

@abstractmethod
def _aggregate_impl(self) -> Any:
"""
Aggregates collected tensors and returns aggregated result.
:return: Aggregated result.
"""
Expand Down Expand Up @@ -503,7 +514,7 @@ def __init__(self, num_samples: Optional[int]):
def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container.append(x.tensor)

def aggregate(self):
def _aggregate_impl(self):
return self._container


Expand All @@ -514,7 +525,7 @@ def __init__(self):
def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def aggregate(self):
def _aggregate_impl(self):
return self._container.shape


Expand All @@ -525,7 +536,7 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
else:
self._container = self._tensor_processor.min(x, self._container)

def aggregate(self):
def _aggregate_impl(self):
return self._container.tensor


Expand All @@ -536,7 +547,7 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
else:
self._container = self._tensor_processor.max(x, self._container)

def aggregate(self):
def _aggregate_impl(self):
return self._container.tensor


Expand All @@ -555,19 +566,19 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
else:
self._container.append(x)

def _aggregate(self, fn):
def _offline_aggregation_impl(self, fn):
stacked_val = self._tensor_processor.stack(self._container)
return fn(stacked_val, axis=0, keepdims=False).tensor


class MeanAggregator(OfflineAggregatorBase):
def aggregate(self):
return self._aggregate(self._tensor_processor.mean)
def _aggregate_impl(self):
return self._offline_aggregation_impl(self._tensor_processor.mean)


class MedianAggregator(OfflineAggregatorBase):
def aggregate(self):
return self._aggregate(self._tensor_processor.median)
def _aggregate_impl(self):
return self._offline_aggregation_impl(self._tensor_processor.median)


class NoOutliersAggregatorBase(OfflineAggregatorBase, ABC):
Expand All @@ -582,7 +593,7 @@ def __init__(
super().__init__(tensor_processor, use_per_sample_stats, num_samples, window_size)
self._quantile = quantile

def _aggregate(self, fn) -> List[NNCFTensor]:
def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
result = self._tensor_processor.no_outliers_map(stacked_val, fn, axis=0, alpha=self._quantile)
return result.tensor
Expand All @@ -595,13 +606,13 @@ def __hash__(self) -> int:


class MeanNoOutliersAggregator(NoOutliersAggregatorBase):
def aggregate(self) -> Any:
return self._aggregate(self._tensor_processor.masked_mean)
def _aggregate_impl(self) -> Any:
return self._offline_aggregation_impl(self._tensor_processor.masked_mean)


class MedianNoOutliersAggregator(NoOutliersAggregatorBase):
def aggregate(self) -> Any:
return self._aggregate(self._tensor_processor.masked_median)
def _aggregate_impl(self) -> Any:
return self._offline_aggregation_impl(self._tensor_processor.masked_median)


AGGREGATORS_MAP = {
Expand Down
3 changes: 3 additions & 0 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def filter_func(point: StatisticPoint) -> bool:
)
assert len(tensor_collectors) == 1
stat = tensor_collectors[0].get_statistics()
if stat.min_values is None or stat.max_values is None:
continue

conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity)
conv_out_cont = ConvParamsContainer(conv_out, model, graph, self._backend_entity)

Expand Down
7 changes: 6 additions & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,10 @@ def filter_func(point: StatisticPoint) -> bool:
for tensor_collector in statistic_points.get_algo_statistics_for_node(
target_node_name, filter_func, self._algorithm_key
):
group_statistics.append(tensor_collector.get_statistics())
statistics = tensor_collector.get_statistics()
if statistics.min_values is None or statistics.max_values is None:
raise RuntimeError(f"Statistics were not collected for the node {target_node_name}")
group_statistics.append(statistics)

unified_values = self._backend_entity.unify_statistics(group_statistics)
for quantization_target_point in unified_scale_group:
Expand Down Expand Up @@ -661,6 +664,8 @@ def filter_func(point: StatisticPoint) -> bool:
half_range = quantization_target_point in quantization_points_overflow_fix
narrow_range = get_quantizer_narrow_range(qconfig, quant_group)
statistics = tensor_collector.get_statistics()
if statistics.min_values is None or statistics.max_values is None:
raise RuntimeError(f"Statistics were not collected for the node {target_node_name}")
parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, narrow_range, half_range)
command = self._backend_entity.create_quantizer_insertion_command(
graph, quantization_target_point, qconfig, parameters
Expand Down
7 changes: 7 additions & 0 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def apply(

for group_id, nodes in tqdm(node_groups.items(), desc="Applying Smooth Quant"):
best_ratio = 0.0
empty_statistic = False
for node_to_smooth in nodes:
source_node, port_id = group_id
activations_value = self._get_statistics_for_node(statistic_points, node_to_smooth.node_name, port_id)
if any(val is None for val in activations_value):
empty_statistic = True
break
activations_value = self._backend_entity.clip_statistics(activations_value)

weights_port = self._backend_entity.get_weight_tensor_port_id(node_to_smooth)
Expand All @@ -126,6 +130,9 @@ def apply(
best_ratio = ratio
best_scale = deepcopy(scales)

if empty_statistic:
continue

activation_scales = self._backend_entity.calculate_activation_scale(best_scale, nodes)
weight_scales = self._backend_entity.calculate_weight_scale(best_scale, nodes)

Expand Down
46 changes: 45 additions & 1 deletion tests/experimental/common/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, num_samples: Optional[int]):
def _register_reduced_input_impl(self, x: TensorType):
return self._container.append(x)

def aggregate(self):
def _aggregate_impl(self):
return self._container[0]


Expand Down Expand Up @@ -269,3 +270,46 @@ def test_multiple_branch_reducer():
assert len(ref_stats) == len(stats)
for key, value in ref_stats.items():
assert value == stats[key]


class TemplateTestStatisticCollector:
@abstractmethod
def get_nncf_tensor_cls(self):
pass

@pytest.mark.parametrize("inplace", [False, True])
@pytest.mark.parametrize("any_not_empty", [False, True])
def test_empty_tensors_register(self, inplace, any_not_empty):
collector = TensorCollector()
reducer = DummyTensorReducer("Dummy", inplace)
aggregator = DummyTensorAggregator(5)
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])]
)

stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] is None

inputs = [full_inputs, empty_inputs, full_inputs] if any_not_empty else [empty_inputs, empty_inputs]
for input_ in inputs:
collector.register_inputs(input_)

if any_not_empty:
assert len(aggregator._container) == 2
assert aggregator._collected_samples == 2
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] == self.get_nncf_tensor_cls()([100])
return

assert len(aggregator._container) == 0
assert aggregator._collected_samples == 0
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] is None
5 changes: 5 additions & 0 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import openvino.runtime as ov
import torch

from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm
from tests.shared.command import Command
Expand Down Expand Up @@ -62,3 +63,7 @@ def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray]) -> No
ref_value = np.array(ref_value)
assert value.shape == ref_value.shape
assert np.all(np.isclose(value, ref_value, atol=0.0001)), f"{value} != {ref_value}"

@staticmethod
def get_matmul_metatype():
return OVMatMulMetatype
47 changes: 4 additions & 43 deletions tests/openvino/native/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.tensor import OVNNCFTensor
from tests.experimental.common.test_statistic_collector import DummyTensorAggregator
from tests.experimental.common.test_statistic_collector import DummyTensorReducer


# pylint:disable=protected-access
def test_empty_tensors_register():
collector = TensorCollector()
reducer = DummyTensorReducer("Dummy")
aggregator = DummyTensorAggregator(5)
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: OVNNCFTensor(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: OVNNCFTensor(np.array([]))}, [(hash(reducer), [input_name])]
)

for inputs in [full_inputs, empty_inputs, full_inputs]:
collector.register_inputs(inputs)
assert len(aggregator._container) == 2
assert aggregator._collected_samples == 2

from tests.experimental.common.test_statistic_collector import TemplateTestStatisticCollector

# pylint:disable=protected-access
def test_empty_inplace_tensors_register():
collector = TensorCollector()
inplace_reducer = DummyTensorReducer("Dummy", True)
aggregator = DummyTensorAggregator(5)
collector.register_statistic_branch("A", inplace_reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: OVNNCFTensor(np.array([100]))}, [(hash(inplace_reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: OVNNCFTensor(np.array([]))}, [(hash(inplace_reducer), [input_name])]
)

for inputs in [full_inputs, empty_inputs, full_inputs]:
collector.register_inputs(inputs)
assert len(aggregator._container) == 2
assert aggregator._collected_samples == 2
class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return OVNNCFTensor
14 changes: 14 additions & 0 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def forward(self, x):
return x_1, x_2


class NonZeroLinearModel(nn.Module):
INPUT_SIZE = [10]

def forward(self, x):
zeros = (x > torch.inf).float()
empty = torch.nonzero(zeros).reshape((-1, 1, 1)).float()
y = torch.matmul(empty, torch.ones((1, 5)))
y += 5
y = torch.cat((torch.ones((1, 10)), y.reshape(1, -1)), dim=1)
y = torch.matmul(y, torch.ones(10, 10))
y += 5
return y


class SplittedModel(nn.Module):
INPUT_SIZE = [1, 3, 28, 28]

Expand Down
21 changes: 17 additions & 4 deletions tests/post_training/test_templates/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,11 @@ def _get_nncf_graph(self, num_biases: int) -> NNCFGraph:
constant_metatype=self.get_constant_metatype(),
).nncf_graph

@pytest.mark.parametrize("empty_statistics", [False, True])
@pytest.mark.parametrize("num_biases", [0, 1, 2])
# pylint: disable=too-many-statements
def test_transformation_layout(self, num_biases, mocker):
# pylint: disable=too-many-branches
def test_transformation_layout(self, empty_statistics, num_biases, mocker):
mocked_transformer = mocker.MagicMock()
self.mock_model_transformer_factory(mocker, mocked_transformer)

Expand Down Expand Up @@ -325,9 +327,12 @@ def f(*args, **kwargs):

algorithm = ChannelAlignment()
tensor_collector = TensorCollector()
tensor_collector.get_statistics = get_constant_lambda(
TestTensorStats(np.array([-1], dtype=np.int32), np.array([2], dtype=np.int32))
)
if empty_statistics:
stat_value = None, None
else:
stat_value = (np.array([-1], dtype=np.int32), np.array([2], dtype=np.int32))

tensor_collector.get_statistics = get_constant_lambda(TestTensorStats(*stat_value))
statistic_points.add_statistic_point(StatisticPoint(target_point, tensor_collector, algorithm._algorithm_key))

class MockBackend(backend_cls):
Expand Down Expand Up @@ -357,6 +362,14 @@ class MockBackend(backend_cls):
)
algorithm.apply(None, nncf_graph, statistic_points)

if empty_statistics:
assert algorithm._align_means.call_count == 0
assert algorithm._align_scales.call_count == 0
mocked_transformer.transform.assert_called_once()
arg = mocked_transformer.transform.call_args.args[0]
assert len(arg.transformations) == 0
return

align_means_called = 1 if num_biases == 2 else 0
assert algorithm._align_means.call_count == align_means_called
if align_means_called:
Expand Down
Loading

0 comments on commit 182903c

Please sign in to comment.