Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Fix weights statistics were empty

Fix statistic collection and external quantizer hook
  • Loading branch information
daniil-lyakhov committed Nov 20, 2023
1 parent 13e794b commit 22a58ca
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 77 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TransformationPriority(IntEnum):
FP32_TENSOR_STATISTICS_OBSERVATION = 1
PRUNING_PRIORITY = 2
SPARSIFICATION_PRIORITY = 3
OP_INSERTION_PRIORITY = 4
QUANTIZATION_PRIORITY = 11


Expand Down
71 changes: 34 additions & 37 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType

# from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO]
return [BackendType.OPENVINO, BackendType.TORCH]

def _set_backend_entity(self, model: TModel) -> None:
"""
Expand All @@ -90,6 +91,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend

self._backend_entity = OVSmoothQuantAlgoBackend()
elif model_backend == BackendType.TORCH:
from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend

self._backend_entity = PTSmoothQuantAlgoBackend()
else:
raise RuntimeError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down Expand Up @@ -123,11 +128,11 @@ def apply(
if any(val is None for val in activations_value):
empty_statistic = True
break
activations_value = self._backend_entity.clip_statistics(activations_value)
assert len(activations_value) == 1
activations_value = self._backend_entity.clip_statistics(activations_value[0])

weight_port = self._backend_entity.get_weight_tensor_port_id(node_to_smooth)
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model, weight_port)
weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, weight_port)
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, graph)
weight_statistics = self._backend_entity.clip_statistics(weight_statistics)

alpha = alpha_map[node_to_smooth.metatype]
Expand All @@ -153,13 +158,12 @@ def apply(
continue

for node_to_smooth in nodes:
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth)
weight_port = self._backend_entity.get_weight_tensor_port_id(node_to_smooth)
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model, weight_port)
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value, graph)
### TODO: DO it as NNCFTensor op
scaled_weight = weight_value * weights_scale
weight_update_command = self._backend_entity.weight_update_command(
node_to_smooth, scaled_weight, weight_port
)
###
weight_update_command = self._backend_entity.weight_update_command(node_to_smooth, scaled_weight)
transformation_layout.register(weight_update_command)

activations_shape = graph.get_output_edges(source_node)[source_output_port_id].tensor_shape
Expand Down Expand Up @@ -208,16 +212,11 @@ def _get_statistics_for_node(
:return: List of the TTensor instances.
"""

def filter_func(point: StatisticPoint) -> bool:
return (
self._algorithm_key in point.algorithm_to_tensor_collectors
and point.target_point.type == TargetType.PRE_LAYER_OPERATION
and point.target_point.port_id == act_port
)

statistics_for_node = []
for tensor_collector in statistic_points.get_algo_statistics_for_node(
node_name, filter_func, self._algorithm_key
node_name,
self._backend_entity.get_filter_fn_for_statistics(act_port),
self._algorithm_key,
):
statistics_for_node.append(tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY])
return statistics_for_node
Expand All @@ -233,7 +232,6 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
for node_data in nodes_to_smooth_data:
node_to_smooth = node_data["node_to_smooth"]
target_point = self._backend_entity.target_point(
TargetType.PRE_LAYER_OPERATION,
target_node_name=node_to_smooth.node_name,
port_id=node_data["input_act_port"],
)
Expand Down Expand Up @@ -267,10 +265,9 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[
if not self._backend_entity.is_node_with_weights(node_with_weight):
continue

ports_map = self._backend_entity.get_input_ports_map(node_with_weight, nncf_graph)
activation_port_id = self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph)
input_edges = nncf_graph.get_input_edges(node_with_weight)
weight_node = input_edges[ports_map["weight"]].from_node
activation_node = input_edges[ports_map["activation"]].from_node
activation_node = input_edges[activation_port_id].from_node

# Skipping agnostic layers as inputs to propagate quantizer
# Only for Convolution layers
Expand All @@ -281,13 +278,13 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[
continue

# Skipping shared weights
if len(nncf_graph.get_next_nodes(weight_node)) > 1:
if self._backend_entity.is_node_with_shared_weight(node_with_weight, nncf_graph):
continue

nodes_to_smooth_data.append(
{
"node_to_smooth": node_with_weight,
"input_act_port": ports_map["activation"],
"input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph),
}
)
return nodes_to_smooth_data
Expand All @@ -303,11 +300,10 @@ def _calculate_activation_scale(
:param nodes: List of consumers for Smooth node.
:return: Calculated per-channel activation scale.
"""
activation_ports_map = {
node: self._backend_entity.get_input_ports_map(node, nncf_graph)["activation"] for node in nodes
}
activation_ports_map = {node: self._backend_entity.get_activations_port_id(node, nncf_graph) for node in nodes}
channel_axes = [
self._backend_entity.get_activation_channel_axis(node, port) for node, port in activation_ports_map.items()
self._backend_entity.get_activation_channel_axis(node, port, activations_shape)
for node, port in activation_ports_map.items()
]
channel_axis = channel_axes[0]

Expand All @@ -317,18 +313,19 @@ def _calculate_activation_scale(
activations_size = len(activations_shape)
return self._backend_entity.calculate_activation_scale(scale_value, activations_size, channel_axis)

def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode) -> TTensor:
def _calculate_weight_scale(
self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor, graph: NNCFGraph
) -> TTensor:
"""
Calculates scale for weight tensor.
:param scale_value: Base scale value.
:param node: Consumer for Smooth node.
:return: Calculated scale for weights.
"""
port_id = self._backend_entity.get_weight_tensor_port_id(node)
weights_size = len(node.layer_attributes.constant_attributes[port_id]["shape"])
weights_size = len(weights_value.shape)
if weights_size > 1:
channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id)
channel_axis = self._backend_entity.get_weight_channel_axis(node, graph)
return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis)
return scale_value

Expand All @@ -344,11 +341,11 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape
reduction_axes = tuple([])
if len(shape) > 1:
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port)
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port, shape)
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape)
return reduction_axes

def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id: int) -> TTensor:
def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NNCFGraph) -> TTensor:
"""
Returns processed weight statistics for node.
Expand All @@ -359,7 +356,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id:
"""
channel_axis = 0
if len(weights.shape) > 1:
channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id)
channel_axis = self._backend_entity.get_weight_channel_axis(node, graph)
reduction_shape = [i for i, _ in enumerate(weights.shape)]
reduction_shape.pop(channel_axis)
return self._backend_entity.process_weight_statistics(weights, tuple(reduction_shape))
Expand Down
22 changes: 15 additions & 7 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from abc import ABC
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector

Expand Down Expand Up @@ -55,11 +54,10 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:

@staticmethod
@abstractmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
def target_point(TargetType, target_node_name: str, port_id: int) -> TargetPoint:
"""
Returns backend-specific target point.
:param target_type: Type of the location that should be modified.
:param target_node_name: Name of the located node.
:param port_id: Port ID of the tensor for the statistics distribution.
:return: Backend-specific TargetPoint.
Expand All @@ -77,7 +75,7 @@ def is_node_with_weights(node: NNCFNode) -> bool:

@staticmethod
@abstractmethod
def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]:
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
"""
Returns map with activation & weighted ports.
Expand Down Expand Up @@ -224,7 +222,7 @@ def scale_insertion_command(

@staticmethod
@abstractmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.
Expand All @@ -235,7 +233,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:

@staticmethod
@abstractmethod
def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int:
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
"""
Returns axis number of the weight tensor which correspond to it channel.
Expand All @@ -254,3 +252,13 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
:param transpose: Transpose position.
:return: Channel axis.
"""

@staticmethod
@abstractmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
pass

@staticmethod
@abstractmethod
def get_filter_fn_for_statistics(activation_port_id: int):
pass
47 changes: 34 additions & 13 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import QUANTIZE_AGNOSTIC_OPERATIONS
Expand Down Expand Up @@ -48,15 +49,15 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
return QUANTIZE_AGNOSTIC_OPERATIONS

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(target_type, target_node_name, port_id)
def target_point(target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
return node.layer_attributes and node.layer_attributes.constant_attributes

@staticmethod
def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]:
def _get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]:
weight_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in weight_ports
Expand All @@ -67,6 +68,10 @@ def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]

return {"activation": activation_ports[0], "weight": weight_ports[0]}

@staticmethod
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
return OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["activation"]

@staticmethod
def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
return get_channel_agnostic_reduction_axes([channel_axis], shape)
Expand All @@ -86,11 +91,16 @@ def process_weight_statistics(weights: np.ndarray, reduction_shape: Tuple[int])
return np.max(np.abs(weights), axis=reduction_shape)

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray:
def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> np.ndarray:
port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
return get_weight_value(node_with_weight, model, port_id)

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node)

@staticmethod
def _get_weight_tensor_port_id(node: NNCFNode) -> int:
const_ids = node.layer_attributes.get_const_port_ids()
if len(const_ids) != 1:
raise RuntimeError(f"Found more than 1 port for {node.node_name} node")
Expand Down Expand Up @@ -134,9 +144,8 @@ def calculate_weight_scale(scale_value: np.ndarray, weights_size: int, channel_a
return weight_scale

@staticmethod
def weight_update_command(
node_with_weight: NNCFNode, weight_value: np.ndarray, weight_port_id: int
) -> OVWeightUpdateCommand:
def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id)

@staticmethod
Expand All @@ -146,7 +155,7 @@ def scale_insertion_command(
return OVCommandCreator.multiply_insertion_command(source_node, nodes, port_id, scale_value, scale_node_name)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
channel_axis = 1

if port_id > 1:
Expand All @@ -164,11 +173,9 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
return channel_axis

@staticmethod
def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int:
channel_axis = 1

if port_id > 1:
raise RuntimeError(f"{node.metatype.name} can not take more than 2 input tensors.")
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
port_id = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["weight"]
channel_axis = 1 if node.metatype.const_channel_axis is None else node.metatype.const_channel_axis[0]

if port_id not in node.layer_attributes.constant_attributes:
raise RuntimeError(f"{node.node_name} should contain {port_id} in the attributes map.")
Expand All @@ -183,3 +190,17 @@ def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int:
@staticmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
ports_map = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)
weight_node = nncf_graph.get_input_edges(node)[ports_map["weight"]].from_node
# Skipping shared weights
return len(nncf_graph.get_next_nodes(weight_node)) > 1

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int):
def filter_func(point: StatisticPoint) -> bool:
return point.target_point.port_id == activation_port_id

return filter_func
Loading

0 comments on commit 22a58ca

Please sign in to comment.