Skip to content

Commit

Permalink
reduction_shape -> reduction_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 26, 2023
1 parent bdfd115 commit 43c29f7
Show file tree
Hide file tree
Showing 25 changed files with 84 additions and 84 deletions.
26 changes: 13 additions & 13 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,37 +200,37 @@ def get_reduce_op(node: ov.Node, output_port_id: int) -> ov.Node:
return get_reduce_op


def get_inplace_min_op(node_name: str, reduction_shape: Tuple[int, ...]) -> InplaceInsertionFnType:
def get_inplace_min_op(node_name: str, reduction_axes: Tuple[int, ...]) -> InplaceInsertionFnType:
"""
Returns inplace min function that adds reduce min node to a passed node.
:param node_name: Min reduce node name.
:param reduction_shape: Target reduction axes for the reduction node.
:param reduction_axes: Target reduction axes for the reduction node.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_min, node_name, reduction_shape, False)
return get_inplace_reduce_op(opset.reduce_min, node_name, reduction_axes, False)


def get_inplace_max_op(node_name: str, reduction_shape: Tuple[int, ...], use_abs_max: bool) -> InplaceInsertionFnType:
def get_inplace_max_op(node_name: str, reduction_axes: Tuple[int, ...], use_abs_max: bool) -> InplaceInsertionFnType:
"""
Returns inplace max function that adds reduce max node to a passed node.
:param node_name: Max reduce node name.
:param reduction_shape: Target reduction axes for the reduction node.
:param reduction_axes: Target reduction axes for the reduction node.
:param use_abs: Wheather reduce absolute values of input tensors or not.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_max, node_name, reduction_shape, use_abs_max)
return get_inplace_reduce_op(opset.reduce_max, node_name, reduction_axes, use_abs_max)


def get_inplace_mean_op(node_name: str, reduction_shape: Tuple[int, ...]) -> InplaceInsertionFnType:
def get_inplace_mean_op(node_name: str, reduction_axes: Tuple[int, ...]) -> InplaceInsertionFnType:
"""
Returns inplace mean function that adds reduce mean node to a passed node.
:param node_name: Mean reduce node name.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_mean, node_name, reduction_shape, False)
return get_inplace_reduce_op(opset.reduce_mean, node_name, reduction_axes, False)


def get_inplace_batch_mean_op(node_name: str) -> InplaceInsertionFnType:
Expand Down Expand Up @@ -373,18 +373,18 @@ def get_matmul_channel_axes(weights_port_id: int, ndims: int, transpose: bool) -
return channel_axes


def get_channel_agnostic_reduction_shape(channel_axes: List[int], shape: List[int]) -> Tuple[int]:
def get_channel_agnostic_reduction_axes(channel_axes: List[int], shape: List[int]) -> Tuple[int]:
"""
Returns filtered reduction shape without axes that corresponds channels.
Returns filtered reduction axes without axes that corresponds channels.
:param channel_axes: List of the channel axes.
:param shape: Shape that need to be filtered.
:return: Reduction shape in tuple format.
"""
reduction_shape = list(range(len(shape)))
reduction_axes = list(range(len(shape)))
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_shape[channel_axis]
return tuple(reduction_shape)
del reduction_axes[channel_axis]
return tuple(reduction_axes)


def create_bias_tensor(node_without_bias: NNCFNode, graph: NNCFGraph, value: Any) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
TargetType.POST_LAYER_OPERATION, node_name, port_id=OUTPUT_PORT_OF_NODE
)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_shape=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
reduction_axes=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: TargetPoint) -
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_shape: Channel axis for the statistics aggregation.
:param reduction_axes: Channel axis for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/bias_correction/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: ONNXTargetPoin

@staticmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size)
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> ONNXMeanStatisticCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: OVTargetPoint)

@staticmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace)
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector:
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,11 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
for conv_in, add_in, _ in self._get_node_pairs(graph):
target_point, node_in = self._get_target_point_and_node_in(conv_in, add_in)
channel_axis = conv_in.metatype.output_channel_axis
reduction_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape)))
reduction_shape.remove(channel_axis)
reduction_axes = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape)))
reduction_axes.remove(channel_axis)

statistic_collector = self._backend_entity.get_statistic_collector(
tuple(reduction_shape), self._quantile, self.subset_size, self.inplace_statistics
tuple(reduction_axes), self._quantile, self.subset_size, self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/channel_alignment/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def get_weights_port_ids_for_node(node: NNCFNode) -> Tuple[int, int]:
@staticmethod
@abstractmethod
def get_statistic_collector(
reduction_shape, q: float, num_samples: int, inplace: bool
reduction_axes, q: float, num_samples: int, inplace: bool
) -> TensorStatisticCollectorBase:
"""
Get backend-specific tensor collector that collects medians of minimal and maximal quantiles.
:param reduction_shape: Target reduction shape for the reduction.
:param reduction_axes: Target reduction shape for the reduction.
:param q: Minimal quantile for the tensor collector.
:param num_samples: Num samples to collect by the tensor collector.
:param inplace: Should statistic be calculated inplace or out of place.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def get_add_metatypes():

@staticmethod
def get_statistic_collector(
reduction_shape, q: float, num_samples: int, inplace: bool
reduction_axes, q: float, num_samples: int, inplace: bool
) -> TensorStatisticCollectorBase:
tensor_collector = TensorCollector(OVMinMaxTensorStatistic)
quantile_reducer = OVQuantileReducer(reduction_shape, (q, 1 - q), inplace)
quantile_reducer = OVQuantileReducer(reduction_axes, (q, 1 - q), inplace)

for port_id, container_key in enumerate([OVMinMaxTensorStatistic.MIN_STAT, OVMinMaxTensorStatistic.MAX_STAT]):
aggregator = MedianAggregator(OVNNCFCollectorTensorProcessor, num_samples=num_samples)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _add_statistic_point(self, container: StatisticPointsContainer, point: Targe
:param axis: Channel axis for the statistics calculation.
"""
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_shape=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
reduction_axes=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
container.add_statistic_point(
StatisticPoint(target_point=point, tensor_collector=stat_collector, algorithm=self._algorithm_key)
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/fast_bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> Transform
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_shape: Channel axes for the statistics aggregation.
:param reduction_axes: Channel axes for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> ONNXModel

@staticmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_shape, num_samples, window_size)
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: onnx.ModelProto) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelEx

@staticmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_shape, window_size, inplace)
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)

@staticmethod
def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> PTModelEx

@staticmethod
def mean_statistic_collector(
reduction_shape: ReductionAxes,
reduction_axes: ReductionAxes,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statisitic_collector(num_samples, reduction_shape, window_size)
return get_mean_statisitic_collector(num_samples, reduction_axes, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
Expand Down
12 changes: 6 additions & 6 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_axes
from nncf.openvino.graph.node_utils import get_weight_channel_axes
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
Expand Down Expand Up @@ -122,7 +122,7 @@ def unify_statistics(statistics: List[OVMinMaxTensorStatistic]) -> OVMinMaxTenso
return OVMinMaxTensorStatistic(min_values=min_values, max_values=max_values)

@staticmethod
def _get_reduction_shape_and_use_abs_max(
def _get_reduction_axes_and_use_abs_max(
nncf_graph: NNCFGraph, target_point: OVTargetPoint, quantizer_config: QuantizerConfig
) -> Tuple[ReductionAxes, bool]:
use_abs_max = quantizer_config.mode == QuantizationMode.SYMMETRIC
Expand All @@ -140,15 +140,15 @@ def _get_reduction_shape_and_use_abs_max(

# TODO (l-bat): Disable quantizer propogation through layout changing operations
channel_axis = 1 # OpenVINO activations have channel first layout: [N, C, Z, Y, X]
axes = get_channel_agnostic_reduction_shape([channel_axis], shape)
axes = get_channel_agnostic_reduction_axes([channel_axis], shape)
return axes, use_abs_max

assert isinstance(node.layer_attributes, OVLayerAttributes)
const_shape = node.layer_attributes.constant_attributes[target_point.port_id]["shape"]

if quantizer_config.per_channel:
channel_axes = get_weight_channel_axes(node, target_point.port_id)
axes = get_channel_agnostic_reduction_shape(channel_axes, const_shape)
axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape)
else:
axes = tuple(range(len(const_shape)))
return axes, use_abs_max
Expand All @@ -162,7 +162,7 @@ def get_statistic_collector(
inplace: bool,
num_samples: int = None,
) -> TensorCollector:
reduction_shape, use_abs_max = OVMinMaxAlgoBackend._get_reduction_shape_and_use_abs_max(
reduction_axes, use_abs_max = OVMinMaxAlgoBackend._get_reduction_axes_and_use_abs_max(
nncf_graph, target_point, quantizer_config
)

Expand All @@ -181,7 +181,7 @@ def get_statistic_collector(
f"Aggregator type: {params.aggregator_type} is not supported for OpenVino PTQ backend yet."
)

kwargs = {"reduction_axes": reduction_shape, "inplace": inplace}
kwargs = {"reduction_axes": reduction_axes, "inplace": inplace}
if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]:
if container_key == OVMinMaxTensorStatistic.MIN_STAT:
quantile = params.quantile_outlier_prob
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,11 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
:return: Calculated reduction axes.
"""
shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape
reduction_shape = tuple([0])
reduction_axes = tuple([0])
if len(shape) > 1:
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port)
reduction_shape = self._backend_entity.get_channel_agnostic_reduction_shape(channel_axis, shape)
return reduction_shape
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape)
return reduction_axes

def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id: int) -> TTensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]

@staticmethod
@abstractmethod
def get_channel_agnostic_reduction_shape(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
"""
Returns filtered reduction shape without axes that corresponds channels.
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_axes
from nncf.openvino.graph.node_utils import get_weight_value
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
Expand Down Expand Up @@ -61,8 +61,8 @@ def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]
return {"activation": activation_ports[0], "weight": weight_ports[0]}

@staticmethod
def get_channel_agnostic_reduction_shape(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
return get_channel_agnostic_reduction_shape([channel_axis], shape)
def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
return get_channel_agnostic_reduction_axes([channel_axis], shape)

@staticmethod
def get_abs_max_channel_collector(
Expand Down
13 changes: 7 additions & 6 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.config.schemata.algo.quantization import RANGE_INIT_TYPES_VS_DESCRIPTIONS
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.initialization import DataLoaderBaseRunner
from nncf.torch.nncf_network import NNCFNetwork
Expand Down Expand Up @@ -104,25 +105,25 @@ def __init__(
self._input_shape = input_shape
self._channel_idx = channel_idx

def get_reduction_axes(self, per_sample_stats) -> ReductionAxes:
def get_reduction_axes(self, per_sample_stats: bool) -> ReductionAxes:
"""
Calculates the reduction axes of the tensor.
:param per_sample_stats: Boolean flag that indicated whether statistics are collected per-sample or per-batch.
:return: Shape to reduce to.
"""
ndims = len(self._input_shape)
reduction_shape = list(range(ndims)) # type: List[int]
reduction_axes = list(range(ndims)) # type: List[int]
if self._per_channel:
val = (ndims + self._channel_idx) % ndims
reduction_shape.remove(val)
reduction_axes.remove(val)
if not val and self.use_per_sample_stats(per_sample_stats):
raise RuntimeError("Batch dimension should be equal to zero")
if self.use_per_sample_stats(per_sample_stats):
reduction_shape = reduction_shape[1:] # Assumes batch is the first dimension
return tuple(reduction_shape)
reduction_axes = reduction_axes[1:] # Assumes batch is the first dimension
return tuple(reduction_axes)

def get_aggregation_axes(self, per_sample_stats) -> Tuple[int, ...]:
def get_aggregation_axes(self, per_sample_stats: bool) -> AggregationAxes:
"""
Calculates the aggregation axes of the tensor.
Expand Down
Loading

0 comments on commit 43c29f7

Please sign in to comment.