Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 6, 2023
1 parent 4a24b66 commit 0de0502
Show file tree
Hide file tree
Showing 16 changed files with 341 additions and 311 deletions.
17 changes: 9 additions & 8 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class TensorReducerBase(ABC):
the specified rule. Could handle tensors inplace or out of place.
"""

def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True):
def __init__(self, reduction_axes: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True):
"""
:param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape)))
if empty.
:param inplace: Whether should be calculated inplace or out of place.
:param keepdims: Should the axes which are reduced are left in the result
as dimensions with size one or not.
"""
self._reduction_shape = reduction_shape
self._reduction_shape = reduction_axes
self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor()
self._inplace = inplace
self._keepdims = keepdims
Expand Down Expand Up @@ -469,12 +469,12 @@ def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
class QuantileReducerBase(TensorReducerBase):
def __init__(
self,
reduction_shape: Optional[ReductionShape] = None,
reduction_axes: Optional[ReductionShape] = None,
quantile: Optional[Union[float, Tuple[float]]] = None,
inplace: bool = False,
keepdims: bool = True,
):
super().__init__(reduction_shape, False, keepdims)
super().__init__(reduction_axes, False, keepdims)
self._quantile = (0.01, 0.99) if quantile is None else quantile

def __eq__(self, __o: object) -> bool:
Expand All @@ -494,11 +494,11 @@ def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
class AbsQuantileReducer(QuantileReducerBase):
def __init__(
self,
reduction_shape: Optional[ReductionShape] = None,
reduction_axes: Optional[ReductionShape] = None,
quantile: Union[float, List[float]] = 0.99,
inplace: bool = False,
):
super().__init__(reduction_shape, quantile, False)
super().__init__(reduction_axes, quantile, False)

def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = self._tensor_processor.abs(x[0])
Expand Down Expand Up @@ -624,8 +624,9 @@ def __init__(
window_size=None,
quantile: float = 0.01,
):
assert len(aggregation_axes) == 1
super().__init__(
tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples
tensor_processor, aggregation_axes=aggregation_axes[0], keepdims=keepdims, num_samples=num_samples
)
self._window_size = window_size
self._container = deque(maxlen=window_size)
Expand Down Expand Up @@ -707,7 +708,7 @@ def _aggregate_impl(self) -> Any:
return retval


class PostAggregateAggregatorHook(TensorAggregatorBase, ABC):
class PostAggregateHook(TensorAggregatorBase, ABC):
def __init__(self, aggregator: TensorAggregatorBase, post_aggregation_hook):
super().__init__(None)
self._aggregator = aggregator
Expand Down
10 changes: 5 additions & 5 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _get_processor(self):
return OVNNCFCollectorTensorProcessor

def get_inplace_fn(self):
return get_inplace_min_op(self.name, self._reduction_shape)
return get_inplace_min_op(self.name, self._reduction_axes)

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)
Expand All @@ -168,7 +168,7 @@ def _get_processor(self):
return OVNNCFCollectorTensorProcessor

def get_inplace_fn(self):
return get_inplace_max_op(self.name, self._reduction_shape, False)
return get_inplace_max_op(self.name, self._reduction_axes, False)

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)
Expand All @@ -179,7 +179,7 @@ def _get_processor(self):
return OVNNCFCollectorTensorProcessor

def get_inplace_fn(self):
return get_inplace_max_op(self.name, self._reduction_shape, True)
return get_inplace_max_op(self.name, self._reduction_axes, True)

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)
Expand All @@ -190,7 +190,7 @@ def _get_processor(self):
return OVNNCFCollectorTensorProcessor

def get_inplace_fn(self):
return get_inplace_mean_op(self.name, self._reduction_shape)
return get_inplace_mean_op(self.name, self._reduction_axes)

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)
Expand All @@ -212,7 +212,7 @@ def _get_processor(self):
return OVNNCFCollectorTensorProcessor

def get_inplace_fn(self):
return get_inplace_mean_per_ch(self.name, self._reduction_shape)
return get_inplace_mean_per_ch(self.name, self._reduction_axes)

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionShape
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.torch.graph.transformations.command_creation import create_bias_correction_command
Expand All @@ -32,8 +33,8 @@
from nncf.torch.model_analyzer import is_quantized_weights
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.collectors import PTMeanStatisticCollector
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.tensor_statistics.collectors import get_mean_stat_collector


@ALGO_BACKENDS.register(BackendType.TORCH)
Expand Down Expand Up @@ -71,8 +72,8 @@ def mean_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> PTMeanStatisticCollector:
return PTMeanStatisticCollector(reduction_shape, num_samples, window_size)
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_shape, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
Expand Down
124 changes: 69 additions & 55 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AggregatorType
Expand All @@ -49,8 +52,9 @@
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import PTQuantizerSpec
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector
from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic


Expand Down Expand Up @@ -155,32 +159,61 @@ def get_statistic_collector(
quantizer_config: QuantizerConfig,
inplace: bool,
num_samples: int = None,
) -> Union[PTMinMaxStatisticCollector, PTMeanMinMaxStatisticCollector]:
if (
range_estimator_params.min.statistics_type == StatisticsType.MIN
and range_estimator_params.min.aggregator_type == AggregatorType.MIN
and range_estimator_params.max.statistics_type == StatisticsType.MAX
and range_estimator_params.max.aggregator_type == AggregatorType.MAX
) -> TensorCollector:
collector_params = PTMinMaxAlgoBackend._default_collector_params(nncf_graph, target_point, quantizer_config)
collector_kwargs = collector_params.convert_statistic_params(per_sample_stats=False)

collector = TensorCollector(PTMinMaxTensorStatistic)
for params, container_key in zip(
[range_estimator_params.min, range_estimator_params.max],
[PTMinMaxTensorStatistic.MIN_STAT, PTMinMaxTensorStatistic.MAX_STAT],
):
collector_name = "min_max"

elif (
range_estimator_params.min.statistics_type == StatisticsType.MIN
and range_estimator_params.min.aggregator_type == AggregatorType.MEAN
and range_estimator_params.max.statistics_type == StatisticsType.MAX
and range_estimator_params.max.aggregator_type == AggregatorType.MEAN
):
collector_name = "mean_min_max"

else:
raise RuntimeError(
"The following range estimator parameters are not supported by PyTorch backend by now: "
f"{str(range_estimator_params)}"
)

return PTMinMaxAlgoBackend._statistic_collector_builder(
collector_name, nncf_graph, target_point, quantizer_config, num_samples
)
if not params.statistics_type in PT_REDUCERS_MAP:
raise RuntimeError(
f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet."
)

if not params.aggregator_type in AGGREGATORS_MAP:
raise RuntimeError(
f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet."
)

kwargs = {
"reduction_axes": collector_kwargs["reducers_axes"],
"keepdims": collector_kwargs["reducers_keepdims"],
}
if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]:
if container_key == PTMinMaxTensorStatistic.MIN_STAT:
quantile = params.quantile_outlier_prob
else:
quantile = 1 - params.quantile_outlier_prob
kwargs.update({"quantile": [quantile]})
# TODO(dlyakhov): merge two quantile aggregators in one

statistic_type = params.statistics_type
if collector_params.use_abs_max and statistic_type == StatisticsType.MAX:
statistic_type = StatisticsType.ABS_MAX
reducer = PT_REDUCERS_MAP[statistic_type](**kwargs)

kwargs = {
"aggregation_axes": collector_kwargs["aggregators_axes"],
"keepdims": collector_kwargs["aggregators_keepdims"],
"num_samples": num_samples,
"tensor_processor": PTNNCFCollectorTensorProcessor,
}
aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs)

if collector_kwargs["squeeze_dims"] is not None:

def post_aggregation_hook(aggregated_value):
return PTNNCFCollectorTensorProcessor.squeeze(
PTNNCFTensor(aggregated_value), dim=collector_kwargs["squeeze_dims"]
).tensor

aggregator = PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook)

collector.register_statistic_branch(container_key, reducer, aggregator)
return collector

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
Expand Down Expand Up @@ -223,37 +256,18 @@ def _get_input_scale_shape(
return input_shape, scale_shape, channel_idx

@staticmethod
def _default_collector_params_and_scale_shape(
def _default_collector_params(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, quantizer_config: QuantizerConfig
) -> Tuple[PTRangeInitCollectorParams, Tuple[int, ...]]:
input_shape, scale_shape, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config
)
return (
PTRangeInitCollectorParams(
is_weights=target_point.is_weight_target_point(),
mode=quantizer_config.mode,
per_channel=quantizer_config.per_channel,
input_shape=input_shape,
channel_idx=channel_idx,
),
scale_shape,
)

@staticmethod
def _statistic_collector_builder(
collector_name: str,
nncf_graph: NNCFGraph,
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
num_samples: int = None,
) -> PTMeanMinMaxStatisticCollector:
collector_params, scale_shape = PTMinMaxAlgoBackend._default_collector_params_and_scale_shape(
) -> PTRangeInitCollectorParams:
input_shape, _, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config
)
init_config = RangeInitConfig(collector_name, num_samples)
return StatCollectorGenerator.generate_stat_collector_for_range_init_config(
init_config, scale_shape, collector_params, num_samples
return PTRangeInitCollectorParams(
is_weights=target_point.is_weight_target_point(),
mode=quantizer_config.mode,
per_channel=quantizer_config.per_channel,
input_shape=input_shape,
channel_idx=channel_idx,
)

@staticmethod
Expand Down
3 changes: 0 additions & 3 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
from nncf.torch.quantization.translator import PTTargetPointTranslator
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint
from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector
from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector
from nncf.torch.tensor_statistics.collectors import PTPercentileStatisticCollector
from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector
from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector
from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector
Expand Down
10 changes: 9 additions & 1 deletion nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,22 @@ def _get_transformation_layout_extra_outputs(
) -> TransformationLayout:
transformation_layout = TransformationLayout()
transformation_commands = []

def register_inputs_fn(fn):
def register_inputs(input_: torch.Tensor):
fn(PTNNCFTensor(input_))
return input_

return register_inputs

for _statistic_points in statistic_points.values():
for _statistic_point in _statistic_points:
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
transformation_commands.append(
PTInsertionCommand(
_statistic_point.target_point,
collector.register_input,
register_inputs_fn(collector.register_unnamed_inputs),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
)
)
Expand Down
Loading

0 comments on commit 0de0502

Please sign in to comment.