Skip to content

Commit

Permalink
SQ test is complited
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 27, 2023
1 parent d220ca5 commit 6cffa74
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 137 deletions.
19 changes: 8 additions & 11 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def apply(
activations_value = self._backend_entity.clip_statistics(activations_value[0])

weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, graph)
weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value)
weight_statistics = self._backend_entity.clip_statistics(weight_statistics)

alpha = alpha_map[node_to_smooth.metatype]
Expand All @@ -159,7 +159,7 @@ def apply(

for node_to_smooth in nodes:
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value, graph)
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value)
### TODO: DO it as NNCFTensor op
scaled_weight = weight_value * weights_scale
###
Expand Down Expand Up @@ -302,8 +302,7 @@ def _calculate_activation_scale(
"""
activation_ports_map = {node: self._backend_entity.get_activations_port_id(node, nncf_graph) for node in nodes}
channel_axes = [
self._backend_entity.get_activation_channel_axis(node, port, activations_shape)
for node, port in activation_ports_map.items()
self._backend_entity.get_activation_channel_axis(node, port) for node, port in activation_ports_map.items()
]
channel_axis = channel_axes[0]

Expand All @@ -313,9 +312,7 @@ def _calculate_activation_scale(
activations_size = len(activations_shape)
return self._backend_entity.calculate_activation_scale(scale_value, activations_size, channel_axis)

def _calculate_weight_scale(
self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor, graph: NNCFGraph
) -> TTensor:
def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor) -> TTensor:
"""
Calculates scale for weight tensor.
Expand All @@ -325,7 +322,7 @@ def _calculate_weight_scale(
"""
weights_size = len(weights_value.shape)
if weights_size > 1:
channel_axis = self._backend_entity.get_weight_channel_axis(node, graph)
channel_axis = self._backend_entity.get_weight_channel_axis(node)
return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis)
return scale_value

Expand All @@ -341,11 +338,11 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape
reduction_axes = tuple([])
if len(shape) > 1:
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port, shape)
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port)
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape)
return reduction_axes

def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NNCFGraph) -> TTensor:
def _process_weight_statistics(self, node: NNCFNode, weights: TTensor) -> TTensor:
"""
Returns processed weight statistics for node.
Expand All @@ -356,7 +353,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NN
"""
channel_axis = 0
if len(weights.shape) > 1:
channel_axis = self._backend_entity.get_weight_channel_axis(node, graph)
channel_axis = self._backend_entity.get_weight_channel_axis(node)
reduction_shape = [i for i, _ in enumerate(weights.shape)]
reduction_shape.pop(channel_axis)
return self._backend_entity.process_weight_statistics(weights, tuple(reduction_shape))
Expand Down
15 changes: 2 additions & 13 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def scale_insertion_command(

@staticmethod
@abstractmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.
Expand All @@ -233,7 +233,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape:

@staticmethod
@abstractmethod
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
def get_weight_channel_axis(node: NNCFNode) -> int:
"""
Returns axis number of the weight tensor which correspond to it channel.
Expand All @@ -242,17 +242,6 @@ def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
:return: Channel axis number.
"""

@staticmethod
@abstractmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
"""
Returns port-based channel axis.
:param port_id: Specified input port id.
:param transpose: Transpose position.
:return: Channel axis.
"""

@staticmethod
@abstractmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
Expand Down
38 changes: 18 additions & 20 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -56,21 +56,22 @@ def target_point(target_node_name: str, port_id: int) -> OVTargetPoint:
def is_node_with_weights(node: NNCFNode) -> bool:
return node.layer_attributes and node.layer_attributes.constant_attributes

def _get_weight_port_id(node: NNCFNode) -> int:
weight_ports = node.layer_attributes.get_const_port_ids()
if len(weight_ports) != 1:
raise RuntimeError(f"Too many weight ports for {node.node_name} node")
return weight_ports[0]

@staticmethod
def _get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]:
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
weight_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in weight_ports
]

if len(weight_ports) != 1 or len(activation_ports) != 1:
raise RuntimeError(f"Too many weight or activation ports for {node.node_name} node")

return {"activation": activation_ports[0], "weight": weight_ports[0]}

@staticmethod
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
return OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["activation"]
return activation_ports[0]

@staticmethod
def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]:
Expand Down Expand Up @@ -161,7 +162,7 @@ def scale_insertion_command(
)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
channel_axis = 1

if port_id > 1:
Expand All @@ -174,33 +175,30 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape:
and "transpose" in node.layer_attributes.input_attributes
):
transpose = node.layer_attributes.input_attributes["transpose"]
channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose)
channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose)

return channel_axis

@staticmethod
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
port_id = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["weight"]
channel_axis = 1 if node.metatype.const_channel_axis is None else node.metatype.const_channel_axis[0]

if port_id not in node.layer_attributes.constant_attributes:
raise RuntimeError(f"{node.node_name} should contain {port_id} in the attributes map.")
def get_weight_channel_axis(node: NNCFNode) -> int:
port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node)
channel_axis = 1

if node.metatype == OVMatMulMetatype:
if "transpose" in node.layer_attributes.constant_attributes[port_id]:
transpose = node.layer_attributes.constant_attributes[port_id]["transpose"]
channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose)
channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose)

return channel_axis

@staticmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
def _calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
ports_map = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)
weight_node = nncf_graph.get_input_edges(node)[ports_map["weight"]].from_node
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node)
weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node
# Skipping shared weights
return len(nncf_graph.get_next_nodes(weight_node)) > 1

Expand Down
14 changes: 7 additions & 7 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,16 @@ def scale_insertion_command(
return multiply_insertion_command(nodes, scale_value, scale_node_name, input_port_id)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
return len(activations_shape) - 1

@staticmethod
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul wiil be supported
return 1

@staticmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id
def get_weight_channel_axis(node: NNCFNode) -> int:
# TODO: Add activation axis calculation when MatMul wiil be supported
return 1

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
Expand Down
2 changes: 0 additions & 2 deletions nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
from nncf.common.graph.layer_attributes import PermuteLayerAttributes
from nncf.common.graph.layer_attributes import ReshapeLayerAttributes
from nncf.common.graph.layer_attributes import TransposeLayerAttributes
from nncf.common.graph.layer_attributes import MatMulLayerAttributes
from nncf.common.graph.utils import get_concat_axis
from nncf.common.graph.utils import get_split_axis
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype
from nncf.torch.graph.operator_metatypes import PTPadMetatype
from nncf.torch.graph.operator_metatypes import PTMatMulMetatype
from nncf.torch.graph.operator_metatypes import PTReshapeMetatype
from nncf.torch.graph.operator_metatypes import PTSplitMetatype
from nncf.torch.graph.operator_metatypes import PTSqueezeMetatype
Expand Down
17 changes: 9 additions & 8 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
return PTWeightUpdateCommand(target_point, weight_value)


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)


def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
Expand All @@ -56,12 +65,4 @@ def multiply_insertion_command(
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))

class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
2 changes: 1 addition & 1 deletion nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def device(self) -> torch.device:
return self._tensor.device

def is_empty(self) -> bool:
return self.tensor.size == 0
return self.tensor.numel() == 0
72 changes: 59 additions & 13 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,61 @@
import torch
from openvino.tools.mo import convert_model

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

OV_LINEAR_MODEL_MM_OP_MAP = {
"MatMul1": "/MatMul",
"MatMul2": "/MatMul_1",
"MatMul3": "/MatMul_2",
"MatMul4": "/MatMul_4",
"MatMul5": "32",
"MatMul6": "37",
"MatMul7": "54",
"MatMul8": "68",
"Linear1": "/linear_2/MatMul",
"Linear2": "/linear_1/MatMul",
"Linear3": "/linear_3/MatMul",
"Linear4": "/linear_4/MatMul",
}


OV_LINEAR_MODEL_SQ_OP_MAP = {
"MatMul1": "/Reshape_0_0/nncf_smooth_quant",
"MatMul2": "/Reshape_0_0/nncf_smooth_quant",
"MatMul3": "/Reshape_1_0_0/nncf_smooth_quant",
"MatMul4": "/Reshape_1_0_1/nncf_smooth_quant",
"MatMul5": "/Reshape_2_0_0/nncf_smooth_quant",
"MatMul6": "/ReduceMax_0_0/nncf_smooth_quant",
"MatMul7": "/Reshape_3_0_0/nncf_smooth_quant",
"MatMul8": "/Reshape_4_0_0/nncf_smooth_quant",
"Linear1": "/Split_1_0/nncf_smooth_quant",
"Linear2": "/Split_0_0/nncf_smooth_quant",
"Linear3": "/Add_0_0/nncf_smooth_quant",
"Linear4": "/Add_0_0/nncf_smooth_quant",
}


class TestOVSQAlgorithm(TemplateTestSQAlgorithm):
@staticmethod
def fn_to_type(tensor) -> np.ndarray:
return np.array(tensor)

@pytest.fixture(params=[False, True], ids=["out_of_palce", "inplace"])
def inplace_statistics(self, request) -> bool:
return request.param

def get_node_name_map(self) -> Dict[str, str]:
return OV_LINEAR_MODEL_MM_OP_MAP

@staticmethod
def get_target_node_name(command: TransformationCommand):
return command.target_point.target_node_name

@staticmethod
def get_transform_fn() -> Callable:
def transform_fn(data_item):
Expand All @@ -53,10 +96,14 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model:
@staticmethod
def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray]) -> None:
ops_list = {op.get_friendly_name(): op for op in model.get_ops()}
for ref_name, ref_value in reference_values.items():
node = ops_list[ref_name]
const_node = node.input(1).get_source_output().get_node()

for ref_names, ref_value in reference_values.items():
const_nodes = []
for ref_name in ref_names:
node = ops_list[OV_LINEAR_MODEL_SQ_OP_MAP[ref_name]]
const_nodes.append(node.input(1).get_source_output().get_node())
# Check unified group acutally shares one constant
assert all(node is const_nodes[0] for node in const_nodes[1:])
const_node = const_nodes[0]
assert const_node.get_type_name() == "Constant"

value = const_node.data
Expand All @@ -79,18 +126,17 @@ def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port
return super().test_get_activation_channel_axis(node_metatype, layer_attributes, port_id, reference_value)

@pytest.mark.parametrize(
"node_metatype, layer_attributes, port_id, reference_value",
"node_metatype, layer_attributes, reference_value",
(
(OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 1, -2),
(OVMatMulMetatype, OVLayerAttributes({1: {"transpose": True}}), 1, -1),
(OVMatMulMetatype, OVLayerAttributes({0: {"transpose": False}}), 0, -1),
(OVMatMulMetatype, OVLayerAttributes({0: {"transpose": True}}), 0, -2),
(OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 2, RuntimeError),
(OVConvolutionMetatype, OVLayerAttributes({1: {}}), 1, 1),
(OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), -2),
(OVMatMulMetatype, OVLayerAttributes({1: {"transpose": True}}), -1),
(OVMatMulMetatype, OVLayerAttributes({0: {"transpose": False}}), -1),
(OVMatMulMetatype, OVLayerAttributes({0: {"transpose": True}}), -2),
(OVConvolutionMetatype, OVLayerAttributes({1: {}}), 1),
),
)
def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value):
return super().test_get_weight_channel_axis(node_metatype, layer_attributes, port_id, reference_value)
def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value):
return super().test_get_weight_channel_axis(node_metatype, layer_attributes, reference_value)

@staticmethod
def get_matmul_metatype():
Expand Down
Loading

0 comments on commit 6cffa74

Please sign in to comment.