Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 24, 2024
1 parent be4d6f6 commit 8ea90a6
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 43 deletions.
12 changes: 3 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,15 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def is_node_with_weights(node: NNCFNode) -> bool:
return node.layer_attributes and node.layer_attributes.constant_attributes

def _get_weight_port_id(node: NNCFNode) -> int:
weight_ports = node.layer_attributes.get_const_port_ids()
if len(weight_ports) != 1:
raise RuntimeError(f"Too many weight ports for {node.node_name} node")
return weight_ports[0]

@staticmethod
def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
weight_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in weight_ports
]

if len(weight_ports) != 1 or len(activation_ports) != 1:
raise RuntimeError(f"Too many weight or activation ports for {node.node_name} node")
if len(activation_ports) != 1:
raise RuntimeError(f"Too many activation ports for {node.node_name} node")
return activation_ports[0]

@staticmethod
Expand Down Expand Up @@ -155,7 +149,7 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node)
weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node)
weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node
return len(nncf_graph.get_next_nodes(weight_node)) > 1

Expand Down
24 changes: 20 additions & 4 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Callable, List, Tuple

import numpy as np
import torch

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
Expand All @@ -28,14 +29,23 @@
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.command_creation import multiply_insertion_command
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)


class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -119,18 +129,24 @@ def scale_insertion_command(
scale_node_name: str,
) -> OVMultiplyInsertionCommand:
input_port_id = 0
return multiply_insertion_command(nodes, scale_value, scale_node_name, input_port_id)
target_points = []
for node in nodes:
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node.node_name, input_port_id=input_port_id)
)

return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul wiil be supported
# TODO: Add activation axis calculation when MatMul will be supported
return 1

@staticmethod
def get_weight_channel_axis(node: NNCFNode) -> int:
# TODO: Add activation axis calculation when MatMul wiil be supported
# TODO: Add activation axis calculation when MatMul will be supported
return 1

@staticmethod
Expand Down
25 changes: 0 additions & 25 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import torch
from torch import Tensor

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand

Expand All @@ -45,24 +41,3 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
"""
target_point = PTTargetPoint(TargetType.LAYER, node.node_name)
return PTWeightUpdateCommand(target_point, weight_value)


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)


def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
target_points = []
for target_node in target_nodes:
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
)

return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
2 changes: 1 addition & 1 deletion tests/torch/ptq/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype
from nncf.torch.graph.transformations.command_creation import SQMultiply
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import ExtraCompressionModuleType
Expand Down
5 changes: 1 addition & 4 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,8 @@ def setup(self):
point_for_relu_inputs,
]

@pytest.mark.parametrize(
"insert_method_name,check_tmp_ops", [("insert_at_point", False), ("temporary_insert_at_point", True)]
)
@pytest.mark.parametrize("target_point", available_points)
def test_single_insertions(self, setup, target_point: PTTargetPoint, insert_method_name: str, check_tmp_ops: bool):
def test_single_insertions(self, setup, target_point: PTTargetPoint):
insertion_point = PTInsertionPoint(
target_point.target_type,
OperationAddress.from_str(target_point.target_node_name),
Expand Down

0 comments on commit 8ea90a6

Please sign in to comment.