Skip to content

Commit

Permalink
[ONNX] Fix sporadic results in BC (openvinotoolkit#3081)
Browse files Browse the repository at this point in the history
### Changes

1. This PR addresses an issue using `ONNXRuntime==1.19.2` where a tensor
used as both an input and output in a model shares the same memory. This
causes unexpected behavior: updating the input tensor inadvertently
modifies the statistics data due to memory overlap.
The issue was confirmed by calling
`np.shares_memory(input_data['image'], outputs['image'])`, which
returned `True`, indicating that the input and output tensors share
memory. After applying the proposed changes, the same check now returns
`False`, confirming that memory sharing is resolved.
To fix this, the `ONNXEngine` logic has been updated to create a copy of
any output tensor that is also used as a model input. This ensures that
the input tensor and statistics data remain independent, avoiding
unintended side effects.

2. Merge RawReducer and NoopReducer
3. Minor fixes (remove warnings + fix bug in BC)


### Reason for changes

Regression

### Related tickets

156025

### Tests

PTQ run 549
  • Loading branch information
kshpv authored and daniil-lyakhov committed Dec 2, 2024
1 parent 7c8c116 commit 1bce93a
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 35 deletions.
12 changes: 2 additions & 10 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,25 +412,17 @@ def __init__(self, tensor_collectors: List[TensorCollector]) -> None:
##################################################


class NoopReducer(TensorReducerBase):
class RawReducer(TensorReducerBase):
def __init__(self):
super().__init__(inplace=False)

def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
return None

def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
return x


class RawReducer(NoopReducer):
def __init__(self):
super().__init__()

def __call__(self, x: List[Tensor]):
return self._reduce_out_of_place(x)


class ShapeReducer(TensorReducerBase):
def __init__(self, inplace: bool = False):
super().__init__(inplace=inplace)
Expand Down
10 changes: 8 additions & 2 deletions nncf/onnx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def infer(self, input_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
:param input_data: inputs for the model
:return output_data: models outputs
"""
output_tensors = self.sess.run([], {k: v for k, v in input_data.items() if k in self.input_names})
output_tensors = self.sess.run([], input_data)
model_outputs = self.sess.get_outputs()

return {output.name: tensor for tensor, output in zip(output_tensors, model_outputs)}
outputs_safe = {}
for tensor, output in zip(output_tensors, model_outputs):
# Workaround for https://github.com/microsoft/onnxruntime/issues/21922
# After fixing this copying should be removed
outputs_safe[output.name] = tensor.copy() if output.name in self.input_names else tensor

return outputs_safe
2 changes: 1 addition & 1 deletion nncf/onnx/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ONNXModelTransformer(ModelTransformer):
def __init__(self, model: onnx.ModelProto):
infered_model = onnx.shape_inference.infer_shapes(model)
super().__init__(infered_model)
self.onnx_model_extractor = onnx.utils.Extractor(self._model)
self.onnx_model_extractor = onnx.utils.Extractor(infered_model)

def _get_target_edge(
self,
Expand Down
5 changes: 2 additions & 3 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from nncf.experimental.common.tensor_statistics.collectors import MeanReducer
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
Expand Down Expand Up @@ -48,7 +47,7 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer(inplace)
else:
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
noop_reducer = NoopReducer()
raw_reducer = RawReducer()

kwargs = {
"num_samples": num_samples,
Expand All @@ -60,7 +59,7 @@ def get_mean_statistic_collector(

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
return collector


Expand Down
5 changes: 2 additions & 3 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
Expand Down Expand Up @@ -128,7 +127,7 @@ def get_mean_statistic_collector(
reducer = OVBatchMeanReducer(inplace)
else:
reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
noop_reducer = NoopReducer()
raw_reducer = RawReducer()

kwargs = {
"num_samples": num_samples,
Expand All @@ -139,7 +138,7 @@ def get_mean_statistic_collector(

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
return collector


Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Te
"""
bias_shift_magnitude = fns.max(
fns.abs(
(updated_bias_value - current_bias_value) / (current_bias_value + fns.finfo(current_bias_value).min)
(updated_bias_value - current_bias_value) / (current_bias_value + fns.finfo(current_bias_value).eps)
)
)
return bias_shift_magnitude
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import HessianTensorStatistic
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
Expand Down Expand Up @@ -252,7 +252,7 @@ def scale_insertion_command(source_node, next_nodes, source_node_output_port, sc
class MixedPrecisionAlgoBackend(ABC):
@staticmethod
def hawq_statistic_collector(subset_size: Optional[int] = None) -> TensorCollector:
reducer = NoopReducer()
reducer = RawReducer()
aggregator = HAWQAggregator(num_samples=subset_size)
collector = TensorCollector(HessianTensorStatistic)
collector.register_statistic_branch(HessianTensorStatistic.HESSIAN_INPUT_ACTIVATION_STATS, reducer, aggregator)
Expand Down
4 changes: 3 additions & 1 deletion nncf/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,13 @@ def tune_range(
fval = -left_border * s
qval = fns.round(fval)

ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border)
with warnings.catch_warnings():
# If `qval` is 0 `rb` will equal `right_border`, and we don't want to show an unnecessary division by 0 warning
# The same for (qval - level_high)
warnings.simplefilter("ignore")
ra_then_result = qval / (qval - level_high) * right_border
rb_then_result = (qval - level_high) / qval * left_border
ra = fns.where(qval < level_high, ra_then_result, left_border)
rb = fns.where(qval > 0.0, rb_then_result, right_border)

range_a = right_border - ra
Expand Down
7 changes: 3 additions & 4 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
Expand Down Expand Up @@ -246,7 +245,7 @@ def _get_collection_without_reduction(
:return: Target statistic collector.
"""
tensor_collector = TensorCollector(statistic_cls)
reducer = NoopReducer()
reducer = RawReducer()
aggregation_axes = list(set(list(aggregation_axes) + [dim + 1 for dim in reduction_axes]))
aggregator = aggregator_cls(
aggregation_axes=aggregation_axes,
Expand Down Expand Up @@ -311,7 +310,7 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer()
else:
reducer = MeanPerChReducer(channel_axis=channel_axis)
noop_reducer = NoopReducer()
raw_reducer = RawReducer()

kwargs = {
"num_samples": num_samples,
Expand All @@ -322,7 +321,7 @@ def get_mean_statistic_collector(

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
return collector


Expand Down
7 changes: 3 additions & 4 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from nncf.experimental.common.tensor_statistics.collectors import MedianNoOutliersAggregator
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
Expand Down Expand Up @@ -173,19 +172,19 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non
def cast_tensor(self, tensor, dtype: Dtype):
pass

@pytest.mark.parametrize("reducer_cls", [NoopReducer, RawReducer])
@pytest.mark.parametrize("reducer_cls", [RawReducer])
@pytest.mark.parametrize("input_data", [np.arange(24).reshape((1, 2, 3, 4)), np.array([])])
def test_other_reducers(self, reducer_cls, input_data):
reducer = reducer_cls()
tensor_data = self.get_nncf_tensor(input_data)
reduced_input = reducer([tensor_data])
if reducer_cls == NoopReducer and tensor_data.isempty():
if tensor_data.isempty():
assert reduced_input is None
else:
assert len(reduced_input) == 1
assert fns.allclose(reduced_input[0], tensor_data)

@pytest.mark.parametrize("reducer_cls", [NoopReducer, RawReducer, ShapeReducer])
@pytest.mark.parametrize("reducer_cls", [RawReducer, ShapeReducer])
def test_other_reducers_name_hash_equal(self, reducer_cls):
reducers_instances = [reducer_cls() for _ in range(2)]
assert hash(reducers_instances[0]) == hash(reducers_instances[1])
Expand Down
12 changes: 10 additions & 2 deletions tests/common/quantization/test_tune_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
import warnings

import numpy as np
import pytest

from nncf.quantization.fake_quantize import tune_range
from nncf.tensor import Tensor


def test_tune_range_zero_division_warning():
@pytest.mark.parametrize(
"params",
(
(Tensor(np.array([0.0])), Tensor(np.array([1.0])), 8, False),
(Tensor(np.array([-1.0])), Tensor(np.array([0.0])), 8, False),
),
)
def test_tune_range_zero_division_warning(params):
with warnings.catch_warnings(record=True) as w:
# Calling tune_range should not raise a warning
tune_range(Tensor(np.array([0.0])), Tensor(np.array([1.0])), 8, False)
tune_range(*params)
assert len(w) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector


Expand Down Expand Up @@ -58,7 +58,7 @@ def check_reducer(self, collector: TensorCollector, expected_reducer_type):
@pytest.mark.parametrize(
"algo_func, aggregator_type, reducer_type",
[
("get_hawq_with_backend", HAWQAggregator, NoopReducer),
("get_hawq_with_backend", HAWQAggregator, RawReducer),
("get_mean_variance_with_backend", MeanAggregator, MeanVarianceReducer),
("get_max_variance_with_backend", MeanAggregator, MaxVarianceReducer),
("get_mean_max_with_backend", MeanAggregator, MeanAbsMaxReducer),
Expand Down

0 comments on commit 1bce93a

Please sign in to comment.